RAGatouille/ColBERT Indexing Deep Dive

python
information retrieval
machine learning
deep learning
In this technical deep dive, I recreate the entire RAG indexing pipeline by directly using the internal methods of RAGatouille and ColBERT libraries. I methodically build all indexing artifacts from scratch - from processing document collections and sampling embeddings to calculating centroids with k-means clustering and compressing document vectors with residual encoding. If you’re curious about what actually happens when you call RAG.index(), this video breaks down the full process step by step.
Author

Vishal Bakshi

Published

March 12, 2025

Background

In this blog post I dive deeply into the internals of the RAGatouille and ColBERT libraries to understand the intermediate steps taken when building an index for a collection of documents.

!pip install datasets ragatouille -qq
from datasets import load_dataset
from ragatouille import RAGPretrainedModel
from fastcore.utils import Path
import torch
import srsly
import uuid
from ragatouille.data import CorpusProcessor
from llama_index.core.text_splitter import SentenceSplitter
import pandas as pd
from ragatouille.models.index import PLAIDModelIndex
from colbert.infra import ColBERTConfig, RunConfig
from colbert.data.collection import Collection
from colbert.modeling.checkpoint import Checkpoint
from colbert.indexing.collection_encoder import CollectionEncoder
from colbert.indexing.collection_indexer import CollectionIndexer
import numpy as np
import random
from colbert.indexing.collection_indexer import compute_faiss_kmeans
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
from colbert.indexing.codecs.residual import ResidualCodec
from colbert.utils.utils import flatten
import tqdm
def set_all_seeds(seed=123):
    """Set seeds for all random number generators"""
    import random
    import numpy as np
    import torch
    import os

    # Python's random module
    random.seed(seed)

    # NumPy
    np.random.seed(seed)

    # PyTorch
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)  # For multi-GPU

    # Set PYTHONHASHSEED for reproducibility across runs
    os.environ['PYTHONHASHSEED'] = str(seed)

    # Set deterministic algorithms for PyTorch
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    print(f"All seeds set to {seed}")

# Call this at the beginning of your script
set_all_seeds(123)
All seeds set to 123

RAGatouille RAG.index

Everything in this notebook will be compared to what’s generated with RAG.index.

For this exercise, I’ll use 1000 passages from the UKPLab/DAPR ConditionalQA dataset.

