Understanding Sequence Packing - Initial Musings

python
deep learning
LLM
A hands-on investigation into how sequence packing interacts with Flash Attention in HuggingFace Transformers. Through print statements and code exploration, I discovered that position_ids are crucial for sequence packing to work correctly—without them, the wrong Flash Attention function gets called, leading to incorrect outputs and loss values. This post walks through the debugging process, comparing packed sequences with padded batches, and reveals the critical requirement for properly constructed position_ids in sequence-packed training.
Author

Vishal Bakshi

Published

May 4, 2025

Setup

Show pip installs and imports
!pip install -qq -U flash-attn --no-build-isolation
!pip uninstall transformers -y
!pip install git+https://github.com/vishalbakshi/transformers.git -qq

from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import inspect
import torch.nn.functional as F

model_name = "HuggingFaceTB/SmolLM2-135M"

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    attn_implementation="flash_attention_2",
    torch_dtype=torch.bfloat16,
    device_map="auto")

tokenizer = AutoTokenizer.from_pretrained(model_name)

Background

In this blog post, I’m walking through transformers code to start exploring functionality between sequence packing and Flash Attention. I’m new to both concepts, so this is purely an exploratory exercise.

To assist my exploration, I’ve forked the Transformers library and added print statements at key junctures related to sequence packing and FA2. Referencing the original repo here’s where I’ve inserted print statements:

print("\n=== FLASH_ATTENTION_FORWARD ENTRY ===")
print(f"kwargs received: {list(kwargs.keys())}")
print("\n=== _FLASH_ATTENTION_FORWARD ENTRY ===")
print(f"kwargs received: {list(kwargs.keys())}")

print("\n attention_mask")
print(attention_mask)

print("\n position_ids")
print(position_ids)

In the same file, later on:

# Contains at least one padding token in the sequence
if attention_mask is not None:
    print("attention_mask is not None")
    ...

and later on further in the _flash_attention_forward function definition:

elif position_ids is not None and (
    max_length_q is not None or (query_length != 1 and not (torch.diff(position_ids, dim=-1) >= 0).all())
):
    print("position_ids is not None and max_length_q check")
    batch_size = query_states.size(0)

    if cu_seq_lens_q is None or cu_seq_lens_k is None:
        print(f"cu_seq_lens_q is None: {cu_seq_lens_q is None}")
        print(f"cu_seq_lens_k is None: {cu_seq_lens_q is None}")
        query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = (
            prepare_fa2_from_position_ids(query_states, key_states, value_states, position_ids)
        )

        print("\n cu_seq_lens")
        print(cu_seq_lens)
        cu_seq_lens_q, cu_seq_lens_k = cu_seq_lens
        max_length_q, max_length_k = max_seq_lens

    else:
        ...

I originally identified these functions by using the inspect library, e.g.:

print(inspect.getsource(model.model.layers[0].self_attn.forward))

The goal of these print functions initially was to understand how cu_seqlens is utilized (if at all) and then after realizing it wasn’t being used, my goal became to understand which function form flash_attn is being used: flash_attn_func or flash_attn_varlen_func?

Initial Example: Passing in input_ids, cu_seqlens and max_seqlen to the SmolLM2-135M Forward Pass

At first, based on a Claude-generated example, I passed in the following fake input data.

torch.manual_seed(42)
test_params = {
    'input_ids': torch.randint(1, 10, size=(1,10)).to("cuda"),
    'cu_seqlens': [torch.tensor([0, 3, 10], dtype=torch.int32).to("cuda")],
    'max_seqlen': [10]
}
test_params
{'input_ids': tensor([[7, 6, 8, 5, 1, 3, 8, 6, 5, 3]], device='cuda:0'),
 'cu_seqlens': [tensor([ 0,  3, 10], device='cuda:0', dtype=torch.int32)],
 'max_seqlen': [10]}
model.eval()
with torch.no_grad(): output = model(**test_params)

The following was printed out for each attention mechanism call in each of the model’s 30 layers:

=== FLASH_ATTENTION_FORWARD ENTRY ===
kwargs received: ['position_ids', 'output_attentions', 'use_cache', 'cu_seqlens', 'max_seqlen']

