PyTorch .sort Behavior Changes from Version 2.0.1 to 2.1.0

ColBERT
While comparing colbert-ai index artifacts (for two installs using different PyTorch versions) I come across an unexpected finding–.sort indices are ordered differently in torch==2.1.0 than in torch==2.0.1 and thus change an intermediate artifact even with all else equal. Thankfully, this doesn’t break ColBERT’s indexing functionality and final index artifacts.
Author

Vishal Bakshi

Published

September 9, 2025

Background

In this notebook I’m going to explore how (and hopefully why) you can start with different codes.indices but end up with the same ivf and ivf_lengths when indexing a document collection using colbert-ai.

I came across this behavior by accident. I was comparing final and intermediate colbert-ai index artifacts between installs using torch==2.0.1 and torch==2.1.0 and found that even after swapping local_sample_embs.pt (the document token embeddings clustered to find centroids) and embs_{chunk_idx}.pt (the full set of document token embeddings) from 2.0.1 to 2.1.0, the intermediate codes.indices (centroid ID for each document token embedding) did not pass torch.equal but the final ivf.pid.pt tensors did. How could that be possible? How can you start with different intermediate centroid-to-document token mappings and end up with the same final centroid-to-document token mappings? Furthermore, how can you end up with different codes.indices when your processing the same embs?

First a bit of review of where codes comes from. The highest-level abstraction we start with is the Indexer:

with Run().context(RunConfig(nranks=nranks)):
    config = ColBERTConfig(...)
    indexer = Indexer(checkpoint="answerdotai/answerai-colbert-small-v1", config=config)
    _ = indexer.index(name="...", collection=collection)

Inside Indexer, encode is called within which CollectionIndexer is instantiated:

def encode(config, collection, shared_lists, shared_queues, verbose: int = 3):
    encoder = CollectionIndexer(config=config, collection=collection, verbose=verbose)
    encoder.run(shared_lists)

Inside CollectionIndexer.index the following line saves (i.e. compresses and stores the residuals of) document token embeddings (the input embs are manually forced to be identical b/w PyTorch version colbert-ai installs):

self.saver.save_chunk(chunk_idx, offset, embs, doclens) # offset = first passage index in chunk

Once saved, the embeddings are deleted, which is why colbert-ai is so memory efficient! It’s also why indexing and embedding are tied together with the same model.

IndexSaver.save_chunk is defined as:

def save_chunk(self, chunk_idx, offset, embs, doclens):
    compressed_embs = self.codec.compress(embs)
    
    self.saver_queue.put((chunk_idx, offset, compressed_embs, doclens))

The codec is a ResidualCodec object and its compress method contains the following line:

codes_ = self.compress_into_codes(batch, out_device=batch.device)

We’re almost there! compress_into_codes is defined as:

def compress_into_codes(self, embs, out_device):
    codes = []

    bsize = (1 << 29) // self.centroids.size(0)
    for batch in embs.split(bsize):
        if self.use_gpu:
            indices = (self.centroids @ batch.T.cuda().half()).max(dim=0).indices.to(device=out_device)
        else:
            indices = (self.centroids @ batch.T.cpu().float()).max(dim=0).indices.to(device=out_device)
        codes.append(indices)

    return torch.cat(codes)

So, codes are the indices (i.e “IDs”) of the centroids with the maximum cosine similarity with the document token embeddings.

Let’s say our centroids have shape (1024, 96) and the batch contains thirty-two 96-dimensional embeddings (shape (32, 96)), each corresponding to a different document token embedding. The transpose of batch has shape (96, 32) and the matrix multiplication centroids @ batch.T has shape (1024, 32) where the rows represent centroid indices and the columns represent token indices. Taking .max(dim=0).indices returns the row indices corresponding to the maximum value in each of the 32 columns. In other words, the 32 centroid IDs that are closest to the document token embeddings. Note that since centroids are normalized as are document token embeddings, the matrix multiplication is the cosine similarity between the two sets of vectors.

