A hands-on investigation into how sequence packing interacts with Flash Attention in HuggingFace Transformers. Through print statements and code exploration, I discovered that position_ids are crucial for sequence packing to work correctly—without them, the wrong Flash Attention function gets called, leading to incorrect outputs and loss values. This post walks through the debugging process, comparing packed sequences with padded batches, and reveals the critical requirement for properly constructed position_ids in sequence-packed training.
In this blog post, I’m walking through transformers code to start exploring functionality between sequence packing and Flash Attention. I’m new to both concepts, so this is purely an exploratory exercise.
To assist my exploration, I’ve forked the Transformers library and added print statements at key junctures related to sequence packing and FA2. Referencing the original repo here’s where I’ve inserted print statements:
The goal of these print functions initially was to understand how cu_seqlens is utilized (if at all) and then after realizing it wasn’t being used, my goal became to understand which function form flash_attn is being used: flash_attn_func or flash_attn_varlen_func?
Initial Example: Passing in input_ids, cu_seqlens and max_seqlen to the SmolLM2-135M Forward Pass
At first, based on a Claude-generated example, I passed in the following fake input data.
I was surprised to see that flash_attn_func was called, because IIUC that doesn’t handle sequence packed inputs. Looking at its function signature, there’s no cu_seqlens or similar parameter:
In particular, this line was of interest: torch.diff(position_ids, dim=-1) >= 0
In the following contrived example, position_ids is not a list of consecutive numbers (which seems to be the default value constructed is no position_ids value is passed to the model’s forward pass).
Which is deconstructed into cu_seq_lens_q and cu_seql_lens_k which are then passed as arguments to flash_attn_varlen_func.
The main takeaway from this: Flash Attention will not handle sequence packing correctly unless you pass in position_ids.
Packed Sequence Loss
In the remaining sections of this blog post, I’ll explore how to correctly handle calculating loss for a packed sequence.
output.logits.shape
torch.Size([1, 10, 49152])
Following how labels are constructed in HuggingFace’s DataCollatorWithFlattening, the first token in each sequence is replaced with -100. This is because the HuggingFace CausalLM loss function handles the shifting of labels to allow next-token prediction.
We can see that the labels have been shifted to the left by 1 element, and a -100 ignore index has been added to the right, which is needed because the last token in the input doesn’t predict anything.
Calculating the loss using F.cross_entropy directly and the model’s loss_function (providing it unshifted labels):
loss = F.cross_entropy( output.logits.reshape(-1, output.logits.size(-1)).float(), shift_labels.reshape(-1))loss
tensor(20.2832, device='cuda:0')
model.loss_function(output.logits, labels, 49152)
tensor(20.2832, device='cuda:0')
Padded Batch Loss
Sequence packing shouldn’t change the loss value of a given input batch. To test this, I’ll construct a padded batch from our fake data and calculate its outputs, labels and loss.
Noting that I haven’t pass any position_ids and the printed output shows us that flash_attn_func is indeed the “vanilla” implementation of Flash Attention for padded batches:
To confirm that not passing in position_ids does in indeed make HuggingFace use the wrong Flash Attention implementation for a packed sequence, I’ll compare the logits and loss:
I’ll reiterate that I’m not familiar with how sequence packing is implemented (in HuggingFace or ModernBERT) and even less familiar with how Flash Attention is implemented. That being said, this cursory investigation allowed me to understand high-level concepts of how these two interact. My key takeaway is that the correct position_ids need to be passed to the model otherwise HuggingFace will not use the correct flash_attn_varlen_func for sequence packed inputs and that will result in incorrect logits and loss values.