!pip install datasets ragatouille -qq
RAGatouille/ColBERT Indexing Deep Dive
Background
In this blog post I dive deeply into the internals of the RAGatouille and ColBERT libraries to understand the intermediate steps taken when building an index for a collection of documents.
- RAGatouille
- ragatouille/RAGPretrainedModel.py
- ragatouille/models/colbert.py
- ragatouille/models/index.py
- ColBERT
- colbert/indexer.py
- colbert/indexing/collection_indexer.py
encode
Collection.cast
CollectionIndexer.run
CollectionIndexer.setup
CollectionIndexer._sample_pids
CollectionIndexer._sample_embeddings
- colbert/indexing/collection_encoder.py:
CollectionEncoder.encode_passages
Checkpoint.docFromText
- colbert/indexing/collection_encoder.py:
CollectionIndexer._save_plan
CollectionIndexer.train
CollectionIndexer.index
- colbert/indexing/collection_encoder.py:
CollectionEncoder.encode_passages
IndexSaver.save_chunk
ResidualCodec.compress
- colbert/indexing/collection_encoder.py:
CollectionIndexer.finalize
from datasets import load_dataset
from ragatouille import RAGPretrainedModel
from fastcore.utils import Path
import torch
import srsly
import uuid
from ragatouille.data import CorpusProcessor
from llama_index.core.text_splitter import SentenceSplitter
import pandas as pd
from ragatouille.models.index import PLAIDModelIndex
from colbert.infra import ColBERTConfig, RunConfig
from colbert.data.collection import Collection
from colbert.modeling.checkpoint import Checkpoint
from colbert.indexing.collection_encoder import CollectionEncoder
from colbert.indexing.collection_indexer import CollectionIndexer
import numpy as np
import random
from colbert.indexing.collection_indexer import compute_faiss_kmeans
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
from colbert.indexing.codecs.residual import ResidualCodec
from colbert.utils.utils import flatten
import tqdm
def set_all_seeds(seed=123):
"""Set seeds for all random number generators"""
import random
import numpy as np
import torch
import os
# Python's random module
random.seed(seed)
# NumPy
np.random.seed(seed)
# PyTorch
torch.manual_seed(seed)if torch.cuda.is_available():
torch.cuda.manual_seed(seed)# For multi-GPU
torch.cuda.manual_seed_all(seed)
# Set PYTHONHASHSEED for reproducibility across runs
'PYTHONHASHSEED'] = str(seed)
os.environ[
# Set deterministic algorithms for PyTorch
= True
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark
print(f"All seeds set to {seed}")
# Call this at the beginning of your script
123) set_all_seeds(
All seeds set to 123
RAGatouille RAG.index
Everything in this notebook will be compared to what’s generated with RAG.index
.
For this exercise, I’ll use 1000 passages from the UKPLab/DAPR ConditionalQA dataset.
= load_dataset("UKPLab/dapr", f"ConditionalQA-corpus", split="test[:1000]")
passages passages
/usr/local/lib/python3.11/dist-packages/huggingface_hub/utils/_auth.py:94: UserWarning:
The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.
warnings.warn(
Dataset({
features: ['_id', 'text', 'title', 'doc_id', 'paragraph_no', 'total_paragraphs', 'is_candidate'],
num_rows: 1000
})
0]['text'] passages[
'Overview'
= "answerdotai/answerai-colbert-small-v1" model_nm
= RAGPretrainedModel.from_pretrained(model_nm)
RAG = RAG.index(index_name="cqa_index", collection=passages['text']) index_path
/usr/local/lib/python3.11/dist-packages/colbert/utils/amp.py:12: FutureWarning: `torch.cuda.amp.GradScaler(args...)` is deprecated. Please use `torch.amp.GradScaler('cuda', args...)` instead.
self.scaler = torch.cuda.amp.GradScaler()
---- WARNING! You are using PLAID with an experimental replacement for FAISS for greater compatibility ----
This is a behaviour change from RAGatouille 0.8.0 onwards.
This works fine for most users and smallish datasets, but can be considerably slower than FAISS and could cause worse results in some situations.
If you're confident with FAISS working on your machine, pass use_faiss=True to revert to the FAISS-using behaviour.
--------------------
[Mar 13, 01:15:09] #> Creating directory .ragatouille/colbert/indexes/cqa_index
[Mar 13, 01:15:11] [0] #> Encoding 1000 passages..
/usr/local/lib/python3.11/dist-packages/colbert/utils/amp.py:15: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
return torch.cuda.amp.autocast() if self.activated else NullContextManager()
[Mar 13, 01:15:14] [0] avg_doclen_est = 15.197999954223633 len(local_sample) = 1,000
[Mar 13, 01:15:14] [0] Creating 1,024 partitions.
[Mar 13, 01:15:14] [0] *Estimated* 15,197 embeddings.
[Mar 13, 01:15:14] [0] #> Saving the indexing plan to .ragatouille/colbert/indexes/cqa_index/plan.json ..
/usr/local/lib/python3.11/dist-packages/colbert/indexing/collection_indexer.py:256: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
sub_sample = torch.load(sub_sample_path)
used 20 iterations (0.5189s) to cluster 14439 items into 1024 clusters
[Mar 13, 01:15:14] Loading decompress_residuals_cpp extension (set COLBERT_LOAD_TORCH_EXTENSION_VERBOSE=True for more info)...
/usr/local/lib/python3.11/dist-packages/torch/utils/cpp_extension.py:1964: UserWarning: TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation.
If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].
warnings.warn(
[Mar 13, 01:16:51] Loading packbits_cpp extension (set COLBERT_LOAD_TORCH_EXTENSION_VERBOSE=True for more info)...
/usr/local/lib/python3.11/dist-packages/torch/utils/cpp_extension.py:1964: UserWarning: TORCH_CUDA_ARCH_LIST is not set, all archs for visible cards are included for compilation.
If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].
warnings.warn(
[0.015, 0.016, 0.015, 0.015, 0.013, 0.016, 0.015, 0.016, 0.017, 0.014, 0.013, 0.015, 0.017, 0.014, 0.015, 0.017, 0.014, 0.015, 0.015, 0.014, 0.015, 0.016, 0.015, 0.015, 0.014, 0.014, 0.015, 0.015, 0.015, 0.015, 0.014, 0.014, 0.015, 0.014, 0.015, 0.015, 0.014, 0.015, 0.016, 0.015, 0.014, 0.015, 0.015, 0.014, 0.014, 0.017, 0.016, 0.017, 0.014, 0.015, 0.016, 0.015, 0.016, 0.016, 0.012, 0.015, 0.016, 0.015, 0.016, 0.016, 0.015, 0.015, 0.016, 0.014, 0.015, 0.017, 0.016, 0.015, 0.014, 0.015, 0.015, 0.015, 0.014, 0.016, 0.016, 0.016, 0.014, 0.015, 0.015, 0.014, 0.014, 0.014, 0.016, 0.015, 0.016, 0.015, 0.014, 0.014, 0.014, 0.015, 0.016, 0.014, 0.014, 0.016, 0.014, 0.015]
/usr/local/lib/python3.11/dist-packages/colbert/indexing/codecs/residual.py:141: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
centroids = torch.load(centroids_path, map_location='cpu')
/usr/local/lib/python3.11/dist-packages/colbert/indexing/codecs/residual.py:142: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
avg_residual = torch.load(avgresidual_path, map_location='cpu')
/usr/local/lib/python3.11/dist-packages/colbert/indexing/codecs/residual.py:143: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
bucket_cutoffs, bucket_weights = torch.load(buckets_path, map_location='cpu')
0it [00:00, ?it/s]
[Mar 13, 01:18:16] [0] #> Encoding 1000 passages..
/usr/local/lib/python3.11/dist-packages/colbert/utils/amp.py:15: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
return torch.cuda.amp.autocast() if self.activated else NullContextManager()
1it [00:00, 1.39it/s]
0%| | 0/1 [00:00<?, ?it/s]/usr/local/lib/python3.11/dist-packages/colbert/indexing/codecs/residual_embeddings.py:86: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
return torch.load(codes_path, map_location='cpu')
100%|██████████| 1/1 [00:00<00:00, 787.81it/s]
[Mar 13, 01:18:16] #> Optimizing IVF to store map from centroids to list of pids..
[Mar 13, 01:18:16] #> Building the emb2pid mapping..
[Mar 13, 01:18:16] len(emb2pid) = 15198
100%|██████████| 1024/1024 [00:00<00:00, 61154.86it/s]
[Mar 13, 01:18:16] #> Saved optimized IVF to .ragatouille/colbert/indexes/cqa_index/ivf.pid.pt
Done indexing!
= Path(index_path)
index_path index_path
Path('.ragatouille/colbert/indexes/cqa_index')
for o in index_path.ls(): print(o)
.ragatouille/colbert/indexes/cqa_index/buckets.pt
.ragatouille/colbert/indexes/cqa_index/collection.json
.ragatouille/colbert/indexes/cqa_index/metadata.json
.ragatouille/colbert/indexes/cqa_index/ivf.pid.pt
.ragatouille/colbert/indexes/cqa_index/doclens.0.json
.ragatouille/colbert/indexes/cqa_index/0.residuals.pt
.ragatouille/colbert/indexes/cqa_index/centroids.pt
.ragatouille/colbert/indexes/cqa_index/0.metadata.json
.ragatouille/colbert/indexes/cqa_index/plan.json
.ragatouille/colbert/indexes/cqa_index/0.codes.pt
.ragatouille/colbert/indexes/cqa_index/avg_residual.pt
.ragatouille/colbert/indexes/cqa_index/pid_docid_map.json
While it’s a bit tedious to do so (since I’m chomping at the bit to get to the deep dive!) I think it’s worth analyzing the contents of each of these files, as we’ll be recreating them in this notebook.
buckets.pt
Looking at Line 160 of the ColBERT repo’s residual.py
, buckets.py stores bucket_cutoffs and bucket_weights. We’ll go into detail into what these exactly are later on.
= torch.load(index_path/'buckets.pt', weights_only=True) _bucket_cutoffs, _bucket_weights
_bucket_cutoffs.shape, _bucket_weights.shape
(torch.Size([15]), torch.Size([16]))
_bucket_cutoffs
tensor([-0.0307, -0.0205, -0.0146, -0.0099, -0.0064, -0.0037, -0.0016, 0.0000,
0.0017, 0.0038, 0.0066, 0.0102, 0.0150, 0.0211, 0.0313],
device='cuda:0')
_bucket_weights
tensor([-0.0411, -0.0247, -0.0173, -0.0121, -0.0081, -0.0050, -0.0026, -0.0007,
0.0007, 0.0027, 0.0052, 0.0083, 0.0124, 0.0178, 0.0253, 0.0421],
device='cuda:0', dtype=torch.float16)
0.residuals.pt
IIUC, there are 15198 tokens in our collection, each with a 48-dimension vector representation, and each integer value represents two 4-bit codes that each correspond to a quantized value. So there are actually 96 values in each vector.
= torch.load(index_path/'0.residuals.pt', weights_only=True)
_residuals _residuals
tensor([[ 30, 225, 225, ..., 238, 238, 30],
[230, 22, 158, ..., 233, 106, 170],
[238, 238, 238, ..., 238, 238, 238],
...,
[ 43, 22, 23, ..., 104, 31, 208],
[222, 254, 91, ..., 128, 8, 189],
[229, 82, 22, ..., 170, 94, 154]], dtype=torch.uint8)
_residuals.shape
torch.Size([15198, 48])
48*4/2
96.0
ivf.pid.pt
IIRC, ivf
contains a flattened sequence of passage IDs corresponding to each centroid. There are 1024 centroids and the first 8 passage IDs in ivf
correspond to the 0-th centroid.
= torch.load(index_path/'ivf.pid.pt', weights_only=True)
_ivf, _ivf_lengths _ivf.shape, _ivf_lengths.shape
(torch.Size([11759]), torch.Size([1024]))
5] _ivf[:
tensor([895, 896, 902, 904, 909], dtype=torch.int32)
0] _ivf_lengths[
tensor(8)
0.metadata.json
There are 1000 passages totaling 15198 tokens.
def load_json(path, filename): return srsly.read_json(str(Path(path) / filename))
"0.metadata.json") load_json(index_path,
{'passage_offset': 0,
'num_passages': 1000,
'num_embeddings': 15198,
'embedding_offset': 0}
collection.json
This JSON contains, as a list, the strings of the 1000 passages in our collection.
= load_json(index_path, "collection.json")
_collection len(_collection), _collection[0]
(1000, 'Overview')
avg_residual.pt
I believe this is the average residual across the 15198 tokens (i.e. the average distance in vector-space between the tokens and their closest centroids).
= torch.load(index_path/'avg_residual.pt', weights_only=True)
_avg_residual _avg_residual
tensor(0.0150, device='cuda:0', dtype=torch.float16)
doclens.0.json
This contains a mapping (list) between passages IDs (indices) and the number of tokens in the document (values).
= load_json(index_path, "doclens.0.json")
_doclens len(_doclens), _doclens[:5]
(1000, [4, 20, 18, 23, 8])
sum(doclens)
15198
metadata.json
Lots of information in here, will highlight the number of centroids and the number of token embeddings in the collection:
'num_partitions': 1024,
'num_embeddings': 15198
"metadata.json") load_json(index_path,
{'config': {'query_token_id': '[unused0]',
'doc_token_id': '[unused1]',
'query_token': '[Q]',
'doc_token': '[D]',
'ncells': None,
'centroid_score_threshold': None,
'ndocs': None,
'load_index_with_mmap': False,
'index_path': None,
'index_bsize': 32,
'nbits': 4,
'kmeans_niters': 20,
'resume': False,
'pool_factor': 1,
'clustering_mode': 'hierarchical',
'protected_tokens': 0,
'similarity': 'cosine',
'bsize': 64,
'accumsteps': 1,
'lr': 1e-05,
'maxsteps': 15626,
'save_every': None,
'warmup': 781,
'warmup_bert': None,
'relu': False,
'nway': 32,
'use_ib_negatives': False,
'reranker': False,
'distillation_alpha': 1.0,
'ignore_scores': False,
'model_name': 'answerdotai/AnswerAI-ColBERTv2.5-small',
'query_maxlen': 32,
'attend_to_mask_tokens': False,
'interaction': 'colbert',
'dim': 96,
'doc_maxlen': 256,
'mask_punctuation': True,
'checkpoint': 'answerdotai/answerai-colbert-small-v1',
'triples': '/home/bclavie/colbertv2.5_en/data/msmarco/triplets.jsonl',
'collection': ['list with 1000 elements starting with...',
['Overview',
'You can only make a claim for Child Tax Credit if you already get Working Tax Credit.',
'If you cannot apply for Child Tax Credit, you can apply for Universal Credit instead.']],
'queries': '/home/bclavie/colbertv2.5_en/data/msmarco/queries.tsv',
'index_name': 'cqa_index',
'overwrite': False,
'root': '.ragatouille/',
'experiment': 'colbert',
'index_root': None,
'name': '2025-03/13/01.14.35',
'rank': 0,
'nranks': 1,
'amp': True,
'gpus': 1,
'avoid_fork_if_possible': False},
'num_chunks': 1,
'num_partitions': 1024,
'num_embeddings': 15198,
'avg_doclen': 15.198,
'RAGatouille': {'index_config': {'index_type': 'PLAID',
'index_name': 'cqa_index'}}}
centroids.pt
There are 1024 96-dimension centroid vectors stored.
= torch.load(index_path/'centroids.pt', weights_only=True)
_centroids _centroids.shape
torch.Size([1024, 96])
They store the full uncompressed values for the centroids.
0][:5] _centroids[
tensor([-0.0649, 0.1193, -0.0551, 0.0561, -0.0826], device='cuda:0',
dtype=torch.float16)
0.codes.pt
I believe this is a mapping (list) between tokens (indices) and centroid IDs (values).
= torch.load(index_path/'0.codes.pt', weights_only=True)
_codes _codes.shape
torch.Size([15198])
5] _codes[:
tensor([138, 843, 273, 138, 561], dtype=torch.int32)
pid_docid_map.json
A mapping between passage ID (0-999) and document ID (UUID).
= load_json(index_path, "pid_docid_map.json")
_pid_docid_map '999'] _pid_docid_map[
'2be086c6-04cc-4d73-b372-08236f76cbe6'
plan.json
This seems to contain the same information as metadata.json.
= load_json(index_path, "plan.json")
_plan _plan
{'config': {'query_token_id': '[unused0]',
'doc_token_id': '[unused1]',
'query_token': '[Q]',
'doc_token': '[D]',
'ncells': None,
'centroid_score_threshold': None,
'ndocs': None,
'load_index_with_mmap': False,
'index_path': None,
'index_bsize': 32,
'nbits': 4,
'kmeans_niters': 20,
'resume': False,
'pool_factor': 1,
'clustering_mode': 'hierarchical',
'protected_tokens': 0,
'similarity': 'cosine',
'bsize': 64,
'accumsteps': 1,
'lr': 1e-05,
'maxsteps': 15626,
'save_every': None,
'warmup': 781,
'warmup_bert': None,
'relu': False,
'nway': 32,
'use_ib_negatives': False,
'reranker': False,
'distillation_alpha': 1.0,
'ignore_scores': False,
'model_name': 'answerdotai/AnswerAI-ColBERTv2.5-small',
'query_maxlen': 32,
'attend_to_mask_tokens': False,
'interaction': 'colbert',
'dim': 96,
'doc_maxlen': 256,
'mask_punctuation': True,
'checkpoint': 'answerdotai/answerai-colbert-small-v1',
'triples': '/home/bclavie/colbertv2.5_en/data/msmarco/triplets.jsonl',
'collection': ['list with 1000 elements starting with...',
['Overview',
'You can only make a claim for Child Tax Credit if you already get Working Tax Credit.',
'If you cannot apply for Child Tax Credit, you can apply for Universal Credit instead.']],
'queries': '/home/bclavie/colbertv2.5_en/data/msmarco/queries.tsv',
'index_name': 'cqa_index',
'overwrite': False,
'root': '.ragatouille/',
'experiment': 'colbert',
'index_root': None,
'name': '2025-03/13/01.14.35',
'rank': 0,
'nranks': 1,
'amp': True,
'gpus': 1,
'avoid_fork_if_possible': False},
'num_chunks': 1,
'num_partitions': 1024,
'num_embeddings_est': 15197.999954223633,
'avg_doclen_est': 15.197999954223633}
In the following sections, I’ll try to recreate each of these index_path
elements.
_process_corpus
Inside RAG.index
, _process_corpus
is called on the documents and document IDs.
= [str(uuid.uuid4()) for _ in range(len(passages))]
passage_ids 0] passage_ids[
'd4cdfec5-a949-43e0-94b3-feb24caeac5e'
Use the corpus processor to convert the passages into {'document_id': '...', 'content': '...'}
dictionaries with 256-token max length.
= CorpusProcessor()
cp cp
<ragatouille.data.corpus_processor.CorpusProcessor at 0x794fa0323690>
= cp.process_corpus(passages['text'], passage_ids, chunk_size=256)
collection_with_ids len(collection_with_ids), collection_with_ids[0]
(1000,
{'document_id': 'd4cdfec5-a949-43e0-94b3-feb24caeac5e',
'content': 'Overview'})
As a brief aside, I’ll take a look at the maximum token length of the passages.
= SentenceSplitter(chunk_size=256)
node_parser node_parser._token_size
llama_index.core.node_parser.text.sentence.SentenceSplitter._token_size
def _token_size(text: str) -> int
<no docstring>
= []
tk_szs for p in passages['text']: tk_szs.append(node_parser._token_size(p))
pd.Series(tk_szs).describe()
0 | |
---|---|
count | 1000.000000 |
mean | 13.217000 |
std | 10.192635 |
min | 1.000000 |
25% | 5.000000 |
50% | 10.000000 |
75% | 19.000000 |
max | 65.000000 |
This collection of passages has relatively short passages (a max of 65 tokens).
_process_corpus
then creates
= {index: item["document_id"] for index, item in enumerate(collection_with_ids)} pid_docid_map
999] pid_docid_map[
'096e054e-3041-4881-ac48-b20f1804f650'
This matches the content of pid_docid_map.json
.
_process_corpus
also defines a list of strings, collection
:
= [x["content"] for x in collection_with_ids]
collection 0] collection[
'Overview'
_process_corpus
also calls _process_metadata
which defines docid_metadata_map
as None
when document_metadatas
is None
(which it is in our case).
= None docid_metadata_map
RAG.index
Internals
After calling _process_corpus
, RAG.index
calls model.index
, where model
is:
= ColBERT(
instance.model =index_root, verbose=verbose
pretrained_model_name_or_path, n_gpu, index_root )
ColBERT.index
in turn calls:
ModelIndexFactory.construct
By default the type of index is PLAID, so the following is called:
PLAIDModelIndex(config).build(**kwargs
checkpoint, collection, index_name, overwrite, verbose, )
PLAIDModelIndex.build
A couple of key configuration values are set in this method, starting with the bsize (which I think is batch size?) defaulting to 32.
PLAIDModelIndex._DEFAULT_INDEX_BSIZE
32
= PLAIDModelIndex._DEFAULT_INDEX_BSIZE
bsize bsize
32
The size of compressed residual embedding values is determined based on the size of the collection.
if len(collection) < 10000: nbits = 4
nbits
4
It then defines a ColBERTConfig
object, which I believe is instantiated as follows when the ColBERT
checkpoint is instantiated:
= ColBERTConfig.load_from_checkpoint(str(model_nm))
ckpt_config ckpt_config
ColBERTConfig(query_token_id='[unused0]', doc_token_id='[unused1]', query_token='[Q]', doc_token='[D]', ncells=None, centroid_score_threshold=None, ndocs=None, load_index_with_mmap=False, index_path=None, index_bsize=64, nbits=1, kmeans_niters=4, resume=False, pool_factor=1, clustering_mode='hierarchical', protected_tokens=0, similarity='cosine', bsize=32, accumsteps=1, lr=1e-05, maxsteps=15626, save_every=None, warmup=781, warmup_bert=None, relu=False, nway=32, use_ib_negatives=False, reranker=False, distillation_alpha=1.0, ignore_scores=False, model_name='answerdotai/AnswerAI-ColBERTv2.5-small', query_maxlen=32, attend_to_mask_tokens=False, interaction='colbert', dim=96, doc_maxlen=300, mask_punctuation=True, checkpoint='/root/.cache/huggingface/hub/models--answerdotai--answerai-colbert-small-v1/snapshots/be1703c55532145a844da800eea4c9a692d7e267/', triples='/home/bclavie/colbertv2.5_en/data/msmarco/triplets.jsonl', collection='/home/bclavie/colbertv2.5_en/data/msmarco/collection.tsv', queries='/home/bclavie/colbertv2.5_en/data/msmarco/queries.tsv', index_name=None, overwrite=False, root='/home/bclavie/colbertv2.5_en/experiments', experiment='minicolbertv2.5', index_root=None, name='2024-08/07/08.16.20', rank=0, nranks=4, amp=True, gpus=4, avoid_fork_if_possible=False)
= ColBERTConfig.from_existing(ckpt_config)
config config
ColBERTConfig(query_token_id='[unused0]', doc_token_id='[unused1]', query_token='[Q]', doc_token='[D]', ncells=None, centroid_score_threshold=None, ndocs=None, load_index_with_mmap=False, index_path=None, index_bsize=64, nbits=1, kmeans_niters=4, resume=False, pool_factor=1, clustering_mode='hierarchical', protected_tokens=0, similarity='cosine', bsize=32, accumsteps=1, lr=1e-05, maxsteps=15626, save_every=None, warmup=781, warmup_bert=None, relu=False, nway=32, use_ib_negatives=False, reranker=False, distillation_alpha=1.0, ignore_scores=False, model_name='answerdotai/AnswerAI-ColBERTv2.5-small', query_maxlen=32, attend_to_mask_tokens=False, interaction='colbert', dim=96, doc_maxlen=300, mask_punctuation=True, checkpoint='/root/.cache/huggingface/hub/models--answerdotai--answerai-colbert-small-v1/snapshots/be1703c55532145a844da800eea4c9a692d7e267/', triples='/home/bclavie/colbertv2.5_en/data/msmarco/triplets.jsonl', collection='/home/bclavie/colbertv2.5_en/data/msmarco/collection.tsv', queries='/home/bclavie/colbertv2.5_en/data/msmarco/queries.tsv', index_name=None, overwrite=False, root='/home/bclavie/colbertv2.5_en/experiments', experiment='minicolbertv2.5', index_root=None, name='2024-08/07/08.16.20', rank=0, nranks=4, amp=True, gpus=4, avoid_fork_if_possible=False)
We also need to create a RunConfig
object:
= RunConfig(nranks=-1, experiment="colbert", root=".ragatouille/")
run_config run_config
RunConfig(overwrite=False, root='.ragatouille/', experiment='colbert', index_root=None, name='2025-03/13/01.14.35', rank=0, nranks=-1, amp=True, gpus=1, avoid_fork_if_possible=False)
A couple more config values are set:
= True
config.avoid_fork_if_possible
if len(collection) > 100000:
= 4
config.kmeans_niters elif len(collection) > 50000:
= 10
config.kmeans_niters else:
= 20
config.kmeans_niters config.avoid_fork_if_possible, config.kmeans_niters
(True, 20)
After determining whether the PyTorch or FAISS k-means implementation will be used, Indexer.index
is called.
Indexer.index
The Indexer
comes from the ColBERT repo, so this is essentially the connection point between the RAGatouille and ColBERT libraries.
Launcher
Inside Indexer.index
, __launch
is called, from within which a Launcher
instance is created with the encode
function.
I’m a bit fuzzy on the next part but I’ll give it a shot:
when Launcher.launch
is called, the following two lines are called (where callee
is the encode
function):
= (self.callee, port, return_value_queue, new_config, *args)
args_ =setup_new_process, args=args_)) all_procs.append(mp.Process(target
setup_new_process
contains the following lines:
with Run().context(config, inherit_config=False):
= callee(config, *args) return_val
With callee
being called, let’s look at the function that callee
is : encode
, which is part of the collection_indexer.py file.
encode
def encode(config, collection, shared_lists, shared_queues, verbose: int = 3):
= CollectionIndexer(config=config, collection=collection, verbose=verbose)
encoder encoder.run(shared_lists)
This leads us to encoder.run
which is CollectionIndexer.run
. But before that, we need to look at how the collection is transformed when CollectionIndexer
is instantiated.
CollectionIndexer.__init__
There are two important objects created when the CollectionIndexer
is instantiated. First is the Collection
object, which turns our list collection
:
type(collection)
list
into a Collection
object:
= Collection.cast(collection)
collection collection
<colbert.data.collection.Collection at 0x794fa037e410>
type(collection)
colbert.data.collection.Collection
def __init__(path=None, data=None)
<no docstring>
Next, it creates a CollectionEncoder
object:
= Checkpoint(config.checkpoint, colbert_config=config)
checkpoint = CollectionEncoder(config, checkpoint) encoder
/usr/local/lib/python3.11/dist-packages/colbert/utils/amp.py:12: FutureWarning: `torch.cuda.amp.GradScaler(args...)` is deprecated. Please use `torch.amp.GradScaler('cuda', args...)` instead.
self.scaler = torch.cuda.amp.GradScaler()
checkpoint
is our model:
checkpoint
Checkpoint(
(model): HF_ColBERT(
(linear): Linear(in_features=384, out_features=96, bias=False)
(bert): BertModel(
(embeddings): BertEmbeddings(
(word_embeddings): Embedding(30522, 384, padding_idx=0)
(position_embeddings): Embedding(512, 384)
(token_type_embeddings): Embedding(2, 384)
(LayerNorm): LayerNorm((384,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(encoder): BertEncoder(
(layer): ModuleList(
(0-11): 12 x BertLayer(
(attention): BertAttention(
(self): BertSdpaSelfAttention(
(query): Linear(in_features=384, out_features=384, bias=True)
(key): Linear(in_features=384, out_features=384, bias=True)
(value): Linear(in_features=384, out_features=384, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
)
(output): BertSelfOutput(
(dense): Linear(in_features=384, out_features=384, bias=True)
(LayerNorm): LayerNorm((384,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(intermediate): BertIntermediate(
(dense): Linear(in_features=384, out_features=1536, bias=True)
(intermediate_act_fn): GELUActivation()
)
(output): BertOutput(
(dense): Linear(in_features=1536, out_features=384, bias=True)
(LayerNorm): LayerNorm((384,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.1, inplace=False)
)
)
)
)
(pooler): BertPooler(
(dense): Linear(in_features=384, out_features=384, bias=True)
(activation): Tanh()
)
)
)
)
The CollectionEncoder
will be used later on to encode the passages.
encoder
<colbert.indexing.collection_encoder.CollectionEncoder at 0x794fa0683dd0>
Next we’ll dive into CollectionIndexer.run
within which all of the indexing operations that I’m most interested in take place, starting with setup
.
CollectionIndexer.run
CollectionIndexer.setup
'''
Calculates and saves plan.json for the whole collection.
plan.json { config, num_chunks, num_partitions, num_embeddings_est, avg_doclen_est}
num_partitions is the number of centroids to be generated.
'''
We’ll see where num_chunks
is used, but for now I’ll just define it:
= int(np.ceil(len(collection) / collection.get_chunksize()))
num_chunks len(collection), collection.get_chunksize(), num_chunks
(1000, 1001, 1)
Next we look at _sample_pids
and _sample_embeddings
which are later clustered to get our centroids.
_sample_pids
= len(collection)
num_passages num_passages
1000
It’s awesome to see one of the heuristics mentioned in the ColBERTv2 paper:
To reduce memory consumption, we apply k-means clustering to the embeddings produced by invoking our BERT encoder over only a sample of all passages, proportional to the square root of the collection size, an approach we found to perform well in practice.
= 120
typical_doclen = 16 * np.sqrt(typical_doclen * num_passages)
sampled_pids sampled_pids
5542.562584220407
= min(1 + int(sampled_pids), num_passages) sampled_pids
In this case because my toy collection is so small (1000 passages) we will use all of them for centroid clustering.
= random.sample(range(num_passages), sampled_pids)
sampled_pids = set(sampled_pids)
sampled_pids len(sampled_pids), min(sampled_pids), max(sampled_pids)
(1000, 0, 999)
_sample_embeddings
= collection.enumerate(rank=config.rank)
local_pids local_pids
<generator object Collection.enumerate at 0x794fa067dc40>
sampled_pids
contains all of our passages
= [passage for pid, passage in local_pids if pid in sampled_pids]
local_sample len(local_sample)
1000
Next come another critical process—encoding our passages!
CollectionEncoder.encode_passages
Inside encode_passages
we call checkpoint.docFromText
.
checkpoint.docFromText
And inside checkpoint.docFromText
we call checkpoint.doc
checkpoint.doc
Inside ColBERT.doc
we finally call the lowest-level method in this chain:
= self.bert(input_ids, attention_mask=attention_mask)[0] D
One key point to visualize is that the BERT output is normalized:
= torch.nn.functional.normalize(D, p=2, dim=2) D
I’ll zoom out again and call encode_passages
.
= encoder.encode_passages(local_sample) local_sample_embs, doclens
[Mar 13, 01:18:18] [0] #> Encoding 1000 passages..
/usr/local/lib/python3.11/dist-packages/colbert/utils/amp.py:15: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
return torch.cuda.amp.autocast() if self.activated else NullContextManager()
local_sample_embs.shape
torch.Size([15198, 96])
Note that the token embeddings are a unit vector:
0].norm() local_sample_embs[
tensor(1., dtype=torch.float16)
len(doclens), doclens[:5]
(1000, [4, 20, 18, 23, 8])
We have 15198 token embeddings (embedded into answerai-colbert-small-v1’s 96-dimension space) and a mapping (list) of passage ID (indices) to number of tokens (values).
= sum(doclens) / len(doclens) if doclens else 0
avg_doclen_est avg_doclen_est
15.198
On average, each passage (document) is 15 tokens long.
Zooming back out to CollectionIndexer.setup we have a few more steps before our planning is complete:
= len(collection)
num_passages = num_passages * avg_doclen_est
num_embeddings_est = int(2 ** np.floor(np.log2(16 * np.sqrt(num_embeddings_est))))
num_partitions num_passages, num_embeddings_est, num_partitions
(1000, 15198.0, 1024)
num_partitions
is the number of clusters we will cluster our 15198 token embeddings into.
This is the information that was in plan.json in index_path
(in addition to the other ColBERTConfig information).
CollectionIndexer.train
After setup
is complete, the next method called in run
is train
.
The first step in train
is to split our local_sample_embs
into sample
and sample_heldout
.
= local_sample_embs[torch.randperm(local_sample_embs.size(0))] local_sample_embs
= 0.05
heldout_fraction = int(min(heldout_fraction * local_sample_embs.size(0), 50_000))
heldout_size heldout_size
759
= local_sample_embs.split([local_sample_embs.size(0) - heldout_size, heldout_size], dim=0)
sample, sample_heldout sample.shape, sample_heldout.shape
(torch.Size([14439, 96]), torch.Size([759, 96]))
compute_faiss_kmeans
Next we get the centroids using compute_faiss_kmeans
= [config.dim, num_partitions, config.kmeans_niters, [[sample]]] args_
= compute_faiss_kmeans(*args_)
centroids centroids.shape
torch.Size([1024, 96])
We then normalize the centroids
= torch.nn.functional.normalize(centroids, dim=-1)
centroids 0].norm() centroids.shape, centroids[
(torch.Size([1024, 96]), tensor(1.))
= centroids.half() centroids
I was hoping to get the same values as RAG.index
centroids by setting seeds at the start of this notebook, but I am not getting the same result.
0][:5] _centroids[
tensor([-0.0649, 0.1193, -0.0551, 0.0561, -0.0826], device='cuda:0',
dtype=torch.float16)
0][:5] centroids[
tensor([-0.0587, 0.0379, -0.0847, -0.0224, -0.0636], dtype=torch.float16)
I’ll use PCA to compare the two sets of centroids:
# Project to 2D
= PCA(n_components=2)
pca = pca.fit_transform(_centroids.cpu().numpy())
prev_2d = pca.transform(centroids.cpu().numpy())
new_2d
# Plot
=(10, 8))
plt.figure(figsize0], prev_2d[:, 1], alpha=0.5, label='Previous')
plt.scatter(prev_2d[:, 0], new_2d[:, 1], alpha=0.5, label='New')
plt.scatter(new_2d[:,
plt.legend()'PCA projection of centroids')
plt.title( plt.show()
I’m not super familiar with interpreting PCA plots, so I asked Claude what it thought about this result:
I would describe this as showing “good structural similarity but with expected local variations.” The centroids aren’t identical (which would show perfect overlap), but they capture similar patterns in the embedding space. This suggests that while individual centroid positions differ, the overall index structure should perform similarly for retrieval tasks.
For now, I’ll consider this part of indexing completing, as we have generated similar contents to what’s in centroids.pt.
_compute_avg_residual
This next section was quite eye opening for me, as it was the first time I understood how quantization is implemented.
The ResidualCodec
does all of the compression/binarization/decompress of residuals.
= ResidualCodec(config=config, centroids=centroids, avg_residual=None)
compressor compressor
<colbert.indexing.codecs.residual.ResidualCodec at 0x794f9aacfdd0>
= compressor.compress_into_codes(sample_heldout, out_device='cuda' )
heldout_reconstruct heldout_reconstruct.shape
torch.Size([759])
compress_into_codes
finds the nearest centroid IDs to the token embeddings. It does so using cosine similarity:
= (self.centroids @ batch.T.cuda().half()).max(dim=0).indices.to(device=out_device) indices
5] heldout_reconstruct[:
tensor([633, 667, 738, 641, 443], device='cuda:0')
lookup_centroids
gets the full vectors related to the centroid IDs in heldout_reconstruct
= compressor.lookup_centroids(heldout_reconstruct, out_device='cuda')
heldout_reconstruct heldout_reconstruct.shape
torch.Size([759, 96])
The residual between the heldout token embeddings and the closest centroids is then calculated:
= sample_heldout.cuda() - heldout_reconstruct
heldout_avg_residual heldout_avg_residual.shape
torch.Size([759, 96])
We then calculate the average residual vector (96 dimensions):
= torch.abs(heldout_avg_residual).mean(dim=0).cpu()
avg_residual avg_residual.shape
torch.Size([96])
The average residual is somewhat similar to the stored value in avg_residual.pt.
_avg_residual, avg_residual.mean()
(tensor(0.0150, device='cuda:0', dtype=torch.float16),
tensor(0.0158, dtype=torch.float16))
To match the RAG.index
defaults, I’m going to set nbits
to 4.
config.nbits
1
= 4
config.nbits config.nbits
4
= 2 ** config.nbits
num_options config.nbits, num_options
(4, 16)
A 4-bit value has four 0 or 1 values and there are 16 possible combinations:
Binary |
---|
0000 |
0001 |
0010 |
0011 |
0100 |
0101 |
0110 |
0111 |
1000 |
1001 |
1010 |
1011 |
1100 |
1101 |
1110 |
1111 |
We split 0-to-1 into 16 equal parts:
= torch.arange(0, num_options, device=heldout_avg_residual.device) * (1 / num_options)
quantiles quantiles.shape, quantiles
(torch.Size([16]),
tensor([0.0000, 0.0625, 0.1250, 0.1875, 0.2500, 0.3125, 0.3750, 0.4375, 0.5000,
0.5625, 0.6250, 0.6875, 0.7500, 0.8125, 0.8750, 0.9375],
device='cuda:0'))
= quantiles[1:], quantiles + (0.5 / num_options)
bucket_cutoffs_quantiles, bucket_weights_quantiles bucket_cutoffs_quantiles, bucket_weights_quantiles
(tensor([0.0625, 0.1250, 0.1875, 0.2500, 0.3125, 0.3750, 0.4375, 0.5000, 0.5625,
0.6250, 0.6875, 0.7500, 0.8125, 0.8750, 0.9375], device='cuda:0'),
tensor([0.0312, 0.0938, 0.1562, 0.2188, 0.2812, 0.3438, 0.4062, 0.4688, 0.5312,
0.5938, 0.6562, 0.7188, 0.7812, 0.8438, 0.9062, 0.9688],
device='cuda:0'))
IIUC, the weights’ quantiles are the midpoints between adjacent cutoffs’ quantiles.
1] + bucket_cutoffs_quantiles[2])/2 (bucket_cutoffs_quantiles[
tensor(0.1562, device='cuda:0')
3] + bucket_cutoffs_quantiles[4])/2 (bucket_cutoffs_quantiles[
tensor(0.2812, device='cuda:0')
-2] + bucket_cutoffs_quantiles[-1])/2 (bucket_cutoffs_quantiles[
tensor(0.9062, device='cuda:0')
= heldout_avg_residual.float().quantile(bucket_cutoffs_quantiles)
bucket_cutoffs = heldout_avg_residual.float().quantile(bucket_weights_quantiles) bucket_weights
IIUC, bucket_cutoffs
are the values with which we can group our (flattened) heldout_avg_residual
s into 16 equal groups. Visualized here by setting the bins to bucket_cutoffs
.
=bucket_cutoffs.cpu())
pd.Series(heldout_avg_residual.flatten().cpu()).hist(bins-0.035, 0.035])
plt.xlim([ plt.show()
Perhaps due to randomness during the sample split, my manually calculated cutoffs are not quite the same as the RAG.index
values.
bucket_cutoffs
tensor([-0.0322, -0.0219, -0.0156, -0.0108, -0.0070, -0.0040, -0.0017, 0.0000,
0.0019, 0.0042, 0.0071, 0.0108, 0.0155, 0.0220, 0.0327],
device='cuda:0')
_bucket_cutoffs
tensor([-0.0307, -0.0205, -0.0146, -0.0099, -0.0064, -0.0037, -0.0016, 0.0000,
0.0017, 0.0038, 0.0066, 0.0102, 0.0150, 0.0211, 0.0313],
device='cuda:0')
bucket_weights
tensor([-0.0434, -0.0261, -0.0185, -0.0131, -0.0088, -0.0054, -0.0028, -0.0009,
0.0009, 0.0029, 0.0055, 0.0088, 0.0130, 0.0184, 0.0262, 0.0441],
device='cuda:0')
_bucket_weights
tensor([-0.0411, -0.0247, -0.0173, -0.0121, -0.0081, -0.0050, -0.0026, -0.0007,
0.0007, 0.0027, 0.0052, 0.0083, 0.0124, 0.0178, 0.0253, 0.0421],
device='cuda:0', dtype=torch.float16)
There seems to be some rounding differences (or perhaps it depends on the distribution?) but the weights again seem to be the midpoints-ish between the cutoffs.
0] + bucket_cutoffs[1])/2 (bucket_cutoffs[
tensor(-0.0270, device='cuda:0')
3] + bucket_cutoffs[4])/2 (bucket_cutoffs[
tensor(-0.0089, device='cuda:0')
-2] + bucket_cutoffs[-1])/2 (bucket_cutoffs[
tensor(0.0273, device='cuda:0')
CollectionIndexer.index
Thus far we have found centroids from a sample of our token embeddings (5%, or 759) and calculated bucket cutoffs and bucket weights for quantization. We also know what the average residual mean value is.
Now we find the closest centroids and residuals for all passages’ token embeddings, starting first by encoding all 15198 tokens with our model:
= encoder.encode_passages(collection)
embs, doclens len(doclens), doclens[:5] embs.shape,
[Mar 13, 01:18:21] [0] #> Encoding 1000 passages..
/usr/local/lib/python3.11/dist-packages/colbert/utils/amp.py:15: FutureWarning: `torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.
return torch.cuda.amp.autocast() if self.activated else NullContextManager()
(torch.Size([15198, 96]), 1000, [4, 20, 18, 23, 8])
Then we call save_chunk
which is inside the IndexSaver
within which some interesting things take place:
def save_chunk(self, chunk_idx, offset, embs, doclens):
= self.codec.compress(embs) compressed_embs
Looking into ResidualCodec.compress
:
def compress(self, embs):
= [], []
codes, residuals
for batch in embs.split(1 << 18):
if self.use_gpu:
= batch.cuda().half()
batch = self.compress_into_codes(batch, out_device=batch.device)
codes_ = self.lookup_centroids(codes_, out_device=batch.device)
centroids_
= (batch - centroids_)
residuals_
codes.append(codes_.cpu())self.binarize(residuals_).cpu())
residuals.append(
= torch.cat(codes)
codes = torch.cat(residuals)
residuals
return ResidualCodec.Embeddings(codes, residuals)
We’ve seen compress_into_codes
and lookup_centroids
before:
= compressor.compress_into_codes(embs, out_device='cuda')
codes_ codes_.shape
torch.Size([15198])
These codes are the centroids ID closest to each token embeddings.
5] codes_[:
tensor([654, 843, 401, 654, 926], device='cuda:0')
We then get those 15198 centroid vectors:
= compressor.lookup_centroids(codes_, out_device='cuda')
centroids_ centroids_.shape
torch.Size([15198, 96])
A reminder that our centroids and our token embeddings are unit vectors:
0].norm(), embs[0].norm() centroids_[
(tensor(1., device='cuda:0', dtype=torch.float16),
tensor(1., dtype=torch.float16))
We then find the residuals between token embeddings and centroids:
= (embs.cpu() - centroids_.cpu())
residuals_ residuals_.shape
torch.Size([15198, 96])
The next piece is super cool. We binarize
the residuals, starting by using bucketize
:
= torch.bucketize(residuals_.float().cpu(), bucket_cutoffs.cpu()).to(dtype=torch.uint8)
residuals residuals.shape
torch.Size([15198, 96])
0][:5] residuals[
tensor([8, 7, 7, 8, 7], dtype=torch.uint8)
0][:5] residuals_[
tensor([6.1035e-05, 0.0000e+00, 0.0000e+00, 1.9073e-06, 0.0000e+00],
dtype=torch.float16)
1][10:20] residuals_[
tensor([ 0.0234, -0.0100, 0.0046, -0.0078, -0.0111, 0.0249, -0.0081, -0.0048,
0.0270, 0.0037], dtype=torch.float16)
6:10] bucket_cutoffs[
tensor([-0.0017, 0.0000, 0.0019, 0.0042], device='cuda:0')
min(), residuals.max() residuals.
(tensor(0, dtype=torch.uint8), tensor(15, dtype=torch.uint8))
The values of residuals
are now the ID (indices) of the buckets that the residual values fall into!
= residuals.unsqueeze(-1).expand(*residuals.size(), config.nbits)
residuals residuals.shape
torch.Size([15198, 96, 4])
We add a space for 4-bits per residual.
= torch.arange(0, config.nbits, device='cuda', dtype=torch.uint8)
arange_bits arange_bits
tensor([0, 1, 2, 3], device='cuda:0', dtype=torch.uint8)
= residuals.cpu() >> arange_bits.cpu()
residuals residuals.shape
torch.Size([15198, 96, 4])
0][:5] residuals[
tensor([[8, 4, 2, 1],
[7, 3, 1, 0],
[7, 3, 1, 0],
[8, 4, 2, 1],
[7, 3, 1, 0]], dtype=torch.uint8)
= residuals & 1 residuals
0][:5] residuals[
tensor([[0, 0, 0, 1],
[1, 1, 1, 0],
[1, 1, 1, 0],
[0, 0, 0, 1],
[1, 1, 1, 0]], dtype=torch.uint8)
We have now converted the bucket ID into the actual 4-bit binary value it represents.
= np.packbits(np.asarray(residuals.contiguous().flatten()))
residuals_packed = torch.as_tensor(residuals_packed, dtype=torch.uint8)
residuals_packed = residuals_packed.reshape(residuals.size(0), config.dim // 8 * config.nbits)
residuals_packed residuals_packed.shape
torch.Size([15198, 48])
0][:5] residuals_packed[
tensor([ 30, 225, 225, 238, 238], dtype=torch.uint8)
f"30 in binary: {bin(30)[2:].zfill(8)}"
'30 in binary: 00011110'
For each residual vector with 96 values, each value is represented with 4-bits (e.g. 0, 0, 0, 1). Every 8 bits are stored into an integer (e.g. 0001 and 1110 concatenate to become the integer 30) so we have cut the number of values in half (from 96 to 48).
These residuals would be stored in 0.residuals.pt.
_build_ivf
This is a critical piece—the mapping between passages and centroids!
= codes_.sort()
codes = codes.indices, codes.values ivf, values
Token embeddings IDs:
ivf
tensor([ 936, 1171, 2363, ..., 12051, 12147, 12161], device='cuda:0')
Centroid IDs:
values
tensor([ 0, 0, 0, ..., 1023, 1023, 1023], device='cuda:0')
ivf.shape, values.shape
(torch.Size([15198]), torch.Size([15198]))
ivf
contains the token embedding ID (the indices of codes_
) and values
contains the centroid ID (the values of codes_
).
We then get the number of tokens per centroid ID:
= torch.bincount(values, minlength=num_partitions)
ivf_lengths ivf_lengths
tensor([10, 11, 17, ..., 17, 9, 29], device='cuda:0')
ivf_lengths.shape
torch.Size([1024])
sum() ivf_lengths.
tensor(15198, device='cuda:0')
colbert/indexing/utils.py: optimize_ivf
We have 1000 documents containing a total of 15198 tokens.
= sum(doclens)
total_num_embeddings len(doclens), total_num_embeddings
(1000, 15198)
Instantiating an empty mapping between token embeddings IDs and passage IDs
= torch.zeros(total_num_embeddings, dtype=torch.int)
emb2pid emb2pid.shape
torch.Size([15198])
The indices of doclens
are passage IDs pid
. The values are the number of tokens in the document dlength
.
= 0
offset_doclens for pid, dlength in enumerate(doclens):
+ dlength] = pid
emb2pid[offset_doclens: offset_doclens += dlength offset_doclens
emb2pid.shape
torch.Size([15198])
The first 4 token embeddings correspond to the first passage, the next 20 token embeddings to the second passage, and so on.
50] emb2pid[:
tensor([0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3,
3, 3], dtype=torch.int32)
Recall that ivf
contained as values the token embeddings IDs which are the indices of emb2pid
. The values of emb2pid
are passage IDs. Indexing into emb2pid
with ivf
pulls out the passage IDs corresponding to tokens. Note that ivf
is sorted by centroid ID.
ivf
tensor([ 936, 1171, 2363, ..., 12051, 12147, 12161], device='cuda:0')
values
tensor([ 0, 0, 0, ..., 1023, 1023, 1023], device='cuda:0')
= emb2pid[ivf.cpu()]
new_ivf 5] new_ivf.shape, new_ivf[:
(torch.Size([15198]), tensor([ 55, 69, 143, 416, 471], dtype=torch.int32))
new_ivf
tensor([ 55, 69, 143, ..., 795, 800, 800], dtype=torch.int32)
The first token embedding corresponding to centroid ID of 0 corresponds to passage ID 55.
936] emb2pid[
tensor(55, dtype=torch.int32)
new_ivf
is a mapping from its indices (token embeddings) to values (passage IDs) which is now aligned to the ivf_lengths
tensor which contains number of tokens per centroid ID (which came from values
).
Next, we iterate through ivf_lengths
, which contains the number of tokens per centroid ID. For each length
we get the unique passages IDs from new_ivf
, and append that to unique_pids_per_centroid
. The number of unique pids for that centroid is added to new_ivf_lengths
.
= []
unique_pids_per_centroid = []
new_ivf_lengths
= 0
offset for length in tqdm.tqdm(ivf_lengths.tolist()):
= torch.unique(new_ivf[offset:offset+length])
pids
unique_pids_per_centroid.append(pids)0])
new_ivf_lengths.append(pids.shape[+= length
offset = torch.cat(unique_pids_per_centroid)
ivf = torch.tensor(new_ivf_lengths) new_ivf_lengths
100%|██████████| 1024/1024 [00:00<00:00, 35975.77it/s]
<ipython-input-138-6c68981e98f9>:11: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
new_ivf_lengths = torch.tensor(ivf_lengths)
ivf.shape, new_ivf_lengths.shape
(torch.Size([11593]), torch.Size([1024]))
Note that there are now fewer values in ivf
than 15198 since we are only capturing the unique pids per centroid.
5] ivf[:
tensor([ 55, 69, 143, 416, 471], dtype=torch.int32)
5] new_ivf_lengths[:
tensor([10, 11, 17, 1, 34], device='cuda:0')
new_ivf_lengths
is the count of unique passage IDs per centroid. So, for example the first 10 pids correspond to centroid ID 0
.
ivf
and new_ivf_lengths
would be stored in ivf.pid.pt.
After updating metadata, this completes the indexing process in RAGatouille and ColBERT!
Final Thoughts
There were of course many details that I didn’t fully explain in this walkthrough, and since I wasn’t able to exactly replicate some of the indexing artifacts there may be some errors in my code, but I think I both covered and understood the main components to creating an index. Getting to this stage involved a lot of discussion with Claude. I used AnswerAI’s toolslm to create context from the RAGatouille and ColBERT repos to provide as Claude project knowledge. I also pored through the codebase for hours, making sure to trace my steps from method-to-method. While I could do more deep dives into the individual components of indexing, I feel satisfied with this walk through for now. I hope you enjoyed this blog post!