Which goes back to my question: how can different codes yield the same final ivf and ivf_lengths? And why are codes different to begin with?

To set the stage, let’s look at how ivf and ivf_lengths are created, starting with CollectionIndexer._build_ivf, the trimmed down version which is:

def _build_ivf(self):
    codes = codes.sort()
    ivf, values = codes.indices, codes.values
    ivf_lengths = torch.bincount(values, minlength=self.num_partitions)

    _, _ = optimize_ivf(ivf, ivf_lengths, self.config.index_path_)

The codes are first sorted. The sorted indices (the document token IDs) are assigned as ivf and the values (the centroid IDs) after being bincount-ed (i.e. the frequency of each centroid ID—the number of tokens associated with each centroid ID) are assigned as ivf_lengths. These are the first iteration of ivf and ivf_lengths and will change later on in optimize_ivf, the trimmed down version of which is:

def optimize_ivf(orig_ivf, orig_ivf_lengths, index_path, verbose:int=3):
    all_doclens = load_doclens(index_path, flatten=False)
    all_doclens = flatten(all_doclens)
    total_num_embeddings = sum(all_doclens)

    emb2pid = torch.zeros(total_num_embeddings, dtype=torch.int)

    offset_doclens = 0
    for pid, dlength in enumerate(all_doclens):
        emb2pid[offset_doclens: offset_doclens + dlength] = pid
        offset_doclens += dlength

    ivf = emb2pid[orig_ivf]
    unique_pids_per_centroid = []
    ivf_lengths = []

    offset = 0
    for length in tqdm.tqdm(orig_ivf_lengths.tolist()):
        pids = torch.unique(ivf[offset:offset+length])
        unique_pids_per_centroid.append(pids)
        ivf_lengths.append(pids.shape[0])
        offset += length
    ivf = torch.cat(unique_pids_per_centroid)
    ivf_lengths = torch.tensor(ivf_lengths)
    
    original_ivf_path = os.path.join(index_path, 'ivf.pt')
    optimized_ivf_path = os.path.join(index_path, 'ivf.pid.pt')
    torch.save((ivf, ivf_lengths), optimized_ivf_path)

    return ivf, ivf_lengths

We’ll actually start from the bottom:

ivf = torch.cat(unique_pids_per_centroid)
ivf_lengths = torch.tensor(ivf_lengths)

ivf is a flattened tensor of pids (unique passage IDs per centroid). Looking at the loop right above this:

offset = 0
for length in tqdm.tqdm(orig_ivf_lengths.tolist()):
    pids = torch.unique(ivf[offset:offset+length])
    unique_pids_per_centroid.append(pids)
    ivf_lengths.append(pids.shape[0])
    offset += length

ivf_lengths is the flattened tensor of the number of pids per centroid.

So again: how can we start with different codes (a list of centroid IDs, where the indices are the document token embedding IDs) and end up with the same ivf (unique pids corresponding to centroids) and ivf_lengths (number of pids per centroid)?

I fed this background section to Sonnet 4 (with the stanford-futuredata/ColBERT repo attached as Project Knowledge) to fact check me and it said:

The key insight is that the final IVF only cares about which passages are associated with each centroid, not which specific token embeddings within those passages. If the different codes still result in the same set of passages being associated with each centroid (even if individual token assignments differ), the final ivf and ivf_lengths would be identical

TBD if that’s correct, certainly seems plausible!

Inspecting codes

First, I’ll show that codes.indices (document token IDs) are not equal between my torch==2.0.1 install and the torch==2.1.0 install where I swapped its local_sample_embs.pt and embs_{chunk_idx}.pt with 2.0.1’s tensors. In other words, I “forced” the 2.1.0 install to cluster the same sample of document token embeddings when calculating centroids and then forced it to use the same document token embeddings to compress as residuals and centroid IDs.

