Debugging ColBERT Index Differences Between PyTorch 2.7.1 and 2.8.0
torch.allclose
tolerances). Through systematic debugging of intermediate tensors, I traced the root cause to precision changes in PyTorch’s vector normalization implementation.
Background
I’ve been redoing my colbert-ai
index comparisons between PyTorch versions using bitsandbytes’ torch.allclose
tolerances. There are three PyTorch version changes that cause index artifact changes: 2.0.1 to 2.1.0 (BertModel
forward pass outputs diverge for all inputs), 2.4.1 to 2.5.0 (certain batch sizes cause BertModel
output divergence), and 2.7.1 to 2.8.0 (detailed in this blog post).
Difference Between PyTorch Versions: residuals.pt
When using the bitsandbytes’ torch.allclose
tolerances, all final index artifacts pass torch.allclose
except residuals.pt
. Residuals are a key component in the indexing pipeline, they are the distance between document token embeddings and centroids. From residual.py’s ResidualCodec.compress
:
def compress(self, embs, chunk_idx): # chunk_idx ADDED BY VISHAL
= [], []
codes, residuals
for batch in embs.split(1 << 18):
if self.use_gpu:
= batch.cuda().half()
batch = self.compress_into_codes(batch, out_device=batch.device)
codes_ = self.lookup_centroids(codes_, out_device=batch.device)
centroids_
= (batch - centroids_)
residuals_ f"{ROOT}/residuals__{chunk_idx}.pt") # ADDED BY VISHAL
torch.save(residuals_, f"{ROOT}/codes__{chunk_idx}.pt") # ADDED BY VISHAL
torch.save(codes_, f"{ROOT}/batch_{chunk_idx}.pt") # ADDED BY VISHAL
torch.save(batch, f"{ROOT}/centroids__{chunk_idx}.pt") # ADDED BY VISHAL
torch.save(centroids_,
codes.append(codes_.cpu())self.binarize(residuals_).cpu())
residuals.append(
= torch.cat(codes)
codes f"{ROOT}/compress_codes_{chunk_idx}.pt")
torch.save(codes, = torch.cat(residuals)
residuals
return ResidualCodec.Embeddings(codes, residuals)
The key line is:
= (batch - centroids_) residuals_
As you can see above, I have added torch.save
calls to compare those intermediate index artifacts between PyTorch versions.
I figured that since residuals_
do not pass torch.allclose
between torch versions, batch
and centroids_
must not as well. I was wrong! batch
does not only pass torch.allclose
but also passes torch.equal
between torch versions. centroids_
passes torch.allclose
but not torch.equal
. Even though centroids_
values are within floating-point tolerance (torch.allclose
passes), the small differences get amplified during the subtraction operation that creates residuals_
. This amplification pushes the final result outside the tolerance bounds, causing residuals_
to fail torch.allclose
.
Difference Between PyTorch Versions: centroids.pt
This begs the question: why are centroids_
between PyTorch versions not exactly equal? In other words, why don’t centroids_
pass torch.equal
like batch
does? To figure this out, I added torch.save
calls to _train_kmeans
where the centroids are created:
def _train_kmeans(self, sample, shared_lists):
= compute_faiss_kmeans(*args_)
centroids f"{ROOT}/prenorm_centroids.pt") # ADDED BY VISHAL
torch.save(centroids, = torch.nn.functional.normalize(centroids, dim=-1)
centroids if POSTNORM_CENTROIDS_SWAP == "True": centroids = torch.load(f"{POSTNORM_CENTROIDS_SWAP_ROOT}/postnorm_centroids.pt") # ADDED BY VISHAL
f"{ROOT}/postnorm_centroids.pt") # ADDED BY VISHAL
torch.save(centroids, if self.use_gpu:
= centroids.half()
centroids f"{ROOT}/half_centroids.pt") # ADDED BY VISHAL
torch.save(centroids, else:
= centroids.float()
centroids
return centroids
There are three versions of centroids I save: prenorm_centroids.pt
(the output of compute_faiss_kmeans
), postnorm_centroids.pt
(the output of torch.nn.functional.normalize(centroids, dim=-1)
) and half_centroids.pt
(the output of centroids.half()
).
I compare each tensor (created with torch==2.7.1
and torch==2.8.0
) with both torch.allclose
and torch.equal
:
Tensor | torch.allclose |
torch.equal |
---|---|---|
prenorm_centroids.pt | True |
True |
postnorm_centroids.pt | True |
False |
half_centroids.pt | True |
False |
The pre-norm centroids are exactly the same between PyTorch versions, but the post-norm centroids are not. To confirm that the divergence between PyTorch versions is the normalization operation, I replace the 2.8.0 postnorm_centroids.pt
with the 2.7.1 ones (the if POSTNORM_CENTROIDS_SWAP == "True"
line in the code above) and all final and intermediate index artifacts (including residuals.pt
) pass torch.allclose
between PyTorch versions.
To confirm that there exists a difference in normalization between PyTorch versions 2.7.1 and 2.8.0 I generate the following tensors with each install:
13)
torch.manual_seed(= torch.empty(1024, 96).uniform_(-0.4, 0.4)
t f"{MOUNT}/{project}/{date}-{source}-{nranks}/t.pt")
torch.save(t, f"{MOUNT}/{project}/{date}-{source}-{nranks}/half_t.pt")
torch.save(t.half(),
= torch.nn.functional.normalize(t, dim=-1)
t f"{MOUNT}/{project}/{date}-{source}-{nranks}/norm.pt")
torch.save(t, f"{MOUNT}/{project}/{date}-{source}-{nranks}/half_norm.pt") torch.save(t.half(),
Comparing the four tensors (t.pt
, half_t.pt
, norm.pt
and half_norm.pt
) between PyTorch versions:
Tensor | torch.allclose |
torch.equal |
---|---|---|
t.pt | True |
True |
half_t.pt | True |
True |
norm.pt | True |
False |
half_norm.pt | True |
False |
While all tensors pass torch.allclose
(bnb tolerances) the normalized tensors (full precision and half precision) fail torch.equal
between PyTorch versions. When used in further operations (as centroids_
are when calculating residuals_ = batch - centroids_
) this inequality compounds and amplifies floating point differences enough to fail torch.allclose
for the residuals.
Closing Thoughts
When working with floating point values, it’s easy to dismiss minor differences. The recent Thinking Machines’ blog post communicated this sentiment:
What’s wrong with bumping up the atol/rtol on the failing unit test?
As I’ve been exploring colbert-ai
index artifact differences across PyTorch versions, it’s been tempting to consider that “fix”. However, by caring about failed torch.allclose
or torch.equal
I’ve learned a lot about how small differences impact index artifacts downstream, and have gained a better understanding of how changes in PyTorch can impact colbert-ai
. While I may not cover all such impacts, I’m hoping that documenting them here will help some engineer somewhere who is debugging why their RAG pipeline has subtle changes after bumping up PyTorch versions.