=== _FLASH_ATTENTION_FORWARD ENTRY ===
kwargs received: ['output_attentions', 'use_cache', 'cu_seqlens', 'max_seqlen']

 attention_mask
None

 position_ids
tensor([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]], device='cuda:0')
flash_attn_func is called
flash_kwargs received: ['deterministic']

I was surprised to see that flash_attn_func was called, because IIUC that doesn’t handle sequence packed inputs. Looking at its function signature, there’s no cu_seqlens or similar parameter:

def flash_attn_func(
    q,
    k,
    v,
    softmax_scale=None,
    causal=False,
    qv=None,
    q_descale=None, k_descale=None, v_descale=None,
    window_size=(-1, -1),
    attention_chunk=0,
    softcap=0.0,
    num_splits=1,
    pack_gqa=None,
    deterministic=False,
    sm_margin=0,
)

Additionally, position_ids is defined even though I didn’t pass it in. IIUC, that’s done in the model’s forward pass with the line:

if position_ids is None:
    position_ids = cache_position.unsqueeze(0)

Where cache_position is defined earlier in that forward pass. This can be observed by running:

forward_method = inspect.getsource(model.model.forward)
print(forward_method)

Second Attempt: Passing in position_ids to the Forward Pass as Well

Claude helped me understand that what triggers the function call of flash_attn_varlen_func is the following conditional in _flash_attention_forward:

elif position_ids is not None and (
        max_length_q is not None or (query_length != 1 and not (torch.diff(position_ids, dim=-1) >= 0).all())
    )

In particular, this line was of interest: torch.diff(position_ids, dim=-1) >= 0

In the following contrived example, position_ids is not a list of consecutive numbers (which seems to be the default value constructed is no position_ids value is passed to the model’s forward pass).

input_ids = torch.tensor([[0, 1, 2, 10, 11, 12, 13, 14, 15, 16]]).to("cuda")
position_ids = torch.tensor([[0, 1, 2, 0, 1, 2, 3, 4, 5, 6]]).to("cuda")
cu_seqlens = [torch.tensor([0, 3, 10], dtype=torch.int32).to("cuda")]
(torch.diff(position_ids, dim=-1) >= 0).all()
tensor(False, device='cuda:0')
torch.diff(position_ids, dim=-1) >= 0
tensor([[ True,  True, False,  True,  True,  True,  True,  True,  True]],
       device='cuda:0')
torch.diff(position_ids, dim=-1)
tensor([[ 1,  1, -2,  1,  1,  1,  1,  1,  1]], device='cuda:0')

Some diffs between consecutive elements in position_ids are negative (because we are defining two sequences’ position ids).

I would now expect flash_attn_varlen_func to be called.

torch.manual_seed(42)
test_params = {
    'input_ids': torch.randint(1, 10, size=(1,10)).to("cuda"),
    'position_ids': torch.tensor([[0, 1, 2, 0, 1, 2, 3, 4, 5, 6]]).to("cuda"),
    'cu_seqlens': [torch.tensor([0, 3, 10], dtype=torch.int32).to("cuda")],
    'max_seqlen': [7]
}
test_params
{'input_ids': tensor([[7, 6, 8, 5, 1, 3, 8, 6, 5, 3]], device='cuda:0'),
 'position_ids': tensor([[0, 1, 2, 0, 1, 2, 3, 4, 5, 6]], device='cuda:0'),
 'cu_seqlens': [tensor([ 0,  3, 10], device='cuda:0', dtype=torch.int32)],
 'max_seqlen': [7]}
model.eval()
with torch.no_grad(): output = model(**test_params)

Passing test_params through the model’s forward pass yields:

=== FLASH_ATTENTION_FORWARD ENTRY ===
kwargs received: ['position_ids', 'output_attentions', 'use_cache', 'cu_seqlens', 'max_seqlen']

=== _FLASH_ATTENTION_FORWARD ENTRY ===
kwargs received: ['output_attentions', 'use_cache', 'cu_seqlens', 'max_seqlen']

 attention_mask
None

 position_ids
tensor([[0, 1, 2, 0, 1, 2, 3, 4, 5, 6]], device='cuda:0')
position_ids is not None and max_length_q check
cu_seq_lens_q is None: True
cu_seq_lens_k is None: True

 cu_seq_lens