import torch
from colbert.indexing.loaders import load_doclens
from colbert.utils.utils import print_message, flatten
codes_indices_a = torch.load("20250909-0.2.22.main.torch.2.0.1-1/ivf.pt")
codes_indices_b = torch.load("20250909-0.2.22.main.torch.2.1.0-swap-1/ivf.pt")
codes_values_a = torch.load("20250909-0.2.22.main.torch.2.0.1-1/values.pt")
codes_values_b = torch.load("20250909-0.2.22.main.torch.2.1.0-swap-1/values.pt")
torch.equal(codes_indices_a, codes_indices_b), torch.equal(codes_values_a, codes_values_b)
(False, True)
codes_values_a
tensor([    0,     0,     0,  ..., 16383, 16383, 16383])
codes_values_b
tensor([    0,     0,     0,  ..., 16383, 16383, 16383])

Note that codes.values (the centroid IDs) are identical. So which centroid IDs are closest to the document token embeddings stays consistent across versions, but which document token IDs they correspond to does not.

ivf_a, ivf_lengths_a = torch.load("20250909-0.2.22.main.torch.2.0.1-1/indexing/ConditionalQA/ivf.pid.pt")
ivf_b, ivf_lengths_b = torch.load("20250909-0.2.22.main.torch.2.1.0-swap-1/indexing/ConditionalQA/ivf.pid.pt")
torch.equal(ivf_a, ivf_b), torch.equal(ivf_lengths_a, ivf_lengths_b)
(True, True)

Furthermore, the final unique passage IDs for each centroid (ivf) and the number of passage IDs per centroid ID (ivf_lengths) are equal across versions. What this tells me (re: Sonnet’s hypothesis) is that the document token IDs, while dissimilar across versions, come from the same passages!

Recreating optimize_ivf

To explore the relationship between document token IDs and passage IDs, I’ll use the code in optimize_ivf, where initially, ivf means codes.indices and ivf_lengths mean torch.bincount(codes.values).

codes_values_a = torch.bincount(codes_values_a, minlength=16384)
codes_values_b = torch.bincount(codes_values_b, minlength=16384)
codes_values_a, codes_values_b
(tensor([1110,   36,  173,  ...,  104,   95,   25]),
 tensor([1110,   36,  173,  ...,  104,   95,   25]))

I’ll start by loading the mapping between passages and tokens: doclens.

all_doclens_a = load_doclens("20250909-0.2.22.main.torch.2.0.1-1/indexing/ConditionalQA/", flatten=False)
all_doclens_a = flatten(all_doclens_a)
total_num_embeddings_a = sum(all_doclens_a)

all_doclens_b = load_doclens("20250909-0.2.22.main.torch.2.1.0-swap-1/indexing/ConditionalQA", flatten=False)
all_doclens_b = flatten(all_doclens_b)
total_num_embeddings_b = sum(all_doclens_b)
all_doclens_a == all_doclens_b
True
total_num_embeddings_a == total_num_embeddings_b
True
total_num_embeddings_b
1146937

Next we create emb2pid which is a tensor that has 1146937 indices (one for each token across the collection) and values (passage IDs).

def _emb2pid(total_num_embeddings, all_doclens):
    emb2pid = torch.zeros(total_num_embeddings, dtype=torch.int)
    offset_doclens = 0
    for pid, dlength in enumerate(all_doclens):
        emb2pid[offset_doclens: offset_doclens + dlength] = pid
        offset_doclens += dlength
    return emb2pid
emb2pid_a = _emb2pid(total_num_embeddings_a, all_doclens_a)
emb2pid_b = _emb2pid(total_num_embeddings_b, all_doclens_b)
emb2pid_a
tensor([    0,     0,     0,  ..., 69198, 69198, 69198], dtype=torch.int32)
emb2pid_b
tensor([    0,     0,     0,  ..., 69198, 69198, 69198], dtype=torch.int32)
torch.equal(emb2pid_a, emb2pid_b)
True

The first three tokens we see correspond to passage ID 0, and the last three tokens to passage ID 69198.

Let’s now see if the tokens in the two codes_indices come from the same passages.

