Re-evaluating colbert-ai
Index Artifacts Between PyTorch Versions with Precision-Based torch.allclose
Tolerances
torch.allclose
returns False
. This analysis also led to multiple deep dives that are linked as separate blog posts.
Background
I recently learned that it’s best practice to use different torch.allclose
tolerances based on the precision of the floating point value. As a reminder, torch.allclose
uses absolute and relative tolerances as follows:
∣input_i − other_i∣ ≤ atol + rtol × ∣other_i∣
bitsandbytes uses the following heuristic:
if dtype == torch.float32:
= 1e-6, 1e-5
atol, rtol elif dtype == torch.bfloat16:
= 1e-3, 1e-2
atol, rtol else: # float16
= 1e-4, 1e-3 atol, rtol
Full-precision (float32
) has the lowest tolerance, followed by half-precision (float16
) and then bfloat16
. I’ve been using default tolerances in all my torch.allclose
calls, regardless of precision (atol
= 1e-08
, rtol
= 1e-05
). Comparing these with bitsandbytes’ tolerances, these default tolerances are:
- float32: 100x smaller for
atol
and the same forrtol
- float16: 10_000x smaller for
atol
and 100x smaller forrtol
- bfloat16: 100_000x smaller for
atol
and 1000x smaller forrtol
As you can see, the bitsandbytes tolerances are much more forgiving for lower precision, which intuitively makes sense.
The first question I’ll explore in this post: how does changing my torch.allclose
tolerances affect index artifact comparison? In other words, are there tensors between PyTorch versions whose difference is larger than atol + rtol × ∣other_i∣
when using bitsandbytes’ more forgiving tolerances?
The second question I’ll answer: when torch.allclose
fails, what is the root cause?
I use the full (69_199 documents) UKPLab/DAPR/ConditionalQA dataset in this exercise. In previous blog posts I used a 1000-document subset.
Comparing All Consecutive Versions
In this section I’ll document tensor shape mismatches and torch.allclose
values (with default tolerances in the “Default” column and bitsandbytes tolerances in the “bnb” column) for tensor index artifacts between consecutive PyTorch versions from 1.13.1 (the version pinned in the latest colbert-ai
release) to 2.8.0 (the latest PyTorch version available as of 9/14/2025).
PyTorch Version A | PyTorch Version B | All Shapes Match | Default | bnb |
---|---|---|---|---|
1.13.1 | 2.0.0 | Yes | True |
True |
2.0.0 | 2.0.1 | Yes | True |
True |
2.0.1 | 2.1.0 | No (11/12 Match) | False (0/12 Match) |
False (2/12 Match) |
2.1.0 | 2.1.1 | Yes | True |
True |
2.1.1 | 2.1.2 | Yes | True |
True |
2.1.2 | 2.2.0 | Yes | True |
True |
2.2.0 | 2.2.1 | Yes | True |
True |
2.2.1 | 2.2.2 | Yes | True |
True |
2.2.2 | 2.3.0 | Yes | True |
True |
2.3.0 | 2.3.1 | Yes | True |
True |
2.3.1 | 2.4.0 | Yes | True |
True |
2.4.0 | 2.4.1 | Yes | True |
True |
2.4.1 | 2.5.0 | No (11/12 Match) | False (0/12 Match) |
False (2/12 Match) |
2.5.0 | 2.5.1 | Yes | True |
True |
2.5.1 | 2.6.0 | Yes | True |
True |
2.6.0 | 2.7.0 | Yes | True |
True |
2.7.0 | 2.7.1 | Yes | True |
True |
2.7.1 | 2.8.0 | Yes | False (8/12 Match) |
False (9/12 Match) |
In the three version comparisons where torch.allclose
failed using default atol
and rtol
values, using bitsandbytes values yielded the same overall result (not all tensors match) but with two more matches for 2.0.1 –> 2.1.0 and 2.4.1 –> 2.5.0, and one more match for 2.7.1 –> 2.8.0.
Here’s my _close
function to handle comparisons between tensors a
and b
:
def _close(a, b, default=False):
= a.dtype
gtype
if gtype in [torch.uint8, torch.int32, torch.int64]:
if a.shape == b.shape: return torch.equal(a,b)
return False
if not default:
if gtype == torch.float32:
= 1e-6, 1e-5
atol, rtol elif gtype == torch.bfloat16:
= 1e-3, 1e-2
atol, rtol else:
= 1e-4, 1e-3
atol, rtol else:
= 1e-8, 1e-5
atol, rtol return torch.allclose(a, b, rtol=rtol, atol=atol)
Root Cause For Index Artifact Difference Between Consecutive PyTorch Versions
There were three consecutive PyTorch versions which broke index artifact reproducibility in colbert-ai
. Listed below are the tensors that failed torch.equal
(for integers) or torch.allclose
(with bitsandbytes’ tolerances):
- 2.0.1 –> 2.1.0
ivf.pid.pt
(ivf
: unique passage IDs (pids) per centroid ID,ivf_lengths
: number of pids per centroid id)codes.pt
(centroid ID mapped to doc token IDs)residuals.pt
(distance between centroids and doc token embeddings)centroids.pt
(centroids of clustered sample doc token embeddingslocal_sample_embs
)bucket_cutoffs
(the quantization bins)
- 2.4.1 –> 2.5.0
ivf.pid.pt
(ivf
andivf_lengths
)codes.pt
residuals.pt
centroids.pt
bucket_cutoffs
- 2.7.1 –> 2.8.0
residuals.pt
In the following sections I’ll detail the root cause for index artifact divergence.
2.0.1 –> 2.1.0: BertModel
Forward Pass for Any input_ids
The first critical intermediate indexing tensor created is local_sample_embs
. This is a sample of document token embeddings used to calculate centroids. The sample passages are passed to Checkpoint.docFromText
, which calls Checkpoint.doc
, which ultimately calls Checkpoint.bert
.
sample_pids
, the sample of passage IDs selected for encoding, were identical between 2.0.1 and 2.1.0, but local_sample_embs
did not pass torch.allclose
(with bnb tolerances). This was the smell that led me to compare the Checkpoint.bert
model layer outputs between PyTorch versions using register_forward_hook
. I tried a variety of input tokens (different batches of passages, random text, single letter strings) and in all cases, model layer outputs between PyTorch versions failed torch.allclose
. I thus concluded that something in PyTorch changed between 2.0.1 and 2.1.0 to cause this. You can read more details of this exploration in another blog post.
To confirm that the local_sample_embs
divergence caused the divergence in downstream index artifacts, I replaced the local_sample_embs
in the torch==2.1.0
install with local_sample_embs
from the torch==2.0.1
install and the final index artifacts passed torch.allclose
. Interestingly, even though all final index artifacts were similar, the intermediate codes.pt
(centroid ID mapped to doc token IDs) was not. I did a deep dive in a separate blog post where I discovered that using Tensor.sort
results in different sort indices in torch==2.0.1
and torch==2.1.0
.
2.4.1 –> 2.5.0: BertModel
Forward Pass for Some Batch Sizes
I saw a similar result when changing the colbert-ai
PyTorch version from 2.4.1 to 2.5.0: identical sample_pids
, diverging local_sample_embs
. In this case, however, not all input_ids
caused a divergence between PyTorch versions. Specifically, inputs of the following batch sizes resulted in model layer outputs passing torch.allclose
: 71, 72, 70, 73, 68, 66, 115, 64, 63, 62, 61, 67, 69. And the following batches failed torch.allclose
: 79, 78, 77, 194, 82, 80, 90, 86, and 83. I concluded that something in PyTorch changed between 2.4.1 and 2.5.0 which made the BertModel
forward pass have batch variance (more details in this blog post). Interestingly, it was at this time that I read the excellent Thinking Machines’ blog post about LLM non-determinism.
2.7.1 –> 2.8.0: Difference in torch.nn.functional.normalize
Output
When comparing the 2.7.1 and 2.8.0 index artifacts, all artifacts but residuals.pt
passed torch.allclose
with bnb tolerances. residuals.pt
are the difference between the document token embeddings and the centroids:
= batch - centroids_ residuals_
batch
not only passes torch.allclose
between PyTorch versions, but also passes torch.equal
. Whereas centroids_
only passes torch.allclose
. Looking deeper at how centroids_
are calculated, they are normalized and then stored in half precision. The pre-norm centroids pass torch.equal
between PyTorch versions but the post-norm centroids do not. Additionally, testing this on random values, the pre-norm tensors are equal between PyTorch versions but the post-norm tensors are not. I thus concluded that the difference in index artifacts is caused by a difference in how PyTorch handles torch.nn.functional.normalize
between versions 2.7.1 and 2.8.0. You can read more details on this in another blog post.
Conclusion
In all three cases, when changing PyTorch versions, colbert-ai
indexing functionality does not break, but reproducibility does. To recap the root causes:
torch==2.0.1
–>torch==2.1.0
:BertModel
forward pass outputs diverge for any inputs +Tensor.sort
indices order changes.torch==2.4.1
–>torch==2.5.0
:BertModel
forward pass outputs diverge depending on batch size.torch==2.7.1
–>torch==2.8.0
:torch.nn.functional.normalize
outputs diverge.
I don’t think these root causes can be addressed in the colbert-ai
codebase as they seem to be purely PyTorch changes. However, I’m documenting them here (and will link this blog post in the next colbert-ai
release notes) as users will experience index artifact changes when using different PyTorch versions.
Next up: comparing and documenting search and training artifacts across PyTorch versions.
Appendix
In this section I’ll detail final and intermediate index tensor artifact comparisons between PyTorch versions where torch.allclose
was False
using default tolerances. I’ll also document integer tensor artifacts separately with torch.equal
for tensors (which I was embarrassingly until now comparing with torch.allclose
, 🤦) and ==
for non-tensors.
torch==2.0.1
vs torch==2.1.0
Final Index Artifacts
Using the more lenient bitsandbytes tolerances, avg_residual.pt
and bucket_weights.pt
pass torch.allclose
while bucket_cutoffs
and centroids
do not.
Integer Tensors
Artifact | Description | dtype | torch.equal |
---|---|---|---|
codes.pt |
centroid id mapped to doc token embeddings | torch.int32 |
False |
residuals.pt |
difference between centroid and doc token embeddings | torch.uint8 |
False |
ivf.pid.pt (ivf) |
unique pids per centroid id | torch.int32 |
shape mismatch |
ivf.pid.pt (ivf_lengths) |
number of pids per centroid id | torch.int64 |
False |
Float Tensors
Artifact | Description | dtype | Default | bnb |
---|---|---|---|---|
avg_residual.pt |
Average difference between centroids and doc token embeddings | torch.float16 |
False |
True |
buckets.pt (bucket_cutoffs ) |
The quantization bins | torch.float32 |
False |
False |
buckets.pt (bucket_weights ) |
The quantization values for each bin | torch.float16 |
False |
True |
centroids.pt |
Centroids of clustered sample doc token embeddings | torch.float16 |
False |
False |
Intermediate Index Artifacts
“Intermediate” artifacts are tensors saved in the middle of the indexing pipeline by adding torch.save
calls in /colbert/indexing/collection_indexer.py
or /colbert/modeling/checkpoint.py
.
Integer Tensors
Some of the intermediate artifacts are not tensors so the equality column I’m titling “Equal” instead of torch.equal
.
Artifact | Description | dtype | Equal |
---|---|---|---|
sample_pids.pt |
A sample of passage ids used to calculate centroids | int |
True |
num_passages.pt |
Number of sampled passages | int |
True |
doclens.pt |
List of number of tokens per document | int |
True |
Float Tensors
Using the more lenient bitsandbytes tolerances, none of the torch.allclose
calls pass.
Artifact | Description | dtype | Default | bnb |
---|---|---|---|---|
local_sample_embs.pt |
Embeddings of sample document passages used to calculate centroids | torch.float16 |
False |
False |
sample.pt |
95% of the values from local_sample_embs.half() |
torch.float16 |
False |
False |
sample_heldout.pt |
5% of the values from local_sample_embs.half() |
torch.float16 |
False |
False |
batches.pt |
1 batch of encoded passages | torch.float16 |
False |
False |
D.pt |
sorted and reshaped batches |
torch.float16 |
False |
False |
Core Difference: BertModel
Forward Pass
Swapping the local_sample_embs.pt
and embs_{chunk_idx}.pt
tensors in the torch==2.1.0
ColBERT install with the ones generated in the torch==2.0.1
install resolves all final index artifacts discrepancies, even when using default tolerances. This led me to uncover that the core difference between 2.0.1 and 2.1.0 is the BertModel
forward pass. The intermediate and final BertModel
layer outputs all fail torch.allclose
(for both sets of tolerances), no matter what the input tokens are (I tried different batch sizes and also a single letter "a"
).
“Swapping” means loading the tensor right before it’s saved:
if SWAP == 'True': local_sample_embs = torch.load(f"{SWAP_ROOT}/local_sample_embs.pt") # ADDED BY VISHAL
f"{ROOT}/local_sample_embs.pt") # ADDED BY VISHAL
torch.save(local_sample_embs, self.config.index_path_, f'sample.{self.rank}.pt')) torch.save(local_sample_embs.half(), os.path.join(
if SWAP == 'True': embs = torch.load(f"{SWAP_ROOT}/embs_{chunk_idx}.pt") # ADDED BY VISHAL
f"{ROOT}/embs_{chunk_idx}.pt") # ADDED BY VISHAL
torch.save(embs, f"{ROOT}/doclens.pt") # ADDED BY VISHAL
torch.save(doclens, self.saver.save_chunk(chunk_idx, offset, embs, doclens) # offset = first passage index in chunk
Peculiar Finding: Different Intermediate codes
Artifact Yields Identical Final ivf.pid.pt
Artifact
Even after swapping local_sample_embs.pt
and embs
, the intermediate codes
(not shown in table above) between PyTorch versions did not pass torch.allclose
(even with the more lenient bitsandbytes tolerances).
torch==2.4.1
vs torch==2.5.0
Final Index Artifacts
Integer Tensors
Artifact | Description | dtype | torch.equal |
---|---|---|---|
codes.pt |
centroid id mapped to doc token embeddings | torch.int32 |
False |
residuals.pt |
difference between centroid and doc token embeddings | torch.uint8 |
False |
ivf.pid.pt (ivf) |
unique pids per centroid id | torch.int32 |
shapes mismatch |
ivf.pid.pt (ivf_lengths) |
number of pids per centroid id | torch.int64 |
False |
Float Tensors
With bnb tolerances, avg_residual.pt
and bucket_weights
pass torch.allclose
between PyTorch versions.
Artifact | Description | dtype | Default | bnb |
---|---|---|---|---|
avg_residual.pt |
Average difference between centroids and doc token embeddings | torch.float16 |
False |
True |
buckets.pt (bucket_cutoffs ) |
The quantization bins | torch.float32 |
False |
False |
buckets.pt (bucket_weights ) |
The quantization values for each bin | torch.float16 |
False |
True |
centroids.pt |
Centroids of clustered sample doc token embeddings | torch.float16 |
False |
False |
Intermediate Index Artifacts
Integer Tensors
Artifact | Description | dtype | Equal |
---|---|---|---|
sample_pids.pt |
A sample of passage ids used to calculate centroids | int |
True |
num_passages.pt |
Number of sampled passages | int |
True |
doclens.pt |
List of number of tokens per document | int |
True |
Float Tensors
Artifact | Description | dtype | Default | bnb |
---|---|---|---|---|
local_sample_embs.pt |
Embeddings of sample document passages used to calculate centroids | torch.float16 |
False |
False |
sample.pt |
95% of the values from local_sample_embs.half() |
torch.float16 |
False |
False |
sample_heldout.pt |
5% of the values from local_sample_embs.half() |
torch.float16 |
False |
False |
batches.pt |
1 batch of encoded passages | torch.float16 |
True |
True |
D.pt |
sorted and reshaped batches |
torch.float16 |
True |
True |
batches.pt
did not pass torch.allclose
for a 1000-document subset as the final batch item had 8 items and .
Core Difference: Something in BertModel
Swapping the local_sample_embs.pt
and embs_{chunk_idx}.pt
tensors in the torch==2.5.0
ColBERT install with the ones generated in the torch==2.4.1
install resolves all final and intermediate index artifacts discrepancies, even when using the smaller default tolerances. However, it’s unclear what is causing the divergence in the BertModel
.
When sampling and embedding just the first 1000 passages (with checkpoint.bert
), the BertModel
intermediate dense layer outputs different tensors between PyTorch versions 2.4.1 and 2.5.0 when using mixed precision (for small batch sizes) this divergence also seems to be related to the number of tokens. When embedding the full dataset (69_199 passages), the third batch of 1600 passages caused a divergence in BertModel
layer outputs.
“Swapping” means loading the tensor right before it’s saved:
if SWAP == 'True': local_sample_embs = torch.load(f"{SWAP_ROOT}/local_sample_embs.pt") # ADDED BY VISHAL
f"{ROOT}/local_sample_embs.pt") # ADDED BY VISHAL
torch.save(local_sample_embs, self.config.index_path_, f'sample.{self.rank}.pt')) torch.save(local_sample_embs.half(), os.path.join(
if SWAP == 'True': embs = torch.load(f"{SWAP_ROOT}/embs_{chunk_idx}.pt") # ADDED BY VISHAL
f"{ROOT}/embs_{chunk_idx}.pt") # ADDED BY VISHAL
torch.save(embs, f"{ROOT}/doclens.pt") # ADDED BY VISHAL
torch.save(doclens, self.saver.save_chunk(chunk_idx, offset, embs, doclens) # offset = first passage index in chunk
torch==2.7.1
vs torch==2.8.0
Using the more lenient bitsandbytes tolerances, ALL torch.allclose
calls pass. It’s interesting to note that while centroids.pt
(floats) passes torch.allclose
, residuals.pt
(integers) is not equal across PyTorch versions.
Final Index Artifacts
Integer Tensors
Artifact | Description | dtype | torch.equal |
---|---|---|---|
codes.pt |
centroid id mapped to doc token embeddings | torch.int32 |
True |
residuals.pt |
difference between centroid and doc token embeddings | torch.uint8 |
False |
ivf.pid.pt (ivf) |
unique pids per centroid id | torch.int32 |
True |
ivf.pid.pt (ivf_lengths) |
number of pids per centroid id | torch.int64 |
True |
Float Tensors
Artifact | Description | dtype | Default | bnb |
---|---|---|---|---|
avg_residual.pt |
Average difference between centroids and doc token embeddings | torch.float16 |
True |
True |
buckets.pt (bucket_cutoffs ) |
The quantization bins | torch.float32 |
True |
True |
buckets.pt (bucket_weights ) |
The quantization values for each bin | torch.float16 |
True |
True |
centroids.pt |
Centroids of clustered sample doc token embeddings | torch.float16 |
False |
True |
When using default tolerances, the normalized half-precision centroids cause the floating-point error.
Intermediate Index Artifacts
All of my intermediate index artifacts pass torch.allclose
regardless of which tolerances are used.
Integer Tensors
Artifact | Description | dtype | Equal |
---|---|---|---|
sample_pids.pt |
A sample of passage ids used to calculate centroids | int |
True |
num_passages.pt |
Number of sampled passages | int |
True |
doclens.pt |
List of number of tokens per document | int |
True |
Float Tensors
Artifact | Description | dtype | Default | bnb |
---|---|---|---|---|
local_sample_embs.pt |
Embeddings of sample document passages used to calculate centroids | torch.float16 |
True |
True |
sample.pt |
95% of the values from local_sample_embs.half() |
torch.float16 |
True |
True |
sample_heldout.pt |
5% of the values from local_sample_embs.half() |
torch.float16 |
True |
True |
batches.pt |
1 batch of encoded passages | torch.float16 |
True |
True |
D.pt |
sorted and reshaped batches |
torch.float16 |
True |
True |