(tensor([ 0,  3, 10], device='cuda:0', dtype=torch.int32), tensor([ 0,  3, 10], device='cuda:0', dtype=torch.int32))

The position_ids are as passed in. However, it does not use cu_seqlens directly from kwargs. Instead it builds it in the following line:

query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = (
    prepare_fa2_from_position_ids(query_states, key_states, value_states, position_ids)
)

The value of cu_seqlens is the tuple:

(tensor([ 0,  3, 10], device='cuda:0', dtype=torch.int32), tensor([ 0,  3, 10], device='cuda:0', dtype=torch.int32))

Which is deconstructed into cu_seq_lens_q and cu_seql_lens_k which are then passed as arguments to flash_attn_varlen_func.

The main takeaway from this: Flash Attention will not handle sequence packing correctly unless you pass in position_ids.

Packed Sequence Loss

In the remaining sections of this blog post, I’ll explore how to correctly handle calculating loss for a packed sequence.

output.logits.shape
torch.Size([1, 10, 49152])

Following how labels are constructed in HuggingFace’s DataCollatorWithFlattening, the first token in each sequence is replaced with -100. This is because the HuggingFace CausalLM loss function handles the shifting of labels to allow next-token prediction.

labels = torch.tensor([-100, 6, 8, -100, 1, 3, 8, 6, 5, 3]).to("cuda")
labels
tensor([-100,    6,    8, -100,    1,    3,    8,    6,    5,    3],
       device='cuda:0')

The following two lines are taken from the model’s loss function which can be inspected with print(inspect.getsource(model.loss_function)):

_labels = torch.nn.functional.pad(labels, (0, 1), value=-100)
shift_labels = _labels[..., 1:].contiguous()
shift_labels
tensor([   6,    8, -100,    1,    3,    8,    6,    5,    3, -100],
       device='cuda:0')

We can see that the labels have been shifted to the left by 1 element, and a -100 ignore index has been added to the right, which is needed because the last token in the input doesn’t predict anything.

Calculating the loss using F.cross_entropy directly and the model’s loss_function (providing it unshifted labels):

loss = F.cross_entropy(
    output.logits.reshape(-1, output.logits.size(-1)).float(),
    shift_labels.reshape(-1)
)
loss
tensor(20.2832, device='cuda:0')
model.loss_function(output.logits, labels, 49152)
tensor(20.2832, device='cuda:0')

Padded Batch Loss

Sequence packing shouldn’t change the loss value of a given input batch. To test this, I’ll construct a padded batch from our fake data and calculate its outputs, labels and loss.

input_ids = test_params['input_ids'][0]
input_ids
tensor([7, 6, 8, 5, 1, 3, 8, 6, 5, 3], device='cuda:0')
cu_seqlens = test_params['cu_seqlens'][0]
cu_seqlens
tensor([ 0,  3, 10], device='cuda:0', dtype=torch.int32)
seq_boundaries = list(zip(cu_seqlens[:-1], cu_seqlens[1:]))
seq_boundaries
[(tensor(0, device='cuda:0', dtype=torch.int32),
  tensor(3, device='cuda:0', dtype=torch.int32)),
 (tensor(3, device='cuda:0', dtype=torch.int32),
  tensor(10, device='cuda:0', dtype=torch.int32))]
seq1 = input_ids[seq_boundaries[0][0]: seq_boundaries[0][1]]
seq2 = input_ids[seq_boundaries[1][0]: seq_boundaries[1][1]]
seq1, seq2
(tensor([7, 6, 8], device='cuda:0'),
 tensor([5, 1, 3, 8, 6, 5, 3], device='cuda:0'))

The first item in the batch has 3 elements, and the second item in the batch has 7 elements. We need to pad the first item so it’s 7 elements long.

seq1 = torch.cat([seq1, torch.tensor([0, 0, 0, 0]).to("cuda")])
seq1
tensor([7, 6, 8, 0, 0, 0, 0], device='cuda:0')
padded_batch = torch.stack([seq1, seq2], dim=0)
padded_batch, padded_batch.shape
(tensor([[7, 6, 8, 0, 0, 0, 0],
         [5, 1, 3, 8, 6, 5, 3]], device='cuda:0'),
 torch.Size([2, 7]))