codes_indices_a
tensor([377624, 285309, 285322,  ..., 117986, 118780, 128088])
codes_indices_b
tensor([  2776,   2808,   5974,  ..., 309906, 579450, 884128])
pids_a = emb2pid_a[codes_indices_a]
pids_b = emb2pid_b[codes_indices_b]
pids_a
tensor([23120, 17145, 17145,  ...,  7128,  7172,  7691], dtype=torch.int32)
pids_b
tensor([  170,   170,   377,  ..., 18739, 35561, 53527], dtype=torch.int32)

Looking at the resulting passage IDs: the first two tokens of pids_a (torch==2.0.1) come from passages 23120 and 17145, respectively. The first two tokens of pids_b (torch==2.0.1 swapped) come from passage 170.

If we count the number of times each passage ID occurs in each tensor (pids_a or pids_b) they are identical! This is the first hint of Sonnet’s hypothesis.

torch.equal(torch.bincount(pids_a), torch.bincount(pids_b))
True

Let’s keep moving along in recreating optimize_ivf:

ivf = emb2pid[orig_ivf]
unique_pids_per_centroid = []
ivf_lengths = []

offset = 0
for length in tqdm.tqdm(orig_ivf_lengths.tolist()):
    pids = torch.unique(ivf[offset:offset+length])
    unique_pids_per_centroid.append(pids)
    ivf_lengths.append(pids.shape[0])
    offset += length
ivf = torch.cat(unique_pids_per_centroid)
ivf_lengths = torch.tensor(ivf_lengths)

Instead of:

ivf = emb2pid[orig_ivf]

I did:

pids_a = emb2pid_a[codes_indices_a]

I’ll move onto the for loop:

def _loop(orig_ivf_lengths, ivf):
    unique_pids_per_centroid = []
    ivf_lengths = []
    offset = 0
    for length in orig_ivf_lengths.tolist():
        pids = torch.unique(ivf[offset:offset+length])
        unique_pids_per_centroid.append(pids)
        ivf_lengths.append(pids.shape[0])
        offset += length
    return unique_pids_per_centroid, ivf_lengths
unique_pids_per_centroid_a, _ivf_lengths_a = _loop(orig_ivf_lengths=codes_values_a, ivf=codes_indices_a)
unique_pids_per_centroid_b, _ivf_lengths_b = _loop(orig_ivf_lengths=codes_values_b, ivf=codes_indices_b)
for idx, item in enumerate(unique_pids_per_centroid_a): assert torch.equal(item, unique_pids_per_centroid_b[idx])

And there we see it! While the order of the passage IDs is different, both codes.indices tensors contain the same unique passage IDs per centroid. The key reason for this is that in the for-loop, we use torch.unique which sorts the values in ascending order. So as long as the set of ivf[offset:offset_length] passages IDs are identical across PyTorch versions, even if sorted differently will have the same order after being sorted by torch.unique.

Conclusion

Let’s revisit Sonnet 4’s hypothesis:

The key insight is that the final IVF only cares about which passages are associated with each centroid, not which specific token embeddings within those passages

While true, the reality was a bit different—the specific token IDs are identical across torch versions, it’s just that they are sorted differently! However, this begs the question: why are the token IDs sorted differently across torch versions? I’ll explore that in the Appendix section below.

Appendix

To understand how the max cosine similarity calculation deviates between torch==2.0.1 and torch==2.1.0 (using 2.0.1’s local_sample_embs.pt) I’ll start by comparing the embs that I torch.save-d right before they were compressed. This is more of a sanity check as these were explicitly swapped from torch==2.0.1.

for f in ["embs_0.pt", "embs_1.pt", "embs_2.pt"]:
    a = torch.load(f"20250909-0.2.22.main.torch.2.0.1-1/{f}")
    b = torch.load(f"20250909-0.2.22.main.torch.2.1.0-swap-1/{f}")
    assert torch.allclose(a, b, atol=1e-4, rtol=1e-3)