passages = load_dataset("UKPLab/dapr", f"ConditionalQA-corpus", split="test[:1000]")
passages
/usr/local/lib/python3.11/dist-packages/huggingface_hub/utils/_auth.py:94: UserWarning: 
The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.
  warnings.warn(
Dataset({
    features: ['_id', 'text', 'title', 'doc_id', 'paragraph_no', 'total_paragraphs', 'is_candidate'],
    num_rows: 1000
})
passages[0]['text']
'Overview'
model_nm = "answerdotai/answerai-colbert-small-v1"
RAG = RAGPretrainedModel.from_pretrained(model_nm)
index_path = RAG.index(index_name="cqa_index", collection=passages['text'])
/usr/local/lib/python3.11/dist-packages/colbert/utils/amp.py:12: FutureWarning: `torch.cuda.amp.GradScaler(args...)` is deprecated. Please use `torch.amp.GradScaler('cuda', args...)` instead.
  self.scaler = torch.cuda.amp.GradScaler()
---- WARNING! You are using PLAID with an experimental replacement for FAISS for greater compatibility ----
This is a behaviour change from RAGatouille 0.8.0 onwards.
This works fine for most users and smallish datasets, but can be considerably slower than FAISS and could cause worse results in some situations.
If you're confident with FAISS working on your machine, pass use_faiss=True to revert to the FAISS-using behaviour.
--------------------


[Mar 13, 01:15:09] #> Creating directory .ragatouille/colbert/indexes/cqa_index 


[Mar 13, 01:15:11] [0]       #> Encoding 1000 passages..
/usr/local/lib/python3.11/dist-packages/colbert/utils/amp.py:15: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
  return torch.cuda.amp.autocast() if self.activated else NullContextManager()
[Mar 13, 01:15:14] [0]       avg_doclen_est = 15.197999954223633     len(local_sample) = 1,000
[Mar 13, 01:15:14] [0]       Creating 1,024 partitions.
[Mar 13, 01:15:14] [0]       *Estimated* 15,197 embeddings.
[Mar 13, 01:15:14] [0]       #> Saving the indexing plan to .ragatouille/colbert/indexes/cqa_index/plan.json ..
/usr/local/lib/python3.11/dist-packages/colbert/indexing/collection_indexer.py:256: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  sub_sample = torch.load(sub_sample_path)
used 20 iterations (0.5189s) to cluster 14439 items into 1024 clusters
[Mar 13, 01:15:14] Loading decompress_residuals_cpp extension (set COLBERT_LOAD_TORCH_EXTENSION_VERBOSE=True for more info)...
/usr/local/lib/python3.11/dist-packages/torch/utils/cpp_extension.py:1964: UserWarning: TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation. 
If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].
  warnings.warn(
[Mar 13, 01:16:51] Loading packbits_cpp extension (set COLBERT_LOAD_TORCH_EXTENSION_VERBOSE=True for more info)...
/usr/local/lib/python3.11/dist-packages/torch/utils/cpp_extension.py:1964: UserWarning: TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation. 
If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].
  warnings.warn(
[0.015, 0.016, 0.015, 0.015, 0.013, 0.016, 0.015, 0.016, 0.017, 0.014, 0.013, 0.015, 0.017, 0.014, 0.015, 0.017, 0.014, 0.015, 0.015, 0.014, 0.015, 0.016, 0.015, 0.015, 0.014, 0.014, 0.015, 0.015, 0.015, 0.015, 0.014, 0.014, 0.015, 0.014, 0.015, 0.015, 0.014, 0.015, 0.016, 0.015, 0.014, 0.015, 0.015, 0.014, 0.014, 0.017, 0.016, 0.017, 0.014, 0.015, 0.016, 0.015, 0.016, 0.016, 0.012, 0.015, 0.016, 0.015, 0.016, 0.016, 0.015, 0.015, 0.016, 0.014, 0.015, 0.017, 0.016, 0.015, 0.014, 0.015, 0.015, 0.015, 0.014, 0.016, 0.016, 0.016, 0.014, 0.015, 0.015, 0.014, 0.014, 0.014, 0.016, 0.015, 0.016, 0.015, 0.014, 0.014, 0.014, 0.015, 0.016, 0.014, 0.014, 0.016, 0.014, 0.015]
/usr/local/lib/python3.11/dist-packages/colbert/indexing/codecs/residual.py:141: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  centroids = torch.load(centroids_path, map_location='cpu')
/usr/local/lib/python3.11/dist-packages/colbert/indexing/codecs/residual.py:142: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  avg_residual = torch.load(avgresidual_path, map_location='cpu')
/usr/local/lib/python3.11/dist-packages/colbert/indexing/codecs/residual.py:143: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  bucket_cutoffs, bucket_weights = torch.load(buckets_path, map_location='cpu')
0it [00:00, ?it/s]
[Mar 13, 01:18:16] [0]       #> Encoding 1000 passages..
/usr/local/lib/python3.11/dist-packages/colbert/utils/amp.py:15: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
  return torch.cuda.amp.autocast() if self.activated else NullContextManager()
1it [00:00,  1.39it/s]
  0%|          | 0/1 [00:00<?, ?it/s]/usr/local/lib/python3.11/dist-packages/colbert/indexing/codecs/residual_embeddings.py:86: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
  return torch.load(codes_path, map_location='cpu')
100%|██████████| 1/1 [00:00<00:00, 787.81it/s]
[Mar 13, 01:18:16] #> Optimizing IVF to store map from centroids to list of pids..
[Mar 13, 01:18:16] #> Building the emb2pid mapping..
[Mar 13, 01:18:16] len(emb2pid) = 15198

100%|██████████| 1024/1024 [00:00<00:00, 61154.86it/s]
[Mar 13, 01:18:16] #> Saved optimized IVF to .ragatouille/colbert/indexes/cqa_index/ivf.pid.pt
Done indexing!
index_path = Path(index_path)
index_path
Path('.ragatouille/colbert/indexes/cqa_index')
for o in index_path.ls(): print(o)
.ragatouille/colbert/indexes/cqa_index/buckets.pt
.ragatouille/colbert/indexes/cqa_index/collection.json
.ragatouille/colbert/indexes/cqa_index/metadata.json
.ragatouille/colbert/indexes/cqa_index/ivf.pid.pt
.ragatouille/colbert/indexes/cqa_index/doclens.0.json
.ragatouille/colbert/indexes/cqa_index/0.residuals.pt
.ragatouille/colbert/indexes/cqa_index/centroids.pt
.ragatouille/colbert/indexes/cqa_index/0.metadata.json
.ragatouille/colbert/indexes/cqa_index/plan.json
.ragatouille/colbert/indexes/cqa_index/0.codes.pt
.ragatouille/colbert/indexes/cqa_index/avg_residual.pt
.ragatouille/colbert/indexes/cqa_index/pid_docid_map.json

While it’s a bit tedious to do so (since I’m chomping at the bit to get to the deep dive!) I think it’s worth analyzing the contents of each of these files, as we’ll be recreating them in this notebook.

buckets.pt

Looking at Line 160 of the ColBERT repo’s residual.py, buckets.py stores bucket_cutoffs and bucket_weights. We’ll go into detail into what these exactly are later on.

_bucket_cutoffs, _bucket_weights = torch.load(index_path/'buckets.pt', weights_only=True)
_bucket_cutoffs.shape, _bucket_weights.shape
(torch.Size([15]), torch.Size([16]))
_bucket_cutoffs
tensor([-0.0307, -0.0205, -0.0146, -0.0099, -0.0064, -0.0037, -0.0016,  0.0000,
         0.0017,  0.0038,  0.0066,  0.0102,  0.0150,  0.0211,  0.0313],
       device='cuda:0')
_bucket_weights
tensor([-0.0411, -0.0247, -0.0173, -0.0121, -0.0081, -0.0050, -0.0026, -0.0007,
         0.0007,  0.0027,  0.0052,  0.0083,  0.0124,  0.0178,  0.0253,  0.0421],
       device='cuda:0', dtype=torch.float16)

0.residuals.pt

IIUC, there are 15198 tokens in our collection, each with a 48-dimension vector representation, and each integer value represents two 4-bit codes that each correspond to a quantized value. So there are actually 96 values in each vector.

_residuals = torch.load(index_path/'0.residuals.pt', weights_only=True)
_residuals
tensor([[ 30, 225, 225,  ..., 238, 238,  30],
        [230,  22, 158,  ..., 233, 106, 170],
        [238, 238, 238,  ..., 238, 238, 238],
        ...,
        [ 43,  22,  23,  ..., 104,  31, 208],
        [222, 254,  91,  ..., 128,   8, 189],
        [229,  82,  22,  ..., 170,  94, 154]], dtype=torch.uint8)
_residuals.shape
torch.Size([15198, 48])
48*4/2
96.0

ivf.pid.pt

IIRC, ivf contains a flattened sequence of passage IDs corresponding to each centroid. There are 1024 centroids and the first 8 passage IDs in ivf correspond to the 0-th centroid.

_ivf, _ivf_lengths = torch.load(index_path/'ivf.pid.pt', weights_only=True)
_ivf.shape, _ivf_lengths.shape
(torch.Size([11759]), torch.Size([1024]))
_ivf[:5]
tensor([895, 896, 902, 904, 909], dtype=torch.int32)
_ivf_lengths[0]
tensor(8)

0.metadata.json

There are 1000 passages totaling 15198 tokens.

def load_json(path, filename): return srsly.read_json(str(Path(path) / filename))
load_json(index_path, "0.metadata.json")
{'passage_offset': 0,
 'num_passages': 1000,
 'num_embeddings': 15198,
 'embedding_offset': 0}

collection.json

This JSON contains, as a list, the strings of the 1000 passages in our collection.

_collection = load_json(index_path, "collection.json")
len(_collection), _collection[0]
(1000, 'Overview')

avg_residual.pt

I believe this is the average residual across the 15198 tokens (i.e. the average distance in vector-space between the tokens and their closest centroids).

_avg_residual = torch.load(index_path/'avg_residual.pt', weights_only=True)
_avg_residual
tensor(0.0150, device='cuda:0', dtype=torch.float16)

doclens.0.json

This contains a mapping (list) between passages IDs (indices) and the number of tokens in the document (values).

_doclens = load_json(index_path, "doclens.0.json")
len(_doclens), _doclens[:5]
(1000, [4, 20, 18, 23, 8])
sum(doclens)
15198

metadata.json

Lots of information in here, will highlight the number of centroids and the number of token embeddings in the collection:

'num_partitions': 1024,
'num_embeddings': 15198
load_json(index_path, "metadata.json")
{'config': {'query_token_id': '[unused0]',
  'doc_token_id': '[unused1]',
  'query_token': '[Q]',
  'doc_token': '[D]',
  'ncells': None,
  'centroid_score_threshold': None,
  'ndocs': None,
  'load_index_with_mmap': False,
  'index_path': None,
  'index_bsize': 32,
  'nbits': 4,
  'kmeans_niters': 20,
  'resume': False,
  'pool_factor': 1,
  'clustering_mode': 'hierarchical',
  'protected_tokens': 0,
  'similarity': 'cosine',
  'bsize': 64,
  'accumsteps': 1,
  'lr': 1e-05,
  'maxsteps': 15626,
  'save_every': None,
  'warmup': 781,
  'warmup_bert': None,
  'relu': False,
  'nway': 32,
  'use_ib_negatives': False,
  'reranker': False,
  'distillation_alpha': 1.0,
  'ignore_scores': False,
  'model_name': 'answerdotai/AnswerAI-ColBERTv2.5-small',
  'query_maxlen': 32,
  'attend_to_mask_tokens': False,
  'interaction': 'colbert',
  'dim': 96,
  'doc_maxlen': 256,
  'mask_punctuation': True,
  'checkpoint': 'answerdotai/answerai-colbert-small-v1',
  'triples': '/home/bclavie/colbertv2.5_en/data/msmarco/triplets.jsonl',
  'collection': ['list with 1000 elements starting with...',
   ['Overview',
    'You can only make a claim for Child Tax Credit if you already get Working Tax Credit.',
    'If you cannot apply for Child Tax Credit, you can apply for Universal Credit instead.']],
  'queries': '/home/bclavie/colbertv2.5_en/data/msmarco/queries.tsv',
  'index_name': 'cqa_index',
  'overwrite': False,
  'root': '.ragatouille/',
  'experiment': 'colbert',
  'index_root': None,
  'name': '2025-03/13/01.14.35',
  'rank': 0,
  'nranks': 1,
  'amp': True,
  'gpus': 1,
  'avoid_fork_if_possible': False},
 'num_chunks': 1,
 'num_partitions': 1024,
 'num_embeddings': 15198,
 'avg_doclen': 15.198,
 'RAGatouille': {'index_config': {'index_type': 'PLAID',
   'index_name': 'cqa_index'}}}

centroids.pt

There are 1024 96-dimension centroid vectors stored.

_centroids = torch.load(index_path/'centroids.pt', weights_only=True)
_centroids.shape
torch.Size([1024, 96])

They store the full uncompressed values for the centroids.

_centroids[0][:5]
tensor([-0.0649,  0.1193, -0.0551,  0.0561, -0.0826], device='cuda:0',
       dtype=torch.float16)

0.codes.pt

I believe this is a mapping (list) between tokens (indices) and centroid IDs (values).

_codes = torch.load(index_path/'0.codes.pt', weights_only=True)
_codes.shape
torch.Size([15198])
_codes[:5]
tensor([138, 843, 273, 138, 561], dtype=torch.int32)

pid_docid_map.json

A mapping between passage ID (0-999) and document ID (UUID).

_pid_docid_map = load_json(index_path, "pid_docid_map.json")
_pid_docid_map['999']
'2be086c6-04cc-4d73-b372-08236f76cbe6'

plan.json

This seems to contain the same information as metadata.json.

_plan = load_json(index_path, "plan.json")
_plan
{'config': {'query_token_id': '[unused0]',
  'doc_token_id': '[unused1]',
  'query_token': '[Q]',
  'doc_token': '[D]',
  'ncells': None,
  'centroid_score_threshold': None,
  'ndocs': None,
  'load_index_with_mmap': False,
  'index_path': None,
  'index_bsize': 32,
  'nbits': 4,
  'kmeans_niters': 20,
  'resume': False,
  'pool_factor': 1,
  'clustering_mode': 'hierarchical',
  'protected_tokens': 0,
  'similarity': 'cosine',
  'bsize': 64,
  'accumsteps': 1,
  'lr': 1e-05,
  'maxsteps': 15626,
  'save_every': None,
  'warmup': 781,
  'warmup_bert': None,
  'relu': False,
  'nway': 32,
  'use_ib_negatives': False,
  'reranker': False,
  'distillation_alpha': 1.0,
  'ignore_scores': False,
  'model_name': 'answerdotai/AnswerAI-ColBERTv2.5-small',
  'query_maxlen': 32,
  'attend_to_mask_tokens': False,
  'interaction': 'colbert',
  'dim': 96,
  'doc_maxlen': 256,
  'mask_punctuation': True,
  'checkpoint': 'answerdotai/answerai-colbert-small-v1',
  'triples': '/home/bclavie/colbertv2.5_en/data/msmarco/triplets.jsonl',
  'collection': ['list with 1000 elements starting with...',
   ['Overview',
    'You can only make a claim for Child Tax Credit if you already get Working Tax Credit.',
    'If you cannot apply for Child Tax Credit, you can apply for Universal Credit instead.']],
  'queries': '/home/bclavie/colbertv2.5_en/data/msmarco/queries.tsv',
  'index_name': 'cqa_index',
  'overwrite': False,
  'root': '.ragatouille/',
  'experiment': 'colbert',
  'index_root': None,
  'name': '2025-03/13/01.14.35',
  'rank': 0,
  'nranks': 1,
  'amp': True,
  'gpus': 1,
  'avoid_fork_if_possible': False},
 'num_chunks': 1,
 'num_partitions': 1024,
 'num_embeddings_est': 15197.999954223633,
 'avg_doclen_est': 15.197999954223633}

In the following sections, I’ll try to recreate each of these index_path elements.

_process_corpus

Inside RAG.index, _process_corpus is called on the documents and document IDs.

passage_ids = [str(uuid.uuid4()) for _ in range(len(passages))]
passage_ids[0]
'd4cdfec5-a949-43e0-94b3-feb24caeac5e'

Use the corpus processor to convert the passages into {'document_id': '...', 'content': '...'} dictionaries with 256-token max length.

cp = CorpusProcessor()
cp
<ragatouille.data.corpus_processor.CorpusProcessor at 0x794fa0323690>
collection_with_ids = cp.process_corpus(passages['text'], passage_ids, chunk_size=256)
len(collection_with_ids), collection_with_ids[0]
(1000,
 {'document_id': 'd4cdfec5-a949-43e0-94b3-feb24caeac5e',
  'content': 'Overview'})

As a brief aside, I’ll take a look at the maximum token length of the passages.

node_parser = SentenceSplitter(chunk_size=256)
node_parser._token_size
llama_index.core.node_parser.text.sentence.SentenceSplitter._token_size
def _token_size(text: str) -> int
<no docstring>
tk_szs = []
for p in passages['text']: tk_szs.append(node_parser._token_size(p))
pd.Series(tk_szs).describe()
0
count 1000.000000
mean 13.217000
std 10.192635
min 1.000000
25% 5.000000
50% 10.000000
75% 19.000000
max 65.000000

This collection of passages has relatively short passages (a max of 65 tokens).

_process_corpus then creates

pid_docid_map = {index: item["document_id"] for index, item in enumerate(collection_with_ids)}
pid_docid_map[999]
'096e054e-3041-4881-ac48-b20f1804f650'

This matches the content of pid_docid_map.json.

_process_corpus also defines a list of strings, collection:

collection = [x["content"] for x in collection_with_ids]
collection[0]
'Overview'

_process_corpus also calls _process_metadata which defines docid_metadata_map as None when document_metadatas is None (which it is in our case).

docid_metadata_map = None

RAG.index Internals

After calling _process_corpus, RAG.index calls model.index, where model is:

instance.model = ColBERT(
            pretrained_model_name_or_path, n_gpu, index_root=index_root, verbose=verbose
        )

ColBERT.index in turn calls:

ModelIndexFactory.construct

By default the type of index is PLAID, so the following is called:

PLAIDModelIndex(config).build(
            checkpoint, collection, index_name, overwrite, verbose, **kwargs
        )

PLAIDModelIndex.build

A couple of key configuration values are set in this method, starting with the bsize (which I think is batch size?) defaulting to 32.

PLAIDModelIndex._DEFAULT_INDEX_BSIZE
32
bsize = PLAIDModelIndex._DEFAULT_INDEX_BSIZE
bsize
32

The size of compressed residual embedding values is determined based on the size of the collection.

if len(collection) < 10000: nbits = 4
nbits
4

It then defines a ColBERTConfig object, which I believe is instantiated as follows when the ColBERT checkpoint is instantiated:

ckpt_config = ColBERTConfig.load_from_checkpoint(str(model_nm))
ckpt_config
ColBERTConfig(query_token_id='[unused0]', doc_token_id='[unused1]', query_token='[Q]', doc_token='[D]', ncells=None, centroid_score_threshold=None, ndocs=None, load_index_with_mmap=False, index_path=None, index_bsize=64, nbits=1, kmeans_niters=4, resume=False, pool_factor=1, clustering_mode='hierarchical', protected_tokens=0, similarity='cosine', bsize=32, accumsteps=1, lr=1e-05, maxsteps=15626, save_every=None, warmup=781, warmup_bert=None, relu=False, nway=32, use_ib_negatives=False, reranker=False, distillation_alpha=1.0, ignore_scores=False, model_name='answerdotai/AnswerAI-ColBERTv2.5-small', query_maxlen=32, attend_to_mask_tokens=False, interaction='colbert', dim=96, doc_maxlen=300, mask_punctuation=True, checkpoint='/root/.cache/huggingface/hub/models--answerdotai--answerai-colbert-small-v1/snapshots/be1703c55532145a844da800eea4c9a692d7e267/', triples='/home/bclavie/colbertv2.5_en/data/msmarco/triplets.jsonl', collection='/home/bclavie/colbertv2.5_en/data/msmarco/collection.tsv', queries='/home/bclavie/colbertv2.5_en/data/msmarco/queries.tsv', index_name=None, overwrite=False, root='/home/bclavie/colbertv2.5_en/experiments', experiment='minicolbertv2.5', index_root=None, name='2024-08/07/08.16.20', rank=0, nranks=4, amp=True, gpus=4, avoid_fork_if_possible=False)
config = ColBERTConfig.from_existing(ckpt_config)
config
ColBERTConfig(query_token_id='[unused0]', doc_token_id='[unused1]', query_token='[Q]', doc_token='[D]', ncells=None, centroid_score_threshold=None, ndocs=None, load_index_with_mmap=False, index_path=None, index_bsize=64, nbits=1, kmeans_niters=4, resume=False, pool_factor=1, clustering_mode='hierarchical', protected_tokens=0, similarity='cosine', bsize=32, accumsteps=1, lr=1e-05, maxsteps=15626, save_every=None, warmup=781, warmup_bert=None, relu=False, nway=32, use_ib_negatives=False, reranker=False, distillation_alpha=1.0, ignore_scores=False, model_name='answerdotai/AnswerAI-ColBERTv2.5-small', query_maxlen=32, attend_to_mask_tokens=False, interaction='colbert', dim=96, doc_maxlen=300, mask_punctuation=True, checkpoint='/root/.cache/huggingface/hub/models--answerdotai--answerai-colbert-small-v1/snapshots/be1703c55532145a844da800eea4c9a692d7e267/', triples='/home/bclavie/colbertv2.5_en/data/msmarco/triplets.jsonl', collection='/home/bclavie/colbertv2.5_en/data/msmarco/collection.tsv', queries='/home/bclavie/colbertv2.5_en/data/msmarco/queries.tsv', index_name=None, overwrite=False, root='/home/bclavie/colbertv2.5_en/experiments', experiment='minicolbertv2.5', index_root=None, name='2024-08/07/08.16.20', rank=0, nranks=4, amp=True, gpus=4, avoid_fork_if_possible=False)

We also need to create a RunConfig object:

run_config = RunConfig(nranks=-1, experiment="colbert", root=".ragatouille/")
run_config
RunConfig(overwrite=False, root='.ragatouille/', experiment='colbert', index_root=None, name='2025-03/13/01.14.35', rank=0, nranks=-1, amp=True, gpus=1, avoid_fork_if_possible=False)

A couple more config values are set:

config.avoid_fork_if_possible = True

if len(collection) > 100000:
    config.kmeans_niters = 4
elif len(collection) > 50000:
    config.kmeans_niters = 10
else:
    config.kmeans_niters = 20
config.avoid_fork_if_possible, config.kmeans_niters
(True, 20)

After determining whether the PyTorch or FAISS k-means implementation will be used, Indexer.index is called.

Indexer.index

The Indexer comes from the ColBERT repo, so this is essentially the connection point between the RAGatouille and ColBERT libraries.

Launcher

Inside Indexer.index, __launch is called, from within which a Launcher instance is created with the encode function.

I’m a bit fuzzy on the next part but I’ll give it a shot:

when Launcher.launch is called, the following two lines are called (where callee is the encode function):

args_ = (self.callee, port, return_value_queue, new_config, *args)
all_procs.append(mp.Process(target=setup_new_process, args=args_))

setup_new_process contains the following lines:

with Run().context(config, inherit_config=False):
    return_val = callee(config, *args)

With callee being called, let’s look at the function that callee is : encode, which is part of the collection_indexer.py file.

encode

def encode(config, collection, shared_lists, shared_queues, verbose: int = 3):
    encoder = CollectionIndexer(config=config, collection=collection, verbose=verbose)
    encoder.run(shared_lists)

This leads us to encoder.run which is CollectionIndexer.run. But before that, we need to look at how the collection is transformed when CollectionIndexer is instantiated.

CollectionIndexer.__init__

There are two important objects created when the CollectionIndexer is instantiated. First is the Collection object, which turns our list collection:

type(collection)
list

into a Collection object:

collection = Collection.cast(collection)
collection
<colbert.data.collection.Collection at 0x794fa037e410>
type(collection)
colbert.data.collection.Collection
def __init__(path=None, data=None)
<no docstring>

Next, it creates a CollectionEncoder object:

checkpoint = Checkpoint(config.checkpoint, colbert_config=config)
encoder = CollectionEncoder(config, checkpoint)
/usr/local/lib/python3.11/dist-packages/colbert/utils/amp.py:12: FutureWarning: `torch.cuda.amp.GradScaler(args...)` is deprecated. Please use `torch.amp.GradScaler('cuda', args...)` instead.
  self.scaler = torch.cuda.amp.GradScaler()

checkpoint is our model:

checkpoint
Checkpoint(
  (model): HF_ColBERT(
    (linear): Linear(in_features=384, out_features=96, bias=False)
    (bert): BertModel(
      (embeddings): BertEmbeddings(
        (word_embeddings): Embedding(30522, 384, padding_idx=0)
        (position_embeddings): Embedding(512, 384)
        (token_type_embeddings): Embedding(2, 384)
        (LayerNorm): LayerNorm((384,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (encoder): BertEncoder(
        (layer): ModuleList(
          (0-11): 12 x BertLayer(
            (attention): BertAttention(
              (self): BertSdpaSelfAttention(
                (query): Linear(in_features=384, out_features=384, bias=True)
                (key): Linear(in_features=384, out_features=384, bias=True)
                (value): Linear(in_features=384, out_features=384, bias=True)
                (dropout): Dropout(p=0.1, inplace=False)
              )
              (output): BertSelfOutput(
                (dense): Linear(in_features=384, out_features=384, bias=True)
                (LayerNorm): LayerNorm((384,), eps=1e-12, elementwise_affine=True)
                (dropout): Dropout(p=0.1, inplace=False)
              )
            )
            (intermediate): BertIntermediate(
              (dense): Linear(in_features=384, out_features=1536, bias=True)
              (intermediate_act_fn): GELUActivation()
            )
            (output): BertOutput(
              (dense): Linear(in_features=1536, out_features=384, bias=True)
              (LayerNorm): LayerNorm((384,), eps=1e-12, elementwise_affine=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
          )
        )
      )
      (pooler): BertPooler(
        (dense): Linear(in_features=384, out_features=384, bias=True)
        (activation): Tanh()
      )
    )
  )
)

The CollectionEncoder will be used later on to encode the passages.

encoder
<colbert.indexing.collection_encoder.CollectionEncoder at 0x794fa0683dd0>

Next we’ll dive into CollectionIndexer.run within which all of the indexing operations that I’m most interested in take place, starting with setup.

CollectionIndexer.run

CollectionIndexer.setup

'''
Calculates and saves plan.json for the whole collection.

plan.json { config, num_chunks, num_partitions, num_embeddings_est, avg_doclen_est}
num_partitions is the number of centroids to be generated.
'''

We’ll see where num_chunks is used, but for now I’ll just define it:

num_chunks = int(np.ceil(len(collection) / collection.get_chunksize()))
len(collection), collection.get_chunksize(), num_chunks
(1000, 1001, 1)

Next we look at _sample_pids and _sample_embeddings which are later clustered to get our centroids.

_sample_pids

num_passages = len(collection)
num_passages
1000

It’s awesome to see one of the heuristics mentioned in the ColBERTv2 paper:

To reduce memory consumption, we apply k-means clustering to the embeddings produced by invoking our BERT encoder over only a sample of all passages, proportional to the square root of the collection size, an approach we found to perform well in practice.

typical_doclen = 120
sampled_pids = 16 * np.sqrt(typical_doclen * num_passages)
sampled_pids
5542.562584220407
sampled_pids = min(1 + int(sampled_pids), num_passages)

In this case because my toy collection is so small (1000 passages) we will use all of them for centroid clustering.

sampled_pids = random.sample(range(num_passages), sampled_pids)
sampled_pids = set(sampled_pids)
len(sampled_pids), min(sampled_pids), max(sampled_pids)
(1000, 0, 999)

_sample_embeddings

local_pids = collection.enumerate(rank=config.rank)
local_pids
<generator object Collection.enumerate at 0x794fa067dc40>

sampled_pids contains all of our passages

local_sample = [passage for pid, passage in local_pids if pid in sampled_pids]
len(local_sample)
1000

Next come another critical process—encoding our passages!

CollectionEncoder.encode_passages

Inside encode_passages we call checkpoint.docFromText.

checkpoint.docFromText

And inside checkpoint.docFromText we call checkpoint.doc

checkpoint.doc

Inside ColBERT.doc we finally call the lowest-level method in this chain:

D = self.bert(input_ids, attention_mask=attention_mask)[0]

One key point to visualize is that the BERT output is normalized:

 D = torch.nn.functional.normalize(D, p=2, dim=2)

I’ll zoom out again and call encode_passages.

local_sample_embs, doclens = encoder.encode_passages(local_sample)
[Mar 13, 01:18:18] [0]       #> Encoding 1000 passages..
/usr/local/lib/python3.11/dist-packages/colbert/utils/amp.py:15: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
  return torch.cuda.amp.autocast() if self.activated else NullContextManager()
local_sample_embs.shape
torch.Size([15198, 96])

Note that the token embeddings are a unit vector:

local_sample_embs[0].norm()
tensor(1., dtype=torch.float16)
len(doclens), doclens[:5]
(1000, [4, 20, 18, 23, 8])

We have 15198 token embeddings (embedded into answerai-colbert-small-v1’s 96-dimension space) and a mapping (list) of passage ID (indices) to number of tokens (values).

avg_doclen_est = sum(doclens) / len(doclens) if doclens else 0
avg_doclen_est
15.198

On average, each passage (document) is 15 tokens long.

Zooming back out to CollectionIndexer.setup we have a few more steps before our planning is complete:

num_passages = len(collection)
num_embeddings_est = num_passages * avg_doclen_est
num_partitions = int(2 ** np.floor(np.log2(16 * np.sqrt(num_embeddings_est))))
num_passages, num_embeddings_est, num_partitions
(1000, 15198.0, 1024)

num_partitions is the number of clusters we will cluster our 15198 token embeddings into.

This is the information that was in plan.json in index_path (in addition to the other ColBERTConfig information).

CollectionIndexer.train

After setup is complete, the next method called in run is train.

The first step in train is to split our local_sample_embs into sample and sample_heldout.

local_sample_embs = local_sample_embs[torch.randperm(local_sample_embs.size(0))]
heldout_fraction = 0.05
heldout_size = int(min(heldout_fraction * local_sample_embs.size(0), 50_000))
heldout_size
759
sample, sample_heldout = local_sample_embs.split([local_sample_embs.size(0) - heldout_size, heldout_size], dim=0)
sample.shape, sample_heldout.shape
(torch.Size([14439, 96]), torch.Size([759, 96]))

compute_faiss_kmeans

Next we get the centroids using compute_faiss_kmeans

args_ = [config.dim, num_partitions, config.kmeans_niters, [[sample]]]
centroids = compute_faiss_kmeans(*args_)
centroids.shape
torch.Size([1024, 96])

We then normalize the centroids

centroids = torch.nn.functional.normalize(centroids, dim=-1)
centroids.shape, centroids[0].norm()
(torch.Size([1024, 96]), tensor(1.))
centroids = centroids.half()

I was hoping to get the same values as RAG.index centroids by setting seeds at the start of this notebook, but I am not getting the same result.

_centroids[0][:5]
tensor([-0.0649,  0.1193, -0.0551,  0.0561, -0.0826], device='cuda:0',
       dtype=torch.float16)
centroids[0][:5]
tensor([-0.0587,  0.0379, -0.0847, -0.0224, -0.0636], dtype=torch.float16)

I’ll use PCA to compare the two sets of centroids:

# Project to 2D
pca = PCA(n_components=2)
prev_2d = pca.fit_transform(_centroids.cpu().numpy())
new_2d = pca.transform(centroids.cpu().numpy())

# Plot
plt.figure(figsize=(10, 8))
plt.scatter(prev_2d[:, 0], prev_2d[:, 1], alpha=0.5, label='Previous')
plt.scatter(new_2d[:, 0], new_2d[:, 1], alpha=0.5, label='New')
plt.legend()
plt.title('PCA projection of centroids')
plt.show()

I’m not super familiar with interpreting PCA plots, so I asked Claude what it thought about this result:

I would describe this as showing “good structural similarity but with expected local variations.” The centroids aren’t identical (which would show perfect overlap), but they capture similar patterns in the embedding space. This suggests that while individual centroid positions differ, the overall index structure should perform similarly for retrieval tasks.

For now, I’ll consider this part of indexing completing, as we have generated similar contents to what’s in centroids.pt.

_compute_avg_residual

This next section was quite eye opening for me, as it was the first time I understood how quantization is implemented.

The ResidualCodec does all of the compression/binarization/decompress of residuals.

compressor = ResidualCodec(config=config, centroids=centroids, avg_residual=None)
compressor
<colbert.indexing.codecs.residual.ResidualCodec at 0x794f9aacfdd0>
heldout_reconstruct = compressor.compress_into_codes(sample_heldout, out_device='cuda' )
heldout_reconstruct.shape
torch.Size([759])

compress_into_codes finds the nearest centroid IDs to the token embeddings. It does so using cosine similarity:

indices = (self.centroids @ batch.T.cuda().half()).max(dim=0).indices.to(device=out_device)
heldout_reconstruct[:5]
tensor([633, 667, 738, 641, 443], device='cuda:0')

lookup_centroids gets the full vectors related to the centroid IDs in heldout_reconstruct

heldout_reconstruct = compressor.lookup_centroids(heldout_reconstruct, out_device='cuda')
heldout_reconstruct.shape
torch.Size([759, 96])

The residual between the heldout token embeddings and the closest centroids is then calculated:

heldout_avg_residual = sample_heldout.cuda() - heldout_reconstruct
heldout_avg_residual.shape
torch.Size([759, 96])

We then calculate the average residual vector (96 dimensions):

avg_residual = torch.abs(heldout_avg_residual).mean(dim=0).cpu()
avg_residual.shape
torch.Size([96])

The average residual is somewhat similar to the stored value in avg_residual.pt.

_avg_residual, avg_residual.mean()
(tensor(0.0150, device='cuda:0', dtype=torch.float16),
 tensor(0.0158, dtype=torch.float16))

To match the RAG.index defaults, I’m going to set nbits to 4.

config.nbits
1
config.nbits = 4
config.nbits
4
num_options = 2 ** config.nbits
config.nbits, num_options
(4, 16)

A 4-bit value has four 0 or 1 values and there are 16 possible combinations:

Binary
0000
0001
0010
0011
0100
0101
0110
0111
1000
1001
1010
1011
1100
1101
1110
1111

We split 0-to-1 into 16 equal parts:

quantiles = torch.arange(0, num_options, device=heldout_avg_residual.device) * (1 / num_options)
quantiles.shape, quantiles
(torch.Size([16]),
 tensor([0.0000, 0.0625, 0.1250, 0.1875, 0.2500, 0.3125, 0.3750, 0.4375, 0.5000,
         0.5625, 0.6250, 0.6875, 0.7500, 0.8125, 0.8750, 0.9375],
        device='cuda:0'))
bucket_cutoffs_quantiles, bucket_weights_quantiles = quantiles[1:], quantiles + (0.5 / num_options)
bucket_cutoffs_quantiles, bucket_weights_quantiles
(tensor([0.0625, 0.1250, 0.1875, 0.2500, 0.3125, 0.3750, 0.4375, 0.5000, 0.5625,
         0.6250, 0.6875, 0.7500, 0.8125, 0.8750, 0.9375], device='cuda:0'),
 tensor([0.0312, 0.0938, 0.1562, 0.2188, 0.2812, 0.3438, 0.4062, 0.4688, 0.5312,
         0.5938, 0.6562, 0.7188, 0.7812, 0.8438, 0.9062, 0.9688],
        device='cuda:0'))

IIUC, the weights’ quantiles are the midpoints between adjacent cutoffs’ quantiles.

(bucket_cutoffs_quantiles[1] + bucket_cutoffs_quantiles[2])/2
tensor(0.1562, device='cuda:0')
(bucket_cutoffs_quantiles[3] + bucket_cutoffs_quantiles[4])/2
tensor(0.2812, device='cuda:0')
(bucket_cutoffs_quantiles[-2] + bucket_cutoffs_quantiles[-1])/2
tensor(0.9062, device='cuda:0')
bucket_cutoffs = heldout_avg_residual.float().quantile(bucket_cutoffs_quantiles)
bucket_weights = heldout_avg_residual.float().quantile(bucket_weights_quantiles)

IIUC, bucket_cutoffs are the values with which we can group our (flattened) heldout_avg_residuals into 16 equal groups. Visualized here by setting the bins to bucket_cutoffs.

pd.Series(heldout_avg_residual.flatten().cpu()).hist(bins=bucket_cutoffs.cpu())
plt.xlim([-0.035, 0.035])
plt.show()

Perhaps due to randomness during the sample split, my manually calculated cutoffs are not quite the same as the RAG.index values.

bucket_cutoffs
tensor([-0.0322, -0.0219, -0.0156, -0.0108, -0.0070, -0.0040, -0.0017,  0.0000,
         0.0019,  0.0042,  0.0071,  0.0108,  0.0155,  0.0220,  0.0327],
       device='cuda:0')
_bucket_cutoffs
tensor([-0.0307, -0.0205, -0.0146, -0.0099, -0.0064, -0.0037, -0.0016,  0.0000,
         0.0017,  0.0038,  0.0066,  0.0102,  0.0150,  0.0211,  0.0313],
       device='cuda:0')
bucket_weights
tensor([-0.0434, -0.0261, -0.0185, -0.0131, -0.0088, -0.0054, -0.0028, -0.0009,
         0.0009,  0.0029,  0.0055,  0.0088,  0.0130,  0.0184,  0.0262,  0.0441],
       device='cuda:0')
_bucket_weights
tensor([-0.0411, -0.0247, -0.0173, -0.0121, -0.0081, -0.0050, -0.0026, -0.0007,
         0.0007,  0.0027,  0.0052,  0.0083,  0.0124,  0.0178,  0.0253,  0.0421],
       device='cuda:0', dtype=torch.float16)

There seems to be some rounding differences (or perhaps it depends on the distribution?) but the weights again seem to be the midpoints-ish between the cutoffs.

(bucket_cutoffs[0] + bucket_cutoffs[1])/2
tensor(-0.0270, device='cuda:0')
(bucket_cutoffs[3] + bucket_cutoffs[4])/2
tensor(-0.0089, device='cuda:0')
(bucket_cutoffs[-2] + bucket_cutoffs[-1])/2
tensor(0.0273, device='cuda:0')

CollectionIndexer.index

Thus far we have found centroids from a sample of our token embeddings (5%, or 759) and calculated bucket cutoffs and bucket weights for quantization. We also know what the average residual mean value is.

Now we find the closest centroids and residuals for all passages’ token embeddings, starting first by encoding all 15198 tokens with our model:

embs, doclens = encoder.encode_passages(collection)
embs.shape, len(doclens), doclens[:5]
[Mar 13, 01:18:21] [0]       #> Encoding 1000 passages..
/usr/local/lib/python3.11/dist-packages/colbert/utils/amp.py:15: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
  return torch.cuda.amp.autocast() if self.activated else NullContextManager()
(torch.Size([15198, 96]), 1000, [4, 20, 18, 23, 8])

Then we call save_chunk which is inside the IndexSaver within which some interesting things take place:

def save_chunk(self, chunk_idx, offset, embs, doclens):
        compressed_embs = self.codec.compress(embs)

Looking into ResidualCodec.compress:

def compress(self, embs):
        codes, residuals = [], []

        for batch in embs.split(1 << 18):
            if self.use_gpu:
                batch = batch.cuda().half()
            codes_ = self.compress_into_codes(batch, out_device=batch.device)
            centroids_ = self.lookup_centroids(codes_, out_device=batch.device)

            residuals_ = (batch - centroids_)

            codes.append(codes_.cpu())
            residuals.append(self.binarize(residuals_).cpu())

        codes = torch.cat(codes)
        residuals = torch.cat(residuals)

        return ResidualCodec.Embeddings(codes, residuals)

We’ve seen compress_into_codes and lookup_centroids before:

codes_ = compressor.compress_into_codes(embs, out_device='cuda')
codes_.shape
torch.Size([15198])

These codes are the centroids ID closest to each token embeddings.

codes_[:5]
tensor([654, 843, 401, 654, 926], device='cuda:0')

We then get those 15198 centroid vectors:

centroids_ = compressor.lookup_centroids(codes_, out_device='cuda')
centroids_.shape
torch.Size([15198, 96])

A reminder that our centroids and our token embeddings are unit vectors:

centroids_[0].norm(), embs[0].norm()
(tensor(1., device='cuda:0', dtype=torch.float16),
 tensor(1., dtype=torch.float16))

We then find the residuals between token embeddings and centroids:

residuals_ = (embs.cpu() - centroids_.cpu())
residuals_.shape
torch.Size([15198, 96])

The next piece is super cool. We binarize the residuals, starting by using bucketize:

residuals = torch.bucketize(residuals_.float().cpu(), bucket_cutoffs.cpu()).to(dtype=torch.uint8)
residuals.shape
torch.Size([15198, 96])
residuals[0][:5]
tensor([8, 7, 7, 8, 7], dtype=torch.uint8)
residuals_[0][:5]
tensor([6.1035e-05, 0.0000e+00, 0.0000e+00, 1.9073e-06, 0.0000e+00],
       dtype=torch.float16)
residuals_[1][10:20]
tensor([ 0.0234, -0.0100,  0.0046, -0.0078, -0.0111,  0.0249, -0.0081, -0.0048,
         0.0270,  0.0037], dtype=torch.float16)
bucket_cutoffs[6:10]
tensor([-0.0017,  0.0000,  0.0019,  0.0042], device='cuda:0')
residuals.min(), residuals.max()
(tensor(0, dtype=torch.uint8), tensor(15, dtype=torch.uint8))

The values of residuals are now the ID (indices) of the buckets that the residual values fall into!

residuals = residuals.unsqueeze(-1).expand(*residuals.size(), config.nbits)
residuals.shape
torch.Size([15198, 96, 4])

We add a space for 4-bits per residual.

arange_bits = torch.arange(0, config.nbits, device='cuda', dtype=torch.uint8)
arange_bits
tensor([0, 1, 2, 3], device='cuda:0', dtype=torch.uint8)
residuals = residuals.cpu() >> arange_bits.cpu()
residuals.shape
torch.Size([15198, 96, 4])
residuals[0][:5]
tensor([[8, 4, 2, 1],
        [7, 3, 1, 0],
        [7, 3, 1, 0],
        [8, 4, 2, 1],
        [7, 3, 1, 0]], dtype=torch.uint8)
residuals = residuals & 1
residuals[0][:5]
tensor([[0, 0, 0, 1],
        [1, 1, 1, 0],
        [1, 1, 1, 0],
        [0, 0, 0, 1],
        [1, 1, 1, 0]], dtype=torch.uint8)

We have now converted the bucket ID into the actual 4-bit binary value it represents.

residuals_packed = np.packbits(np.asarray(residuals.contiguous().flatten()))
residuals_packed = torch.as_tensor(residuals_packed, dtype=torch.uint8)
residuals_packed = residuals_packed.reshape(residuals.size(0), config.dim // 8 * config.nbits)
residuals_packed.shape
torch.Size([15198, 48])
residuals_packed[0][:5]
tensor([ 30, 225, 225, 238, 238], dtype=torch.uint8)
f"30 in binary: {bin(30)[2:].zfill(8)}"
'30 in binary: 00011110'

For each residual vector with 96 values, each value is represented with 4-bits (e.g. 0, 0, 0, 1). Every 8 bits are stored into an integer (e.g. 0001 and 1110 concatenate to become the integer 30) so we have cut the number of values in half (from 96 to 48).

These residuals would be stored in 0.residuals.pt.

_build_ivf

This is a critical piece—the mapping between passages and centroids!

codes = codes_.sort()
ivf, values = codes.indices, codes.values

Token embeddings IDs:

ivf
tensor([  936,  1171,  2363,  ..., 12051, 12147, 12161], device='cuda:0')

Centroid IDs:

values
tensor([   0,    0,    0,  ..., 1023, 1023, 1023], device='cuda:0')
ivf.shape, values.shape
(torch.Size([15198]), torch.Size([15198]))

ivf contains the token embedding ID (the indices of codes_) and values contains the centroid ID (the values of codes_).

We then get the number of tokens per centroid ID:

ivf_lengths = torch.bincount(values, minlength=num_partitions)
ivf_lengths
tensor([10, 11, 17,  ..., 17,  9, 29], device='cuda:0')
ivf_lengths.shape
torch.Size([1024])
ivf_lengths.sum()
tensor(15198, device='cuda:0')

colbert/indexing/utils.py: optimize_ivf

We have 1000 documents containing a total of 15198 tokens.

total_num_embeddings = sum(doclens)
len(doclens), total_num_embeddings
(1000, 15198)

Instantiating an empty mapping between token embeddings IDs and passage IDs

emb2pid = torch.zeros(total_num_embeddings, dtype=torch.int)
emb2pid.shape
torch.Size([15198])

The indices of doclens are passage IDs pid. The values are the number of tokens in the document dlength.

offset_doclens = 0
for pid, dlength in enumerate(doclens):
    emb2pid[offset_doclens: offset_doclens + dlength] = pid
    offset_doclens += dlength
emb2pid.shape
torch.Size([15198])

The first 4 token embeddings correspond to the first passage, the next 20 token embeddings to the second passage, and so on.

emb2pid[:50]
tensor([0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3,
        3, 3], dtype=torch.int32)

Recall that ivf contained as values the token embeddings IDs which are the indices of emb2pid. The values of emb2pid are passage IDs. Indexing into emb2pid with ivf pulls out the passage IDs corresponding to tokens. Note that ivf is sorted by centroid ID.

ivf
tensor([  936,  1171,  2363,  ..., 12051, 12147, 12161], device='cuda:0')
values
tensor([   0,    0,    0,  ..., 1023, 1023, 1023], device='cuda:0')
new_ivf = emb2pid[ivf.cpu()]
new_ivf.shape, new_ivf[:5]
(torch.Size([15198]), tensor([ 55,  69, 143, 416, 471], dtype=torch.int32))
new_ivf
tensor([ 55,  69, 143,  ..., 795, 800, 800], dtype=torch.int32)

The first token embedding corresponding to centroid ID of 0 corresponds to passage ID 55.

emb2pid[936]
tensor(55, dtype=torch.int32)

new_ivf is a mapping from its indices (token embeddings) to values (passage IDs) which is now aligned to the ivf_lengths tensor which contains number of tokens per centroid ID (which came from values).

Next, we iterate through ivf_lengths, which contains the number of tokens per centroid ID. For each length we get the unique passages IDs from new_ivf, and append that to unique_pids_per_centroid. The number of unique pids for that centroid is added to new_ivf_lengths.

unique_pids_per_centroid = []
new_ivf_lengths = []

offset = 0
for length in tqdm.tqdm(ivf_lengths.tolist()):
    pids = torch.unique(new_ivf[offset:offset+length])
    unique_pids_per_centroid.append(pids)
    new_ivf_lengths.append(pids.shape[0])
    offset += length
ivf = torch.cat(unique_pids_per_centroid)
new_ivf_lengths = torch.tensor(new_ivf_lengths)
100%|██████████| 1024/1024 [00:00<00:00, 35975.77it/s]
<ipython-input-138-6c68981e98f9>:11: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  new_ivf_lengths = torch.tensor(ivf_lengths)
ivf.shape, new_ivf_lengths.shape
(torch.Size([11593]), torch.Size([1024]))

Note that there are now fewer values in ivf than 15198 since we are only capturing the unique pids per centroid.

ivf[:5]
tensor([ 55,  69, 143, 416, 471], dtype=torch.int32)
new_ivf_lengths[:5]
tensor([10, 11, 17,  1, 34], device='cuda:0')

new_ivf_lengths is the count of unique passage IDs per centroid. So, for example the first 10 pids correspond to centroid ID 0.

ivf and new_ivf_lengths would be stored in ivf.pid.pt.

After updating metadata, this completes the indexing process in RAGatouille and ColBERT!

Final Thoughts

There were of course many details that I didn’t fully explain in this walkthrough, and since I wasn’t able to exactly replicate some of the indexing artifacts there may be some errors in my code, but I think I both covered and understood the main components to creating an index. Getting to this stage involved a lot of discussion with Claude. I used AnswerAI’s toolslm to create context from the RAGatouille and ColBERT repos to provide as Claude project knowledge. I also pored through the codebase for hours, making sure to trace my steps from method-to-method. While I could do more deep dives into the individual components of indexing, I feel satisfied with this walk through for now. I hope you enjoyed this blog post!