Similarly, we need to construct labels such that the last four elements in the first batch item are ignored.

seq1 = input_ids[seq_boundaries[0][0]: seq_boundaries[0][1]]
seq2 = input_ids[seq_boundaries[1][0]: seq_boundaries[1][1]]
seq1, seq2
(tensor([7, 6, 8], device='cuda:0'),
 tensor([5, 1, 3, 8, 6, 5, 3], device='cuda:0'))
seq1 = torch.cat([seq1, torch.tensor([-100, -100, -100, -100]).to("cuda")])
seq1
tensor([   7,    6,    8, -100, -100, -100, -100], device='cuda:0')
padded_labels = torch.stack([seq1, seq2], dim=0)
padded_labels
tensor([[   7,    6,    8, -100, -100, -100, -100],
        [   5,    1,    3,    8,    6,    5,    3]], device='cuda:0')

Calculating the logits:

model.eval()
with torch.no_grad(): padded_output = model(input_ids=padded_batch)

Noting that I haven’t pass any position_ids and the printed output shows us that flash_attn_func is indeed the “vanilla” implementation of Flash Attention for padded batches:

=== FLASH_ATTENTION_FORWARD ENTRY ===
kwargs received: ['position_ids', 'output_attentions', 'use_cache']

=== _FLASH_ATTENTION_FORWARD ENTRY ===
kwargs received: ['output_attentions', 'use_cache']

 attention_mask
None

 position_ids
tensor([[0, 1, 2, 3, 4, 5, 6]], device='cuda:0')
flash_attn_func is called
flash_kwargs received: ['deterministic']

Comparing the packed output logits with the padded output logits. The shapes are different but the values are the same.

output.logits.shape, padded_output.logits.shape
(torch.Size([1, 10, 49152]), torch.Size([2, 7, 49152]))
(output.logits[0, 0:3, :] == padded_output.logits[0, 0:3, :]).float().mean()
tensor(1., device='cuda:0')
(output.logits[0, 3:, :] == padded_output.logits[1, :, :]).float().mean()
tensor(1., device='cuda:0')

Finally, calculating the padded batch’s loss gives us the same value as the sequence packed loss:

padded_loss = model.loss_function(padded_output.logits, padded_labels, vocab_size=49152)
padded_loss
tensor(20.2832, device='cuda:0')

Not Passing in position_ids With Packed Sequence

To confirm that not passing in position_ids does in indeed make HuggingFace use the wrong Flash Attention implementation for a packed sequence, I’ll compare the logits and loss:

torch.manual_seed(42)
test_params = {
    'input_ids': torch.randint(1, 10, size=(1,10)).to("cuda"),
    'cu_seqlens': [torch.tensor([0, 3, 10], dtype=torch.int32).to("cuda")],
    'max_seqlen': [10]
}
test_params
{'input_ids': tensor([[7, 6, 8, 5, 1, 3, 8, 6, 5, 3]], device='cuda:0'),
 'cu_seqlens': [tensor([ 0,  3, 10], device='cuda:0', dtype=torch.int32)],
 'max_seqlen': [10]}
model.eval()
with torch.no_grad(): output2 = model(**test_params)

The logits are not the same as when flash_attn_varlen_func is used.

(output.logits == output2.logits).float().mean()
tensor(0.3012, device='cuda:0')

It follows that the loss value is not the same either.

labels = torch.tensor([-100, 6, 8, -100, 1, 3, 8, 6, 5, 3]).to("cuda")
model.loss_function(output.logits, labels, 49152)
tensor(17.4632, device='cuda:0')

Closing Thoughts

I’ll reiterate that I’m not familiar with how sequence packing is implemented (in HuggingFace or ModernBERT) and even less familiar with how Flash Attention is implemented. That being said, this cursory investigation allowed me to understand high-level concepts of how these two interact. My key takeaway is that the correct position_ids need to be passed to the model otherwise HuggingFace will not use the correct flash_attn_varlen_func for sequence packed inputs and that will result in incorrect logits and loss values.