They are all close enough! Next, I’ll compare the single batch and centroids that I saved in the ResidualCodec.compress_into_codes method.

batch_a = torch.load(f"20250909-0.2.22.main.torch.2.0.1-1/compress_batch.pt")
batch_b = torch.load(f"20250909-0.2.22.main.torch.2.1.0-swap-1/compress_batch.pt")
torch.allclose(batch_a, batch_b, atol=1e-4, rtol=1e-3)
True
centroids_a = torch.load(f"20250909-0.2.22.main.torch.2.0.1-1/compress_centroids.pt")
centroids_b = torch.load(f"20250909-0.2.22.main.torch.2.1.0-swap-1/compress_centroids.pt")
torch.allclose(centroids_a, centroids_b, atol=1e-4, rtol=1e-3)
True

Both the token embeddings and the centroids are close enough (both are float16). Next I’ll compare a batch of codes (indices) saved inside compress_into_codes:

indices_a = torch.load(f"20250909-0.2.22.main.torch.2.0.1-1/compress_indices.pt")
indices_b = torch.load(f"20250909-0.2.22.main.torch.2.1.0-swap-1/compress_indices.pt")
torch.equal(indices_a, indices_b)
True

Interestingly, they are equal across the PyTorch versions. Next I’ll compare the codes for each batch of embs in compress:

for f in ["compress_codes_0.pt", "compress_codes_1.pt", "compress_codes_2.pt"]:
    a = torch.load(f"20250909-0.2.22.main.torch.2.0.1-1/{f}")
    b = torch.load(f"20250909-0.2.22.main.torch.2.1.0-swap-1/{f}")
    assert torch.equal(a, b)

They are all equal as well!

At this point it was clear to me that the cosine similarity calculation was not the root cause of the codes.indices diverging between PyTorch versions. The next place to look: the sorting of codes! I added a line in CollectionIndexer._build_ivf which saved the pre-sorted codes.

Surprisingly: the codes before being sorted are identical between PyTorch versions.

a = torch.load(f"20250909-0.2.22.main.torch.2.0.1-1/presort_codes.pt")
b = torch.load(f"20250909-0.2.22.main.torch.2.1.0-swap-1/presort_codes.pt")
a
tensor([ 1269,   582, 10939,  ...,  5013,  4582,   431])
b
tensor([ 1269,   582, 10939,  ...,  5013,  4582,   431])
torch.equal(a,b)
True

However, after being sorted the codes.indices diverge:

a = torch.load(f"20250909-0.2.22.main.torch.2.0.1-1/codes.pt")
b = torch.load(f"20250909-0.2.22.main.torch.2.1.0-swap-1/codes.pt")
a
torch.return_types.sort(
values=tensor([    0,     0,     0,  ..., 16383, 16383, 16383]),
indices=tensor([377624, 285309, 285322,  ..., 117986, 118780, 128088]))
b
torch.return_types.sort(
values=tensor([    0,     0,     0,  ..., 16383, 16383, 16383]),
indices=tensor([  2776,   2808,   5974,  ..., 309906, 579450, 884128]))
torch.equal(a.indices, b.indices)
False

There was the source of discrepancy! The order of indices after being sorted!

Is this the case for all sort calls between these PyTorch versions? To test this, I ran the following code with each PyTorch install (torch==2.0.1 and torch==2.1.0) and saved t before and after .sort was called:

import torch
torch.manual_seed(42)
t = torch.randint(low=0, high=16383, size=(1146937,))
t = t.sort()

For both PyTorch versions, t.indices was not equal (i.e. torch.equal was False). This is evidence that sort’s behavior changes from 2.0.1 to 2.1.0. After keyword searching the release notes, I couldn’t find a PR that could be the culprit.

Thankfully, colbert-ai is robust to such changes! Since we only care about the unique passage IDs (and number of passage IDs) for ivf and ivf_lengths, respectively, and not the order of token IDs, this PyTorch change does not break the indexing pipeline.