Analyzing PyTorch 2.0 Release Notes for ColBERT Dependency Impact
Background
In this blog post, I walk through the PyTorch 2.0 Release Note PRs where I’m estimating there will be some kind of impact to ColBERT as I update the torch dependency to 2.0 (ColBERT is currently dependent on torch==1.13.1). The level of detail in my analysis of a PyTorch PR is not necessarily signifying its importance. In some cases, I am using this analysis as an opportunity to get more familiar with details about the ColBERT codebase (such as the number of instances where torch.cat
is used).
Full Release Notes Analysis
You can find my item-by-item PyTorch 2.0 release notes analysis for ColBERT in this Google Sheet.
Overall, across 508 PyTorch PRs, I have estimated that 455 of them are not applicable to ColBERT and 42 (8.7%) have a potential impact. I was unclear if or how 11 of the PyTorch 2.0 PRs would affect ColBERT (2.6%).
There are 5 sections in the PyTorch 2.0 Release Notes, here’s a break down of PRs by section that will have a potential (or unclear) impact on ColBERT:
Section | # of PRs |
---|---|
Improvements | 22 |
Bug Fixes | 21 |
Performance | 6 |
Backwards Incompatible Changes | 3 |
Deprecations | 1 |
In my estimation, the improvements and bug fixes PRs in PyTorch 2.0 will only improve the performance of ColBERT. That being said, there still may be noticeable differences in indexing, search, and training artifacts which may break tests I write for before/after comparisons.
Fortunately, only two backward-compatible changes may affect ColBERT.
There are 11 subsections in the PyTorch 2.0 Release Notes, here’s a break down of PRs that will have a potential (or unclear) impact on ColBERT:
Subsection | # of PRs |
---|---|
MPS | 12 |
Python API | 8 |
Cuda | 8 |
Releng | 4 |
Distributed | 4 |
Build | 4 |
ONNX | 3 |
Cpu | 2 |
Cpp API | 2 |
torch.nn API | 1 |
I am including MPS-related PRs in this analysis just in case we consider making ColBERT compatible with MPS in the future.
I’ll start by analyzing breaking changes, which are likely going to be the most impactful.
Backwards Incompatible Changes
PR #92731
Gradients are now set to None instead of zeros by default in
torch.optim.*.zero_grad()
andtorch.nn.Module.zero_grad()
(#92731)
There are two lines in ColBERT wherezero_grad
is called: in colbert/utils/amp.py and in colbert/training/training.py. I’m not sure how this change would affect ColBERT behavior, but flagging it as something to keep in mind.
PR #92306
Algorithms
{Adadelta, Adagrad, Adam, Adamax, AdamW, ASGD, NAdam, RAdam, RMSProp, RProp, SGD}
default to fasterforeach
implementation when on CUDA + differentiable=False
This PR adds the following lines to AdamW
, which is used in ColBERT’s `training.py:
if foreach is None:
= _default_to_foreach(
foreach
[params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps],=differentiable) differentiable
foreach
is None
in ColBERT, as it’s not specified and that’s what it defaults to (foreach: Optional[bool] = None
):
= AdamW(filter(lambda p: p.requires_grad, colbert.parameters()), lr=config.lr, eps=1e-8) optimizer
Since this is described as a “faster implementation”, I would expect the training time to decrease. I’ll be on the lookout for this when comparing training time benchmarks before/after upgrading to PyTorch 2.0.
PR #88913
Update
torch.tensor
andnn.Parameter
to serialize all their attributes (#88913)
It’s unclear what this PR is doing but since it’s touching the nn.Parameter
definition, I’m flagging it.
Bug Fixes
I would expect PyTorch PRs that introduce bug fixes to only positively affect ColBERT. That being said, a positive effect is still a change and can potentially impact concrete artifacts during indexing, search and training. I am planning on curating a baseline set of these artifacts before I test the upgrade to PyTorch 2.0.
PR #92810
Fix SIGSEGV on a big-endian machine when reading pickle data (#92810)
The PR states:
This PR fixes SIGSEGV on a big-endian machine when reading pickle data.
I’m not familiar with the term “big-endian” so had to look it up:
A big-endian system stores the most significant byte of a word at the smallest memory address and the least significant byte at the largest. A little-endian system, in contrast, stores the least-significant byte at the smallest address. (source)
Claude’s understanding of the cpp method affected by this PR is that it affects the torch.load
method. There are a number of ColBERT files that use torch.load
:
- colbert/utils/coalesce.py uses it to load
codes.pt
(centroid id for each embedding in chunk) and loadresiduals.pt
(16-bits residual for each embedding in chunk). - colbert/search/index_loader.py uses it to load
ivf.pid.pt
orivf.pt
. - colbert/utils/utils.py uses it in
torch_load_dnn
,load_checkpoint_raw
andload_ranking
. - colbert/indexing/index_manager.py uses it in
load_index_part
. - colbert/indexing/codecs/residual_embeddings.py uses it in
load_codes
andload_residuals
. - colbert/indexing/codecs/residual.py uses it in
load
to load centroids,avg_residual
, andbucket_cutoffs, bucket_weights
. - colbert/indexing/collection_indexer.py uses it in
_concatenate_and_split_sample
. - colbert/index_updater.py uses it in
_load_disk_ivf
to loadivf.pid.pt
orivf.pt
, inload_chunk_codes
, and in_load_chunk_residuals
. - colbert/tests/index_coalesce_test.py uses it to load multi-file
codes.pt
, single-filecodes.pt
, multi-fileresiduals.pt
and single-fileresiduals.pt
.
PR #92315
Fix NVML visible device parsing (#92315)
CUDA_VISIBLE_DEVICES can contain either ordinals or UUIDs Extend the logic to be able to parse it by UUID
I don’t think this would affect any artifacts created during indexing/searching/training but would make it easier for PyTorch to identify GPUs.
PR #93095
This PR fixes an error in #93006 when using topk
, which is used in the following places in ColBERT:
get_cells
in colbert/search/candidate_generation.pyscore_pids
in colbert/search/index_storage.py to filter centroids by the threshold, filterpids
using pruned centroid scores and filterpids
using full centroid scores- filter scores in
colbert_score_reduce
for the"flipr"
interaction method: link1, link 2
Claude recommended also consider uses of max
and argmax
to be potentially impacted:
get_cells
ifncells==1
in colbert/search/candidate_generation.py- colbert/modeling/colbert.py in
ColBERT.score
- colbert/modeling/colbert.py in
colbert_score_reduce
- colbert/indexing/codecs/residual.py in
ResidualCodec.compress_into_codes
on GPU and CPU - colbert/search/strided_tensor.py in `StridedTensor.lookup
- colbert/search/strided_tensor_core.py in
StridedTensorCore.__init__
- colbert/modeling/checkpoint.py in `Checkpoint.score
- colbert/indexing/utils.py in
optimize_ivf
I would assume that the only impact this PR would have on PyTorch is avoiding any errors during the use of topk
(no such errors have been reported on in the open issues).
PR #85596
Fix: half reduction with multiple sub-iterators (#85596)
Fixes cuda low-precision reductions on large tensors produce wrong results #74438:
Reductions with low precision inputs (half, bfloat16) that need sub-iterators accumulate directly in output and thus truncate intermediate results
This would fix any issues related to the use of half
in the following ColBERT files:
- in
ResidualCodec
forcentroids
,avg_residual
,bucket_weights
, savingcentroids
, compressing token embeddings, calculating cosine similarity between token embeddings and centroids on GPU. - colbert/search/candidate_generation.py in
CandidateGeneration.generate_candidates
for queries when using the GPU. - colbert/indexing/collection_indexer.py in
CollectionIndexer._sample_embeddings
when savinglocal_sample_embs
, inCollectionIndexer.train_kmeans
for centroids on the GPU, and inCollectionIndexer.index
when saving token embeddings. - colbert/modeling/colbert.py in
ColBERT.doc
for document token embeddings on the GPU
If this PR fix is relevant to the use of half
in the above files I would expect there to be numeric differences in indexing/search artifacts.
PR #86492
Fixes a memory leak by making autocast cache global instead of thread-local (#86492)
This PR adds a PyTorch test which:
Verifies that the autocast cache is global. This is done by mocking out cache clearing at the end of the forward pass, running forward+backward with an explicit call to autocast in the backward, and verifying that the weight only get cast to float16 once.
Claude’s analysis:
This PyTorch enhancement directly benefits ColBERT. By making the autocast cache global, this PR provides a performance improvement when training ColBERT with mixed precision. It reduces redundant computations during the backward pass, leading to faster and more efficient training without changing the model’s functionality.
ColBERT uses torch.cuda.amp.autocast
in the following files:
- colbert/utils/amp.py in
MixedPrecisionManager.context
which is used in colbert/training/training.py duringtrain
, and in colbert/modeling/checkpoint.py inCheckpoint.query
andCheckpoint.doc
to calculate query and document token embeddings, respectively. - colbert/distillation/scorer in
Scorer._score_pairs
.
PR #88898
Fixes PyTorch #88873:
torch_extension.py should be fixed or ninja compile will fail.
Gemini’s analysis: Because ColBERT uses the very feature this PR is fixing (torch.utils.cpp_extension.py), the change is directly relevant. This bug fix is important for any developer or user who needs to compile and run ColBERT on a Windows machine. It ensures that ColBERT’s performance-critical custom CUDA code can be built correctly, preventing potential compilation errors.
ColBERT uses torch.utils.cpp_extension
in the following files:
- colbert/modeling/colbert.py in
ColBERT.try_load_torch_extensions
to loadsegmented_lookup.cpp
on CPU. - colbert/search/index_storage.py in
IndexScorer.try_load_torch_extensions
to load filter_pids.cpp and decompress_residuals.cpp. - colbert/search/strided_tensor.py in
StridedTensor.try_load_torch_extensions
to loadsegmented_lookup.cpp
on CPU. - colbert/indexing/codecs/residual.py in
ResidualCode.try_load_torch_extensions
to load decompress_residuals.cpp
This might be related to ColBERT #317
PR #90149
Fix a static initialization order fiasco in c10d (#90149)
Gemini’s analysis: Because ColBERT’s multi-GPU functionality is built directly on the PyTorch library that this PR is fixing, this change is highly relevant. This is a crucial stability improvement that makes ColBERT’s distributed training and inference more reliable by preventing potential crashes at startup.
If Gemini’s analysis is correct, this will make ColBERT’s multi-GPU functionality more reliable and might address related open issues.
PRs #86956, #86958
Fix issues with non-contiguous Tensor handling (#86956, #86958)
These are both MPS-related and will only affect ColBERT if we in the future choose to make it compatible with MPS
PRs #94119, #86240, #91520, #94442, #94386
Fix issues with ops implementation torch.median (#90326, #88807), torch.{std,var} correction argument (#91203), torch.index_select (#94117, #91064), torch.cumsum (#94119), torch.where (#86240), torch.nn.Embedding (#82809), torch.nn.Softplus (#88555), torch.nn.functional.pad (#89864), torch.max (#91520), padding functions (#91522), torch.nn.functional.upsample (#91669), pooling functions (#91519, #94348), torch.nn.{NLLLoss,SmoothL1Loss} (#94226), torch.nn.SoftPlus (#94256), torch.masked_fill (#94263), torch.fill_ (#94479), torch.median (#94489), torch.nonzero (#94442), torch.nn.BatchNorm (#94351), torch.{min,max} (#94386), torch.nn.GELU (#94529), torch.nn.LSTM (#94889), #95137),torch.nn.Conv2d(#95078),torch.nn.functional.bilinear(#94892),torch.copy_ (#95272),torch.max_pool2d(#94963),torch.div (#95769)
ColBERT uses topk in the following files:
get_cells
in colbert/search/candidate_generation.pyscore_pids
in colbert/search/index_storage.py to filter centroids by the threshold, filterpids
using pruned centroid scores and filterpids
using full centroid scores- filter scores in
colbert_score_reduce
for the"flipr"
interaction method: link1, link 2
ColBERT uses max/argmax in the following files:
get_cells
ifncells==1
in colbert/search/candidate_generation.py- colbert/modeling/colbert.py in
ColBERT.score
- colbert/modeling/colbert.py in
colbert_score_reduce
- colbert/indexing/codecs/residual.py in
ResidualCodec.compress_into_codes
on GPU and CPU - colbert/search/strided_tensor.py in `StridedTensor.lookup
- colbert/search/strided_tensor_core.py in
StridedTensorCore.__init__
- colbert/modeling/checkpoint.py in `Checkpoint.score
- colbert/indexing/utils.py in
optimize_ivf
ColBERT uses torch.cumsum
in the following files to calculate offsets
.:
- colbert/indexing/utils.py
- colbert/search/strided_tensor_core.py
- colbert/search/strided_tensor.py
- colbert/search/index_storage.py
ColBERT uses torch.where
in the following files:
- colbert/modeling/checkpoint.py to pool embeddings within each cluster.
ColBERT uses torch.nonzero
in the following files:
- colbert/index_updater.py to construct mask of where pids to be removed appear in ivf.
These are all MPS-related and will only affect ColBERT if we in the future choose to make it compatible with MPS.
PRs #91120, #94464
Fix issues with torch.bool for Unary ops (#91120), scatter ops (#94464)
Claude’s analysis: Claude: The PR fixes compatibility issues where boolean tensors needed to be cast to int8 on older macOS versions, then cast back. This would be important for ColBERT’s masking operations which rely heavily on boolean tensors for attention and padding masks.
These are MPS-related and will only affect ColBERT if we in the future choose to make it compatible with MPS.
PR #94484
Properly cast torch.int64 to torch.int32 for reduction ops and raise warning. (#94484)
Claude’s analysiss: The PR changes TORCH_CHECK (which throws an error) to TORCH_WARN_ONCE (which just warns) and automatically casts int64 to int32 for min/max operations. This would allow ColBERT to run on MPS with int64 tensors instead of failing, though with potential precision loss.
MPS-related and will only affect ColBERT if we in the future choose to make it compatible with MPS.
PRs #91120, #94464
Fix handling of ops taking multiple dtypes as input (#91197, #91514)
Claude’s analysis: The PR fixes MPS scatter to handle type mismatches between source and destination tensors automatically.
MPS-related and will only affect ColBERT if we in the future choose to make it compatible with MPS.
PRs #91786, #94662
Fix handling of channels last for torch.cat (#91786, #94662), torch.Conv2d (#91822, #94384), torch.nn.{ELU,ReLU,Hardswish} (#94664), torch.nn.BatchNorm (#94760), torch.nn.MaxPool2d (#94877)
ColBERT uses .cat
in the following files:
- colbert/indexing/utils.py to concatenate a list of tensors (
unique_pids_per_centroid
) into a single tensor (ivf
). - colbert/indexing/index_manager.py to concatenate multiple path names.
- to calculate
offsets
. - to add padding to a tensor.
- to concatenate
scores
. - to concatenate document token embeddings.
- to concatenate
codes
(centroid IDs). - to concatenate
residuals
. - to concatenate
codes
(centroid IDs). - to concatenate
centroids
. - to concatenate document token embeddings.
- to concatenate
packed_tensor
. - to concatenate
all_orders
. - to concatenate
all_lengths
. - to concatenate
compressed_embs.codes
(centroid IDs corresponding to document token embeddings). - to concatenate
compressed_embs.residuals
. - to concatenate
doclens
. - to concatenate
codes
(centroid IDs). - to concatenate
residuals
. - to concatenate
codes
(centroid IDs). - to concatenate
residuals
. - to concatenate
batches
(of queries). - to concatenate document token embeddings (in order).
- to concatenate
mask
for document token embeddings (in order). - to concatenate
code
chunks. - to concatenate
offsets
. - to concatenate
approx_scores
. - to concatenate
ids
(for query tokens). - to concatenate
masks
(for query tokens). - to concatenate prefix token.
MPS-related and will only affect ColBERT if we in the future choose to make it compatible with MPS.
PRs #94259, #94278, #95145, #95762, #95905
Fix view operations handling (#94259, #94278,#95145, #95762, #95905)
Claude’s analysis: This PR fixes crashes in view operations when slicing with incorrect lengths, which ColBERT uses for tensor reshaping and indexing operations.
MPS-related and will only affect ColBERT if we in the future choose to make it compatible with MPS.
PR #87853
Move incorrectly placed closing curly brace of extern “C” block (#87853)
Gemini’s analysis: This pull request is a foundational C++ correctness fix for the PyTorch framework. Because ColBERT compiles its own C++ extensions that depend on these core headers, this change is directly beneficial. It ensures the stability and reliability of ColBERT’s own build process, preventing potential compilation failures.
ColBERT’s C++ extensions:
Unsure how to measure the impact but if it’s about reliability perhaps it will address some open issues. TBD.
PR #93322
Fix MSVC compiler error in basic_ops.h (#93322)
Gemini’s take: This pull request is a crucial build-system and correctness fix. It directly impacts ColBERT by ensuring that its custom C++ code can be compiled successfully on Windows machines that use the affected MSVC compiler. Without this fix, users in that environment would be unable to run ColBERT. This change makes ColBERT’s build process more robust and widens its platform compatibility.
TBD if this addressed open issues related to Windows machines.
PR #89310
Fix a bug that redefines __STDC_FORMAT_MACROS (#89310)
Gemini’s take: This pull request provides a stability and correctness fix to the underlying PyTorch framework. Because ColBERT compiles its own C++ code that depends on these core PyTorch headers, this change is directly beneficial. It makes ColBERT’s own compilation process more reliable and prevents a potential class of build failures.
Unsure how to measure the impact but if it’s about reliability perhaps it will address some open issues. TBD.
PR #90411
Add manual cuda deps search logic (#90411)
Gemini’s take: This PyTorch pull request adds a new mechanism to help PyTorch find its essential CUDA libraries (cuBLAS and cuDNN) on Linux systems.
Unsure how to measure the impact but if it’s about reliability perhaps it will address some open issues. TBD.
PR #89759
Workaround for NumPy builds that ship with a broken Dlpack deleter (#89759)
TBD if this improves reliability as ColBERT uses NumPy/.
PR #86288
Workaround MSVC ICE due to constexpr char* template argument (#86288)
Gemini’s take: It directly impacts ColBERT by ensuring that its custom C++ code can be compiled successfully on Windows machines that use an affected MSVC compiler.
TBD if this addressed open issues related to Windows machines.
PR #85408
Add define to fix issue with compatibility with latest Windows SDK (#85408)
Gemini’s take: It directly impacts ColBERT by ensuring that the underlying PyTorch framework can be successfully built on modern Windows environments.
TBD if this addressed open issues related to Windows machines.
Improvements
These PyTorch PRs are related to improvements, which could affect ColBERT by speeding things up (and therefore seeing a speed up in indexing/search/training time) or changing baseline indexing/search/training artifacts if improvements impact numeric precision.
PR #56398
Set std/var correction overloads default value to None (#56398)
Unclear if and how this affects ColBERT but highlighting it since it changes code in PyTorch’s aten/src/ATen/native.
PR #86309
Add support for int32 indices in index/index_put ops (#86309)
I think this PR is related to this ColBERT PR which I think is related to this line of code in IndexScorer
.
PR #87022
Enable where to have cpu scalar args (#87022)
ColBERT uses torch.where
in the following files:
- colbert/modeling/checkpoint.py to pool embeddings within each cluster.
Unclear if this will affect ColBERT but there are currently no open issues related to torch.where
.
PR #90914
Add support for NumPy scalars to torch.tensor.asarray (#90914)
Found 1 use of asarray
but it doesn’t deal with a scalar so probably won’t be affected.
PR #85926
Enable out variant of torch.max(#85926)
Unclear what this PR does but highlighting it since ColBERT uses torch.max
.
PR #91846
Implement faster gradient clipping using foreach function (#91846)
ColBERT uses torch.nn.utils.clip_grad_norm_
in two lines:
IIUC this won’t affect ColBERT since it doesn’t set foreach
in torch.nn.utils.clip_grad_norm_
.
PR #92334
Enable DDP to handle custom dataclass forward outputs (#92334)
ColBERT does use DistributedDataParallel (in train
) butit’s not being passed a custom dataclass, it’s being passed a colbert
model so I don’t think this PR applies.
PR #89137
Skip collective communications for NO_SHARD in clip_grad_norm_ (#89137)
ColBERT doesn’t use FullyShardedDataParallel but it does use torch.nn.utils.clip_grad_norm
so not sure if this PyTorch PR affects it.
PR #90028
Apply the “largest” dtype across all parameters/gradients as defined by PyTorch’s type promotion semantics for the total norm returned in clip_grad_norm_ for low prec grads (#90028)
ColBERT doesn’t use FullyShardedDataParallel but it does use torch.nn.utils.clip_grad_norm
so not sure if this PyTorch PR affects it.
PR #85692
Set CUDA_MODULE_LOADING to LAZY when not set by the user (#85692)
Unclear exactly what this does but it relates to the CUDA_MODULE_LOADING env var which is not set in ColBERT
PR #89172
Add an option to disable reduced precision reductions for BF16 GEMM (#89172)
Unclear exactly what this does, but in the PR they mentioned it improves H100 usage, so I’ll keep that in mind.
PR #91436
Add an env variable to disable addmm_cuda_lt kernel (#91436)
Unclear what this does, but it’s adding a variable, so it’s a new feature.
PR #86041, #93022
Clean up flatbuffer lib dependency and fixed its test to match pkl models (#86041, #93022)
I am not sure what these PRs are doing. The title refers “pkl models” which ColBERT doesn’t use to my knowledge.
PR #93898
Type corrections to avoid unnecessary static_casts (#93898)
Unclear what this PR does but it touches a lot of what seem to be core files so I’m flagging it.
PR #87245
Integrate all ONNX operators with a new JitScalarType API (#87245)
It’s onnx related, which ColBERT doesn’t use, but it also says: “this PR addresses not only the issue above, but the entire family of issues related to torch._C.Value.type() parsing when scalarType() or dtype() is not available.”
PR #87343
Add share_from_this to torch::jit::Graph (#87343)
Is ONNX related, but unclear if it affects anything else?
PR #84789
Use optional op to keep None in results for ONNX internal tests (#84789)
Is ONNX related, but unclear if it affects anything else?
PR #86218
Add fp16 support for torch.nn.Linear (#89774), torch.nn.GELU (#86218)
MPS-related and will only affect ColBERT if we in the future choose to make it compatible with MPS.
PR #91884
Add support for empty Tensors in torch.bitwise_not (#87286), torch.nn.LayerNorm (#94212), many backward functions (#94343), torch.nn.functional.hardswish (#94342), torch.topk (#91884), torch.arange (#94485), torch.linal.inv (#94551),
MPS-related and will only affect ColBERT if we in the future choose to make it compatible with MPS.
PR #91734
Add support for reduction ops on multiple axis at a time (#91734)
MPS-related and will only affect ColBERT if we in the future choose to make it compatible with MPS.
PR #94639
Add support for k greater than 16 for torch.topk (#94639)
MPS-related and will only affect ColBERT if we in the future choose to make it compatible with MPS.
PR #91576
Simplify OpenMP detection in CMake (#91576)
Claude’s take: While ColBERT should continue working, there could be subtle performance or compilation differences depending on how PyTorch’s simplified OpenMP detection affects the runtime compilation of ColBERT’s C++ extensions, particularly in multi-threaded scenarios.
Unclear what this PR is doing but flagging it as it might improve performance as Claude states.
Deprecations
These are PRs I would think would have a significant impact if applicable.
PR #92143
Deprecate tensor.mT,tensor.T,tensor.mH,tensor.H on 0D-tensors (#92143)
There are five instances where .T is used, but pretty sure none of these are 0-D tensors, will confirm:
- colbert/search/candidate_generation.py#L13: cosine similarity between centroids and query token embeddings.
- colbert/search/candidate_generation.py#L43: used in
generate_candidate_scores
which useslookup_eids
which I can’t find anywhere else in the codebase. - colbert/modeling/colbert.py#L195: Cosine similarity between query and document token embeddings.
- colbert/indexing/codecs/residual.py#L215: Cosine similarity between centroids and document token embeddings (GPU).
- colbert/indexing/codecs/residual.py#L217: Cosine similarity between centroids and document token embeddings (CPU).
Performance
Similar to improvements, I would only expect this set of PRs to improve ColBERT performance, keeping an eye on how different artifacts changed because of that.
PR #93234
Improve performance for functional.multi_head_attention_forward() (#93234, #89847)
ColBERT uses BERT, which has its own attention implementation, so this likely wouldn’t impact it unless the BERT model specifically uses torch.nn.functional.multi_head_attention_forward
or torch.nn.MultiheadAttention
.
PR #84981
Use atomicAdd for bfloat16 in Ampere and above (#84981)
Gemini’s take: This pull request directly accelerates a fundamental operation used during the training of ColBERT. By replacing a slow, emulated function with a fast, hardware-native instruction, this change leads to a noticeable increase in training speed for anyone training ColBERT with bfloat16 mixed precision on an Ampere or newer GPU.
If Gemini is correct, then I will see a speedup in training.
PR #94034
Add various performance fixes to c++ STL usage (#94034)
Gemini’s take: The changes in this PR touch several core PyTorch components that are critical to ColBERT’s operation:
- Autograd Engine (function.h): Every gradient calculation during training will benefit from these optimizations.
- CUDA Communication (comm.cpp): The code that handles broadcasting and gathering tensors across GPUs for multi-GPU training and inference is made more efficient.
- Mixed Precision (autocast_mode.h): The logic for automatic mixed precision, which is key for training ColBERT efficiently, is also slightly optimized.
If Gemini is correct, then I will see a speed up in all aspects of ColBERT.
PR #86568
Add fmsub to vectorization primitives (#86568)
Gemini’s take: This pull request is a CPU-specific performance optimization. It adds support for the fmsub (fused multiply-subtract) instruction to PyTorch’s CPU vectorization library. This allows PyTorch to perform the operation (a * b) - c in a single, faster instruction on modern CPUs that support it (e.g., via AVX or NEON).
I’m pretty sure ColBERT doesn’t use multiply-subtract, but keeping it in here just in case it comes up.
PR #92300
Fix biasadd OMP perf issue for the packed MKL SGEMM (#92300)
Gemini’s take: This pull request is a CPU-specific performance optimization. It fixes a parallelization issue within the Intel MKL (Math Kernel Library) backend for linear layers. This change improves the efficiency of adding a bias term to the output of a matrix multiplication when running on a CPU.
If Gemini is correct, I would expect a speedup on CPU.
PR #91114
Increase performance of torch.add{cmul,cdiv,mm}(#94214, #94534)torch.multinomial (#86342), faster op launch time (#86437), torch.linear (#91114), view handling (#91743, #94218), convolutions(#94661), scatter/gather (#94663)
Gemini’s take: While the Adam optimizer used by ColBERT does use the addcdiv operation, this is executed on the GPU via CUDA, not MPS. This pull request is a performance optimization for the torch.nn.Linear layer, but it is exclusively for the MPS (Metal Performance Shaders) backend.
MPS-related and will only affect ColBERT if we in the future choose to make it compatible with MPS.
Closing Thoughts
Based on my analysis, I’m optimistic about upgrading ColBERT from torch==1.13.1 to 2.0. The upgrade should deliver concrete benefits with reasonable testing overhead. Performance-wise, I’ll be watching for training time improvements from the faster foreach
optimizer implementations and expect speedups across all aspects of ColBERT from C++ optimizations and CUDA improvements. For validation, I’ll need to check for numeric differences in indexing/search artifacts from half-precision bug fixes and benchmark retrieval quality metrics after reindexing to avoid regressions. The reliability improvements should make ColBERT’s multi-GPU functionality more reliable and might address related open issues. Plus there are fixes for operations like topk
and torch.load
that ColBERT uses extensively. Most MPS-related changes will only affect ColBERT if we choose future compatibility, so they’re not immediate concerns but good to have.
My next step will be to establish training time benchmarks and indexing/retrieval/training baseline artifacts so that I can concretely monitor even subtle performance/behavior changes when using torch==2.0
in my development branch.