PyTorch Version Impact on ColBERT Index Artifacts
Background
I recently released colbert-ai==0.2.22
which removed the deprecated transformers.AdamW
import among other changes. I’m now turning my attention to upgrading the PyTorch dependency to 2.x, which will not only introduce compatibility with modern version installations of torch
but will also allow the integration of the AnswerAI fastkmeans
library as a replacement for the faiss-gpu
and faiss-cpu
libraries (which are no longer officially maintained on PyPI).
I started this PyTorch 2.x upgrade effort by analyzing the impact of torch==2.0.0
on colbert-ai
as this was the first upgrade from the existing torch==1.13.1
dependency. I approached this analysis by reviewing and documenting whether the 500+ PRs involved in torch==2.0.0
would impact colbert-ai
. The resulting spreadsheet and blog post detail my findings. In short, I estimated that 28 PRs potentially impacted colbert-ai
.
In this blog post I’m detailing a different approach, from the “other end” so to speak: what changes in colbert-ai
index artifacts when changing PyTorch versions?
Indexing ConditionalQA with 19 PyTorch Versions
I started by indexing the UKPLab/DAPR/ConditionalQA document collection with 19 different colbert-ai
installs (one for each version of PyTorch from 1.13.1
to 2.8.0
), using Modal. Each Dockerfile looks something like this:
FROM mambaorg/micromamba:latest
USER root
RUN apt-get update && apt-get install -y git nano curl wget build-essential && apt-get clean && rm -rf /var/lib/apt/lists/*
RUN git clone https://github.com/stanford-futuredata/ColBERT.git /ColBERT && \
cd /ColBERT && \
micromamba create -n colbert python=3.11 cuda -c nvidia/label/11.7.1 -c conda-forge && \
micromamba install -n colbert faiss-gpu -c pytorch -c conda-forge && \
micromamba run -n colbert pip install -e . && \
micromamba run -n colbert pip install torch==2.2.0 transformers==4.38.2 pandas
ENV CONDA_DEFAULT_ENV=colbert
ENV PATH=/opt/conda/envs/colbert/bin:$PATH
WORKDIR /ColBERT
RUN echo "eval \"\$(micromamba shell hook --shell bash)\"" >> ~/.bashrc && \
echo "micromamba activate colbert" >> ~/.bashrc
CMD ["/bin/bash"]
I decided to git clone
and pip install -e .
the main
branch of stanford-futuredata/ColBERT since I wanted to modify the files down the road to save/inject intermediate index artifacts (as we’ll see later on in this blog post).
My indexing function looks like:
@app.function(gpu=GPU, image=image, timeout=3600,
={MOUNT: VOLUME},
volumes=1)
max_containersdef _index(source, project, date, nranks, ndocs, root):
import os
import subprocess
'pwd'], text=True, shell=True)
subprocess.run([from colbert import Indexer
from colbert.infra import RunConfig, ColBERTConfig
from colbert.infra.run import Run
from datasets import load_dataset
"ROOT"] = root
os.environ[
= "ConditionalQA"
dataset_name = load_dataset("UKPLab/dapr", f"{dataset_name}-corpus", split="test")
passages = load_dataset("UKPLab/dapr", f"{dataset_name}-queries", split="test")
queries = load_dataset("UKPLab/dapr", f"{dataset_name}-qrels", split="test")
qrels_rows
with Run().context(RunConfig(nranks=nranks)):
= ColBERTConfig(
config =256,
doc_maxlen=4,
nbits=96,
dim=20,
kmeans_niters=32,
index_bsize=64,
bsize="answerdotai/answerai-colbert-small-v1"
checkpoint
)
= Indexer(checkpoint="answerdotai/answerai-colbert-small-v1", config=config)
indexer = indexer.index(name=f"{MOUNT}/{project}/{date}-{source}-{nranks}/indexing/{dataset_name}", collection=passages[:ndocs]["text"], overwrite=True)
_
print("Index created!")
I would run the indexing function (in my main.py
file) using a terminal command like so:
SOURCE="0.2.22.main.torch.1.13.1" DATE="20250818" PROJECT="torch2.x" NRANKS=1 GPU="L4" modal run main.py
Comparing Index Artifacts Across PyTorch Versions
Once indexed, I ran my comparison script which starts by comparing index file names:
print("\n[bold blue]INDEX FILE NAME COMPARISON[/bold blue]")
console.= os.listdir(a_path)
a = os.listdir(b_path)
b
try:
for i, f in enumerate(a): assert f == b[i]
print(f"[green]✓ All {len(a)} files match[/green]")
console.except:
print("[red]✗ File names don't match[/red]") console.
Then index tensor shapes:
for i, f in enumerate(a_pts):
print(f"\n[bold]{f}[/bold]")
console.= torch.load(a_path + f)
a_pt = torch.load(b_path + f)
b_pt
if isinstance(a_pt, tuple):
= a_pt[0].shape == b_pt[0].shape
match1 = a_pt[1].shape == b_pt[1].shape
match2 print(f" Tensor[0]: [{'green' if match1 else 'red'}]{a_pt[0].shape} vs {b_pt[0].shape}[/{'green' if match1 else 'red'}]")
console.print(f" Tensor[1]: [{'green' if match2 else 'red'}]{a_pt[1].shape} vs {b_pt[1].shape}[/{'green' if match2 else 'red'}]")
console.if not (match1 and match2):
+= 1
shape_mismatches else:
= a_pt.shape == b_pt.shape
match print(f" Shape: [{'green' if match else 'red'}]{a_pt.shape} vs {b_pt.shape}[/{'green' if match else 'red'}]")
console.if not match:
+= 1 shape_mismatches
and finally compare tensor values between indexes:
for i, f in enumerate(a_pts):
print(f"\n[bold]{f}[/bold]")
console.= torch.load(a_path + f)
a_pt = torch.load(b_path + f)
b_pt
if isinstance(a_pt, tuple):
if a_pt[0].shape == b_pt[0].shape:
= torch.allclose(a_pt[0], b_pt[0])
match1 print(f" [{'green' if match1 else 'red'}]{'✓' if match1 else '✗'} Tensor[0] values {'match' if match1 else 'differ'}[/{'green' if match1 else 'red'}]")
console.else:
print(" [red]✗ Tensor[0] shape mismatch[/red]")
console.= False
match1
if a_pt[1].shape == b_pt[1].shape:
= torch.allclose(a_pt[1], b_pt[1])
match2 print(f" [{'green' if match2 else 'red'}]{'✓' if match2 else '✗'} Tensor[1] values {'match' if match2 else 'differ'}[/{'green' if match2 else 'red'}]")
console.else:
print(" [red]✗ Tensor[1] shape mismatch[/red]")
console.= False
match2
if not (match1 and match2):
+= 1
value_mismatches else:
if a_pt.shape == b_pt.shape:
= torch.allclose(a_pt, b_pt)
match print(f" [{'green' if match else 'red'}]{'✓' if match else '✗'} Values {'match' if match else 'differ'}[/{'green' if match else 'red'}]")
console.else:
print(" [red]✗ Shape mismatch[/red]")
console.= False
match
if not match:
+= 1 value_mismatches
I compared consecutive pairs of PyTorch version colbert-ai
installs to understand between which versions the index artifacts change. Here are my results:
Version A | Version B | All .pt Shapes Match? (Matches) |
All .pt Values Match? (Matches) |
---|---|---|---|
1.13.1 | 2.0.0 | Yes (10/10) | Yes (10/10) |
2.0.0 | 2.0.1 | Yes (10/10) | Yes (10/10) |
2.0.1 | 2.1.0 | No (9/10) | No (0/10) |
2.1.0 | 2.1.1 | Yes (10/10) | Yes (10/10) |
2.1.1 | 2.1.2 | Yes (10/10) | Yes (10/10) |
2.1.2 | 2.2.0 | Yes (10/10) | Yes (10/10) |
2.2.0 | 2.2.1 | Yes (10/10) | Yes (10/10) |
2.2.1 | 2.2.2 | Yes (10/10) | Yes (10/10) |
2.2.2 | 2.3.0 | Yes (10/10) | Yes (10/10) |
2.3.0 | 2.3.1 | Yes (10/10) | Yes (10/10) |
2.3.1 | 2.4.0 | Yes (10/10) | Yes (10/10) |
2.4.0 | 2.4.1 | Yes (10/10) | Yes (10/10) |
2.4.1 | 2.5.0 | No (9/10) | No (0/10) |
2.5.0 | 2.5.1 | Yes (10/10) | Yes (10/10) |
2.5.1 | 2.6.0 | Yes (10/10) | Yes (10/10) |
2.6.0 | 2.7.0 | Yes (10/10) | Yes (10/10) |
2.7.0 | 2.7.1 | Yes (10/10) | Yes (10/10) |
2.7.1 | 2.8.0 | Yes (10/10) | No (6/10) |
There are three PyTorch upgrades that cause a change in index artifacts: 2.0.1 –> 2.1.0, 2.4.1 –> 2.5.0, and 2.7.1 –> 2.8.0.
Comparing Intermediate Index Artifacts
To better understand exactly where the index artifacts changed when upgrading PyTorch, I created my own copies of two stanford-futuredata/ColBERT files and added torch.save
lines to save the intermediate artifacts listed below:
- colbert/indexing/collection_indexer.py
sampled_pids
(a set of integers corresponding to sampled passage IDs)num_passages
(a single integers, the number of total passages in the collection)local_sample_embs
(BERT encodings of the sample pids, created byCheckpoint.docFromText
)centroids
(from_train_kmeans
)bucket_cutoffs
(the bin “boundaries” used for quantization from_compute_avg_residual
)bucket_weights
(the quantized values, from_compute_avg_residual
)avg_residual
(a single float, from_compute_avg_residual
)sample
(95% of the values fromlocal_sample_embs.half()
)sample_heldout
(5% of the values fromlocal_sample_embs.half()
)embs
(encoded passages)doclens
(number of tokens in each passage)codes
(centroid IDs (values) and document token IDs (indices))ivf
(document token IDs)values
(centroid IDs)
- colbert/modeling/checkpoint.py
tensorize_output
(tuple (text_batches
,reverse_indices
) output fromDocTokenizer.tensorize
)batches
(BERT encodings, output fromCheckpoint.doc
)D
(sorted and reshapedbatches
)
I then replaced the corresponding files in the /ColBERT
directory (which is why I used git clone
and pip install e .
) with the following lines for Modal:
= image.add_local_file("collection_indexer.py", "/ColBERT/colbert/indexing/collection_indexer.py")
image = image.add_local_file("checkpoint.py", "/ColBERT/colbert/modeling/checkpoint.py") image
Here are the results when comparing these artifacts between colbert-ai
installs using torch==1.13.1
and torch==2.1.0
:
Artifact | torch.allclose |
|
---|---|---|
sampled_pids |
True |
|
num_passages |
True |
|
local_sample_embs |
False |
|
centroids |
False |
|
bucket_cutoffs |
False |
|
bucket_weights |
False |
|
avg_residual |
False |
|
sample |
False |
|
sample_heldout |
False |
|
embs |
False |
|
doclens |
True |
|
codes |
False |
|
ivf |
False |
|
values |
False |
|
tensorize_output |
True |
|
batches |
False |
|
D |
False |
After reviewing these comparisons, my hypothesis was that the first difference (in local_sample_embs
) affected all subsequent artifacts. The difference in local_sample_embs
can be traced down to the difference in batches
and D
. To test this hypothesis, I “injected” the local_sample_embs
from the torch==1.13.1
install into the collection_indexer.py
when indexing with torch==2.1.0
:
= torch.load("/colbert-maintenance/torch2.x/20250818-0.2.22.main.torch.1.13.1-1k-1/local_sample_embs.pt") local_sample_embs
I then re-compared the artifacts, and my hypothesis was correct!
Artifact | torch.allclose |
|
---|---|---|
centroids |
True |
|
bucket_cutoffs |
True |
|
bucket_weights |
True |
|
avg_residual |
True |
|
sample |
True |
|
sample_heldout |
True |
|
embs |
False |
|
doclens |
True |
|
codes |
False |
|
ivf |
False |
|
values |
False |
Comparing the BertModel
s
Where do local_sample_embs
come from? The highest-level method is CollectionEncoder.encode_passages
. Inside CollectionEncoder.encode_passages
the collection of texts, passages
is fed to Checkpoint.docFromText
. Inside there, the tokenized text is passed to Checkpoint.doc
, which passes them to ColBERT.doc
, which finally passes the input_ids
and attention_mask
to ColBERT.bert
. Since there was a divergence in local_sample_embs
, I figured there would be a divergence in either the weights and/or the logits of ColBERT.bert
between both PyTorch version installs.
I installed each image of colbert-ai
and separately saved the BertModel
weights as well as a dictionary with outputs from each of the 10 BertEncoder
layers. These outputs were accessed using a forward hook:
= {}
outputs_dict
def capture_output(name):
def hook_fn(module, input, output):
= output[0].detach()
outputs_dict[name] return hook_fn
= []
hooks for i in range(10):
f"1.13.1_{i}")))
hooks.append(checkpoint.bert.encoder.layer[i].register_forward_hook(capture_output(
with torch.no_grad():
= checkpoint.bert(input_ids, attention_mask=attention_mask)[0]
D
for h in hooks: h.remove()
Both colbert-ai
installs (torch==1.13.1
and torch==2.1.0
) had equal BertModel
weights. However, both of them have diverging BertEncoder
outputs.
Here are the mean absolute differences between corresponding BertEncoder
layer outputs between torch==1.13.1
and torch==2.1.0
:
for i in range(10):
= a[f"1.13.1_{i}"]
a_ = b[f"2.1_{i}"]
b_ print(i, torch.abs(a_ - b_).float().mean())
0 tensor(2.8141e-08, device='cuda:0')
1 tensor(5.9652e-08, device='cuda:0')
2 tensor(8.0172e-08, device='cuda:0')
3 tensor(7.8228e-08, device='cuda:0')
4 tensor(7.9968e-08, device='cuda:0')
5 tensor(8.3589e-08, device='cuda:0')
6 tensor(8.7348e-08, device='cuda:0')
7 tensor(8.5140e-08, device='cuda:0')
8 tensor(8.5651e-08, device='cuda:0')
9 tensor(8.1636e-08, device='cuda:0')
The difference increases about 2x as we go deeper through the model.
Here are the max absolute differences, which increases 2x by the final layer:
0 tensor(3.5763e-07, device='cuda:0')
1 tensor(4.7684e-07, device='cuda:0')
2 tensor(5.9605e-07, device='cuda:0')
3 tensor(5.9605e-07, device='cuda:0')
4 tensor(7.1526e-07, device='cuda:0')
5 tensor(7.1526e-07, device='cuda:0')
6 tensor(7.1526e-07, device='cuda:0')
7 tensor(9.5367e-07, device='cuda:0')
8 tensor(9.5367e-07, device='cuda:0')
9 tensor(1.1921e-06, device='cuda:0')
Closing Thoughts
From this analysis, I can conclude that the difference in index artifacts generated by colbert-ai
using different torch==1.13.1
vs. torch==2.1.0
is due to floating point differences in the forward pass of the BertModel
used to generate token-level embeddings from text passages. I have not yet analyzed the torch==2.1.0
release notes to make an educated guess on why these differences occur. But given that it’s during the forward pass of the model, I would wager there was some update to the underlying C++ code for the torch.nn
module.
I will move forward with comparing intermediate artifacts between each subsequent versions where the final index artifacts are different 2.4.1 –> 2.5.0, and 2.7.1 –> 2.8.0. Once that’s complete, I’ll dive into the PyTorch release notes and see if I can reasonably point to a few PRs behind this change. Once I have a reasonable handle on understanding colbert-ai
indexing behavior with different versions of PyTorch 2.x, I’ll perform a similar analysis with colbert-ai
training and document my findings.
Thanks for reading until the end! I’ll be posting more blog post and/or video updates around ColBERT maintenance as soon as I have something more to share.