import torch
from colbert.indexing.loaders import load_doclens
from colbert.utils.utils import print_message, flatten
PyTorch .sort
Behavior Changes from Version 2.0.1
to 2.1.0
colbert-ai
index artifacts (for two installs using different PyTorch versions) I come across an unexpected finding–.sort
indices are ordered differently in torch==2.1.0
than in torch==2.0.1
and thus change an intermediate artifact even with all else equal. Thankfully, this doesn’t break ColBERT’s indexing functionality and final index artifacts.
Background
In this notebook I’m going to explore how (and hopefully why) you can start with different codes.indices
but end up with the same ivf
and ivf_lengths
when indexing a document collection using colbert-ai
.
I came across this behavior by accident. I was comparing final and intermediate colbert-ai
index artifacts between installs using torch==2.0.1
and torch==2.1.0
and found that even after swapping local_sample_embs.pt
(the document token embeddings clustered to find centroids) and embs_{chunk_idx}.pt
(the full set of document token embeddings) from 2.0.1 to 2.1.0, the intermediate codes.indices
(centroid ID for each document token embedding) did not pass torch.equal
but the final ivf.pid.pt
tensors did. How could that be possible? How can you start with different intermediate centroid-to-document token mappings and end up with the same final centroid-to-document token mappings? Furthermore, how can you end up with different codes.indices
when your processing the same embs
?
First a bit of review of where codes
comes from. The highest-level abstraction we start with is the Indexer
:
with Run().context(RunConfig(nranks=nranks)):
= ColBERTConfig(...)
config = Indexer(checkpoint="answerdotai/answerai-colbert-small-v1", config=config)
indexer = indexer.index(name="...", collection=collection) _
Inside Indexer
, encode
is called within which CollectionIndexer
is instantiated:
def encode(config, collection, shared_lists, shared_queues, verbose: int = 3):
= CollectionIndexer(config=config, collection=collection, verbose=verbose)
encoder encoder.run(shared_lists)
Inside CollectionIndexer.index
the following line saves (i.e. compresses and stores the residuals of) document token embeddings (the input embs
are manually forced to be identical b/w PyTorch version colbert-ai
installs):
self.saver.save_chunk(chunk_idx, offset, embs, doclens) # offset = first passage index in chunk
Once saved, the embeddings are deleted, which is why colbert-ai
is so memory efficient! It’s also why indexing and embedding are tied together with the same model.
IndexSaver.save_chunk
is defined as:
def save_chunk(self, chunk_idx, offset, embs, doclens):
= self.codec.compress(embs)
compressed_embs
self.saver_queue.put((chunk_idx, offset, compressed_embs, doclens))
The codec
is a ResidualCodec
object and its compress
method contains the following line:
= self.compress_into_codes(batch, out_device=batch.device) codes_
We’re almost there! compress_into_codes
is defined as:
def compress_into_codes(self, embs, out_device):
= []
codes
= (1 << 29) // self.centroids.size(0)
bsize for batch in embs.split(bsize):
if self.use_gpu:
= (self.centroids @ batch.T.cuda().half()).max(dim=0).indices.to(device=out_device)
indices else:
= (self.centroids @ batch.T.cpu().float()).max(dim=0).indices.to(device=out_device)
indices
codes.append(indices)
return torch.cat(codes)
So, codes
are the indices (i.e “IDs”) of the centroids with the maximum cosine similarity with the document token embeddings.
Let’s say our centroids
have shape (1024, 96)
and the batch
contains thirty-two 96-dimensional embeddings (shape (32, 96)
), each corresponding to a different document token embedding. The transpose of batch
has shape (96, 32)
and the matrix multiplication centroids @ batch.T
has shape (1024, 32)
where the rows represent centroid indices and the columns represent token indices. Taking .max(dim=0).indices
returns the row indices corresponding to the maximum value in each of the 32 columns. In other words, the 32 centroid IDs that are closest to the document token embeddings. Note that since centroids
are normalized as are document token embeddings, the matrix multiplication is the cosine similarity between the two sets of vectors.
Which goes back to my question: how can different codes
yield the same final ivf
and ivf_lengths
? And why are codes
different to begin with?
To set the stage, let’s look at how ivf
and ivf_lengths
are created, starting with CollectionIndexer._build_ivf
, the trimmed down version which is:
def _build_ivf(self):
= codes.sort()
codes = codes.indices, codes.values
ivf, values = torch.bincount(values, minlength=self.num_partitions)
ivf_lengths
= optimize_ivf(ivf, ivf_lengths, self.config.index_path_) _, _
The codes
are first sorted. The sorted indices (the document token IDs) are assigned as ivf
and the values (the centroid IDs) after being bincount
-ed (i.e. the frequency of each centroid ID—the number of tokens associated with each centroid ID) are assigned as ivf_lengths
. These are the first iteration of ivf
and ivf_lengths
and will change later on in optimize_ivf
, the trimmed down version of which is:
def optimize_ivf(orig_ivf, orig_ivf_lengths, index_path, verbose:int=3):
= load_doclens(index_path, flatten=False)
all_doclens = flatten(all_doclens)
all_doclens = sum(all_doclens)
total_num_embeddings
= torch.zeros(total_num_embeddings, dtype=torch.int)
emb2pid
= 0
offset_doclens for pid, dlength in enumerate(all_doclens):
+ dlength] = pid
emb2pid[offset_doclens: offset_doclens += dlength
offset_doclens
= emb2pid[orig_ivf]
ivf = []
unique_pids_per_centroid = []
ivf_lengths
= 0
offset for length in tqdm.tqdm(orig_ivf_lengths.tolist()):
= torch.unique(ivf[offset:offset+length])
pids
unique_pids_per_centroid.append(pids)0])
ivf_lengths.append(pids.shape[+= length
offset = torch.cat(unique_pids_per_centroid)
ivf = torch.tensor(ivf_lengths)
ivf_lengths
= os.path.join(index_path, 'ivf.pt')
original_ivf_path = os.path.join(index_path, 'ivf.pid.pt')
optimized_ivf_path
torch.save((ivf, ivf_lengths), optimized_ivf_path)
return ivf, ivf_lengths
We’ll actually start from the bottom:
= torch.cat(unique_pids_per_centroid)
ivf = torch.tensor(ivf_lengths) ivf_lengths
ivf
is a flattened tensor of pids (unique passage IDs per centroid). Looking at the loop right above this:
= 0
offset for length in tqdm.tqdm(orig_ivf_lengths.tolist()):
= torch.unique(ivf[offset:offset+length])
pids
unique_pids_per_centroid.append(pids)0])
ivf_lengths.append(pids.shape[+= length offset
ivf_lengths
is the flattened tensor of the number of pids per centroid.
So again: how can we start with different codes
(a list of centroid IDs, where the indices are the document token embedding IDs) and end up with the same ivf
(unique pids corresponding to centroids) and ivf_lengths
(number of pids per centroid)?
I fed this background section to Sonnet 4 (with the stanford-futuredata/ColBERT repo attached as Project Knowledge) to fact check me and it said:
The key insight is that the final IVF only cares about which passages are associated with each centroid, not which specific token embeddings within those passages. If the different codes still result in the same set of passages being associated with each centroid (even if individual token assignments differ), the final ivf and ivf_lengths would be identical
TBD if that’s correct, certainly seems plausible!
Inspecting codes
First, I’ll show that codes.indices
(document token IDs) are not equal between my torch==2.0.1
install and the torch==2.1.0
install where I swapped its local_sample_embs.pt
and embs_{chunk_idx}.pt
with 2.0.1’s tensors. In other words, I “forced” the 2.1.0
install to cluster the same sample of document token embeddings when calculating centroids and then forced it to use the same document token embeddings to compress as residuals and centroid IDs.
= torch.load("20250909-0.2.22.main.torch.2.0.1-1/ivf.pt")
codes_indices_a = torch.load("20250909-0.2.22.main.torch.2.1.0-swap-1/ivf.pt")
codes_indices_b = torch.load("20250909-0.2.22.main.torch.2.0.1-1/values.pt")
codes_values_a = torch.load("20250909-0.2.22.main.torch.2.1.0-swap-1/values.pt")
codes_values_b torch.equal(codes_indices_a, codes_indices_b), torch.equal(codes_values_a, codes_values_b)
(False, True)
codes_values_a
tensor([ 0, 0, 0, ..., 16383, 16383, 16383])
codes_values_b
tensor([ 0, 0, 0, ..., 16383, 16383, 16383])
Note that codes.values
(the centroid IDs) are identical. So which centroid IDs are closest to the document token embeddings stays consistent across versions, but which document token IDs they correspond to does not.
= torch.load("20250909-0.2.22.main.torch.2.0.1-1/indexing/ConditionalQA/ivf.pid.pt")
ivf_a, ivf_lengths_a = torch.load("20250909-0.2.22.main.torch.2.1.0-swap-1/indexing/ConditionalQA/ivf.pid.pt")
ivf_b, ivf_lengths_b torch.equal(ivf_a, ivf_b), torch.equal(ivf_lengths_a, ivf_lengths_b)
(True, True)
Furthermore, the final unique passage IDs for each centroid (ivf
) and the number of passage IDs per centroid ID (ivf_lengths
) are equal across versions. What this tells me (re: Sonnet’s hypothesis) is that the document token IDs, while dissimilar across versions, come from the same passages!
Recreating optimize_ivf
To explore the relationship between document token IDs and passage IDs, I’ll use the code in optimize_ivf
, where initially, ivf
means codes.indices
and ivf_lengths
mean torch.bincount(codes.values)
.
= torch.bincount(codes_values_a, minlength=16384)
codes_values_a = torch.bincount(codes_values_b, minlength=16384) codes_values_b
codes_values_a, codes_values_b
(tensor([1110, 36, 173, ..., 104, 95, 25]),
tensor([1110, 36, 173, ..., 104, 95, 25]))
I’ll start by loading the mapping between passages and tokens: doclens
.
= load_doclens("20250909-0.2.22.main.torch.2.0.1-1/indexing/ConditionalQA/", flatten=False)
all_doclens_a = flatten(all_doclens_a)
all_doclens_a = sum(all_doclens_a)
total_num_embeddings_a
= load_doclens("20250909-0.2.22.main.torch.2.1.0-swap-1/indexing/ConditionalQA", flatten=False)
all_doclens_b = flatten(all_doclens_b)
all_doclens_b = sum(all_doclens_b) total_num_embeddings_b
== all_doclens_b all_doclens_a
True
== total_num_embeddings_b total_num_embeddings_a
True
total_num_embeddings_b
1146937
Next we create emb2pid
which is a tensor that has 1146937 indices (one for each token across the collection) and values (passage IDs).
def _emb2pid(total_num_embeddings, all_doclens):
= torch.zeros(total_num_embeddings, dtype=torch.int)
emb2pid = 0
offset_doclens for pid, dlength in enumerate(all_doclens):
+ dlength] = pid
emb2pid[offset_doclens: offset_doclens += dlength
offset_doclens return emb2pid
= _emb2pid(total_num_embeddings_a, all_doclens_a)
emb2pid_a = _emb2pid(total_num_embeddings_b, all_doclens_b) emb2pid_b
emb2pid_a
tensor([ 0, 0, 0, ..., 69198, 69198, 69198], dtype=torch.int32)
emb2pid_b
tensor([ 0, 0, 0, ..., 69198, 69198, 69198], dtype=torch.int32)
torch.equal(emb2pid_a, emb2pid_b)
True
The first three tokens we see correspond to passage ID 0
, and the last three tokens to passage ID 69198
.
Let’s now see if the tokens in the two codes_indices
come from the same passages.
codes_indices_a
tensor([377624, 285309, 285322, ..., 117986, 118780, 128088])
codes_indices_b
tensor([ 2776, 2808, 5974, ..., 309906, 579450, 884128])
= emb2pid_a[codes_indices_a]
pids_a = emb2pid_b[codes_indices_b] pids_b
pids_a
tensor([23120, 17145, 17145, ..., 7128, 7172, 7691], dtype=torch.int32)
pids_b
tensor([ 170, 170, 377, ..., 18739, 35561, 53527], dtype=torch.int32)
Looking at the resulting passage IDs: the first two tokens of pids_a
(torch==2.0.1
) come from passages 23120
and 17145
, respectively. The first two tokens of pids_b
(torch==2.0.1
swapped) come from passage 170
.
If we count the number of times each passage ID occurs in each tensor (pids_a
or pids_b
) they are identical! This is the first hint of Sonnet’s hypothesis.
torch.equal(torch.bincount(pids_a), torch.bincount(pids_b))
True
Let’s keep moving along in recreating optimize_ivf
:
= emb2pid[orig_ivf]
ivf = []
unique_pids_per_centroid = []
ivf_lengths
= 0
offset for length in tqdm.tqdm(orig_ivf_lengths.tolist()):
= torch.unique(ivf[offset:offset+length])
pids
unique_pids_per_centroid.append(pids)0])
ivf_lengths.append(pids.shape[+= length
offset = torch.cat(unique_pids_per_centroid)
ivf = torch.tensor(ivf_lengths) ivf_lengths
Instead of:
= emb2pid[orig_ivf] ivf
I did:
= emb2pid_a[codes_indices_a] pids_a
I’ll move onto the for loop:
def _loop(orig_ivf_lengths, ivf):
= []
unique_pids_per_centroid = []
ivf_lengths = 0
offset for length in orig_ivf_lengths.tolist():
= torch.unique(ivf[offset:offset+length])
pids
unique_pids_per_centroid.append(pids)0])
ivf_lengths.append(pids.shape[+= length
offset return unique_pids_per_centroid, ivf_lengths
= _loop(orig_ivf_lengths=codes_values_a, ivf=codes_indices_a)
unique_pids_per_centroid_a, _ivf_lengths_a = _loop(orig_ivf_lengths=codes_values_b, ivf=codes_indices_b) unique_pids_per_centroid_b, _ivf_lengths_b
for idx, item in enumerate(unique_pids_per_centroid_a): assert torch.equal(item, unique_pids_per_centroid_b[idx])
And there we see it! While the order of the passage IDs is different, both codes.indices
tensors contain the same unique passage IDs per centroid. The key reason for this is that in the for-loop, we use torch.unique
which sorts the values in ascending order. So as long as the set of ivf[offset:offset_length]
passages IDs are identical across PyTorch versions, even if sorted differently will have the same order after being sorted by torch.unique
.
Conclusion
Let’s revisit Sonnet 4’s hypothesis:
The key insight is that the final IVF only cares about which passages are associated with each centroid, not which specific token embeddings within those passages
While true, the reality was a bit different—the specific token IDs are identical across torch versions, it’s just that they are sorted differently! However, this begs the question: why are the token IDs sorted differently across torch versions? I’ll explore that in the Appendix section below.
Appendix
To understand how the max cosine similarity calculation deviates between torch==2.0.1
and torch==2.1.0
(using 2.0.1
’s local_sample_embs.pt
) I’ll start by comparing the embs
that I torch.save
-d right before they were compressed. This is more of a sanity check as these were explicitly swapped from torch==2.0.1
.
for f in ["embs_0.pt", "embs_1.pt", "embs_2.pt"]:
= torch.load(f"20250909-0.2.22.main.torch.2.0.1-1/{f}")
a = torch.load(f"20250909-0.2.22.main.torch.2.1.0-swap-1/{f}")
b assert torch.allclose(a, b, atol=1e-4, rtol=1e-3)
They are all close enough! Next, I’ll compare the single batch and centroids that I saved in the ResidualCodec.compress_into_codes
method.
= torch.load(f"20250909-0.2.22.main.torch.2.0.1-1/compress_batch.pt")
batch_a = torch.load(f"20250909-0.2.22.main.torch.2.1.0-swap-1/compress_batch.pt")
batch_b =1e-4, rtol=1e-3) torch.allclose(batch_a, batch_b, atol
True
= torch.load(f"20250909-0.2.22.main.torch.2.0.1-1/compress_centroids.pt")
centroids_a = torch.load(f"20250909-0.2.22.main.torch.2.1.0-swap-1/compress_centroids.pt")
centroids_b =1e-4, rtol=1e-3) torch.allclose(centroids_a, centroids_b, atol
True
Both the token embeddings and the centroids are close enough (both are float16). Next I’ll compare a batch of codes
(indices
) saved inside compress_into_codes
:
= torch.load(f"20250909-0.2.22.main.torch.2.0.1-1/compress_indices.pt")
indices_a = torch.load(f"20250909-0.2.22.main.torch.2.1.0-swap-1/compress_indices.pt")
indices_b torch.equal(indices_a, indices_b)
True
Interestingly, they are equal across the PyTorch versions. Next I’ll compare the codes
for each batch of embs
in compress
:
for f in ["compress_codes_0.pt", "compress_codes_1.pt", "compress_codes_2.pt"]:
= torch.load(f"20250909-0.2.22.main.torch.2.0.1-1/{f}")
a = torch.load(f"20250909-0.2.22.main.torch.2.1.0-swap-1/{f}")
b assert torch.equal(a, b)
They are all equal as well!
At this point it was clear to me that the cosine similarity calculation was not the root cause of the codes.indices
diverging between PyTorch versions. The next place to look: the sorting of codes! I added a line in CollectionIndexer._build_ivf
which saved the pre-sorted codes
.
Surprisingly: the codes
before being sorted are identical between PyTorch versions.
= torch.load(f"20250909-0.2.22.main.torch.2.0.1-1/presort_codes.pt")
a = torch.load(f"20250909-0.2.22.main.torch.2.1.0-swap-1/presort_codes.pt") b
a
tensor([ 1269, 582, 10939, ..., 5013, 4582, 431])
b
tensor([ 1269, 582, 10939, ..., 5013, 4582, 431])
torch.equal(a,b)
True
However, after being sorted the codes.indices
diverge:
= torch.load(f"20250909-0.2.22.main.torch.2.0.1-1/codes.pt")
a = torch.load(f"20250909-0.2.22.main.torch.2.1.0-swap-1/codes.pt") b
a
torch.return_types.sort(
values=tensor([ 0, 0, 0, ..., 16383, 16383, 16383]),
indices=tensor([377624, 285309, 285322, ..., 117986, 118780, 128088]))
b
torch.return_types.sort(
values=tensor([ 0, 0, 0, ..., 16383, 16383, 16383]),
indices=tensor([ 2776, 2808, 5974, ..., 309906, 579450, 884128]))
torch.equal(a.indices, b.indices)
False
There was the source of discrepancy! The order of indices after being sorted!
Is this the case for all sort
calls between these PyTorch versions? To test this, I ran the following code with each PyTorch install (torch==2.0.1
and torch==2.1.0
) and saved t
before and after .sort
was called:
import torch
42)
torch.manual_seed(= torch.randint(low=0, high=16383, size=(1146937,))
t = t.sort() t
For both PyTorch versions, t.indices
was not equal (i.e. torch.equal
was False
). This is evidence that sort
’s behavior changes from 2.0.1 to 2.1.0. After keyword searching the release notes, I couldn’t find a PR that could be the culprit.
Thankfully, colbert-ai
is robust to such changes! Since we only care about the unique passage IDs (and number of passage IDs) for ivf
and ivf_lengths
, respectively, and not the order of token IDs, this PyTorch change does not break the indexing pipeline.