Batch Size Causes BertModel
Forward Pass Divergence Between torch==2.4.1
and torch==2.5.0
for colbert-ai
.
BertModel
forward pass divergence between PyTorch versions 2.0.1
and 2.1.0
. Certain batch sizes yield different model layer outputs between PyTorch version, while other batch sizes don’t.
Background
I’ve recently been documenting how PyTorch version changes impact stanford-futuredata/ColBERT (colbert-ai
on PyPI) intermediate and final index artifacts. The index artifact I’ll focus on in this blog post is the very important local_sample_embs
tensor. This is the sample of token embeddings used to calculate the centroids, which are later on used during search. Instead of loading and comparing full document token embeddings, ColBERT’s PLAID index compares centroid IDs (integers) and compressed residuals (low bit vectors) in the first three stages of the search pipeline, only decompressing residuals in the final stage. This reduces storage footprint and search latency.
When comparing local_sample_embs
(torch.float16
) between torch==2.4.1
and torch==2.5.0
, using atol=1e-4 and rtol=1e-3 in torch.allclose
:
torch.allclose: False
Mean Acc: 0.7978946566581726
MAD: 1.2740434613078833e-05
Max Abs Diff: 0.00115966796875
Whats the Diff?
What’s causing the local_sample_embs
to be different across PyTorch versions? Here’s how I explored it:
colbert-ai
encodes passages in batches (1600 at a time in my case, for a total of 29 batches across 46107 passages) so I compared model layer outputs for each batch between PyTorch versions using the following script:
= Checkpoint("answerdotai/answerai-colbert-small-v1", colbert_config=config)
checkpoint = torch.load(f"{MOUNT}/{project}/{date}-{source}-{nranks}/sample_pids.pt")
sample_pids
= 0
idx for idx in range(29):
= passages['text'][list(sample_pids)[1600*idx:1600*(idx+1)]]
docs = checkpoint.doc_tokenizer.tensorize(docs, bsize=config.index_bsize)
text_batches, reverse_indices = text_batches[0][0]
input_ids = text_batches[0][1]
attention_mask
with torch.cuda.amp.autocast():
= {}
outputs_dict def capture_output(name):
def hook_fn(module, input, output):
= output[0].detach()
outputs_dict[name] return hook_fn
= []
hooks for i in range(12): hooks.append(checkpoint.bert.encoder.layer[i].register_forward_hook(capture_output(f"{i}")))
with torch.no_grad(): D = checkpoint.bert(input_ids, attention_mask=attention_mask)[0]
for h in hooks: h.remove()
f"{MOUNT}/{project}/{date}-{source}-{nranks}/lse_outputs_dict_{idx}.pt")
torch.save(outputs_dict, print(f"lse_outputs_dict_{idx} saved!")
Batch idx 2
, 3
, 4
, 5
, 12
, 15
, 18
, 19
, 20
, and 26
fail the torch.allclose
comparison between BertModel
layer outputs. Why is that the case?
Here are the tensor shapes for each batch of input_ids
:
0 torch.Size([32, 71])
1 torch.Size([32, 72])
2 torch.Size([32, 79]) # fails
3 torch.Size([32, 78]) # fails
4 torch.Size([32, 77]) # fails
5 torch.Size([32, 194]) # fails
6 torch.Size([32, 70])
7 torch.Size([32, 73])
8 torch.Size([32, 71])
9 torch.Size([32, 68])
10 torch.Size([32, 66])
11 torch.Size([32, 115])
12 torch.Size([32, 82]) # fails
13 torch.Size([32, 115])
14 torch.Size([32, 115])
15 torch.Size([32, 80]) # fails
16 torch.Size([32, 72])
17 torch.Size([32, 64])
18 torch.Size([32, 90]) # fails
19 torch.Size([32, 82]) # fails
20 torch.Size([32, 86]) # fails
21 torch.Size([32, 63])
22 torch.Size([32, 71])
23 torch.Size([32, 62])
24 torch.Size([32, 61])
25 torch.Size([32, 67])
26 torch.Size([32, 83]) # fails
27 torch.Size([32, 69])
28 torch.Size([32, 72])
The batches that diverge have a second dimension of: 79
, 78
, 77
, 194
, 82
, 80
, 90
, 86
, and 83
.
The batches that do not diverge have a second dimension of: 71
, 72
, 70
, 73
, 68
, 66
, 115
, 64
, 63
, 62
, 61
, 67
, 69
.
It is interesting to note that these sets do not intersect. To test if batch size is the root cause, I index into the first 70 items of the diverging batches and run the layer output comparison again:
= 70
batch_idx for idx in [2, 3, 4, 5, 12, 15, 18, 19, 20, 26]:
= passages['text'][list(sample_pids)[1600*idx:1600*(idx+1)]]
docs = checkpoint.doc_tokenizer.tensorize(docs, bsize=config.index_bsize)
text_batches, reverse_indices = text_batches[0][0][:, :batch_idx]
input_ids = text_batches[0][1][:, :batch_idx]
attention_mask
with torch.cuda.amp.autocast():
= {}
outputs_dict def capture_output(name):
def hook_fn(module, input, output):
= output[0].detach()
outputs_dict[name] return hook_fn
= []
hooks for i in range(12): hooks.append(checkpoint.bert.encoder.layer[i].register_forward_hook(capture_output(f"{i}")))
with torch.no_grad(): D = checkpoint.bert(input_ids, attention_mask=attention_mask)[0]
for h in hooks: h.remove()
f"{MOUNT}/{project}/{date}-{source}-{nranks}/lse_outputs_dict_{idx}.pt")
torch.save(outputs_dict, print(f"lse_outputs_dict_{idx} saved!")
for idx in range(29):
= torch.load(f"{MOUNT}/{root_a}/lse_outputs_dict_{idx}.pt")
a = torch.load(f"{MOUNT}/{root_b}/lse_outputs_dict_{idx}.pt")
b
for i in range(len(a.keys())):
= a[f"{i}"]
a_ = b[f"{i}"]
b_ assert _close(a_, b_)
Where _close
is defined as:
def _close(a, b, default=False):
= a.dtype
gtype if gtype in [torch.uint8, torch.int32, torch.int64]:
if a.shape == b.shape: return torch.equal(a,b)
return False
if not default:
if gtype == torch.float32:
= 1e-6, 1e-5
atol, rtol elif gtype == torch.bfloat16:
= 1e-3, 1e-2
atol, rtol else:
= 1e-4, 1e-3
atol, rtol else:
= 1e-8, 1e-5
atol, rtol return torch.allclose(a, b, rtol=rtol, atol=atol)
All model layer outputs match between PyTorch versions! Just to be sure, I tried batch_idx
of 61
, 64
and 68
, and all model layer outputs match.
Closing Thoughts
Earlier today I read the Thinking Machines blog post on why LLM inference is non-deterministic. The main cause for non-determinism is that not all tensor ops are batch size invariant:
As it turns out, our request’s output does depend on the parallel user requests. Not because we’re somehow leaking information across batches — instead, it’s because our forward pass lacks “batch invariance”, causing our request’s output to depend on the batch size of our forward pass.
To explain batch invariance, let’s simplify the system and look solely at matmuls. You can assume that all matmul implementations are “run-to-run deterministic.”This is not totally true, but most common matmul implementations do have this property. However, they are not “batch-invariant.” In other words, when the batch size changes, each element in the batch can get different results.
While I’m not going to (can’t?) dig into PyTorch to understand what is causing batch size variance between 2.4.1
and 2.5.0
, I think there is enough evidence to show that something in PyTorch is causing it. If you disagree with that conclusion, please @ me on Twitter!