!pip install -qq -U flash-attn --no-build-isolation
HuggingFace’s Default KV Cache and the flash_attn_varlen_func
Docstring
flash_attn_varlen_func
’s docstring’s causal masks (for seqlen_q != seqlen_k) by exploring Hugging Face’s KV Cache (DynamicCache
) in model.generate()
with hands-on Q/K shape inspection. Unravels “bottom-right alignment” and why flash_attn_func
gets called.
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
import torch
import torch.nn as nn
= AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM2-135M")
tokenizer = AutoModelForCausalLM.from_pretrained("HuggingFaceTB/SmolLM2-135M", attn_implementation="flash_attention_2", torch_dtype=torch.bfloat16).to("cuda") model
Background
I have recently been working on a research project which has required me to better understand sequence packing and Flash Attention, and eager attention as well. However, I’ve found that both in terms of my interest, and for practical understanding, that hasn’t been enough!
Tangentially, I also recently started using Gemini Pro 2.5 (the chat UI) and have been thoroughly enjoying it after using a year of daily use of Claude Pro. This seemed like a great opportunity to test out Gemini’s functionality to learn about a topic that is complex and currently outside of my comfort zone.
I fed Gemini the Flash Attention GitHub repo and explained that I wanted to thoroughly understand flash_attn_varlen_func
to the point where I could make a detailed video walkthrough with visuals. It provided me with this condensed timeline:
- Phase 1: Python Layer & Sequence Packing Concepts
- Tasks: Understand
flash_attn_varlen_func
’s Python call stack, the role ofcu_seqlens
,max_seqlen
, andtorch.autograd.Function
. - Time: 1-2 Weeks
- Hours: Approximately 15-30 hours
- Tasks: Understand
- Phase 2: Core CUDA Kernel Deep Dive
- Tasks: Study the FlashAttention research paper(s).
- Analyze the C++/CUDA dispatcher code (e.g., in
csrc/flash_attn/flash_api.cpp
). - Dissect the core CUDA kernels for variable length forward and backward passes (e.g., in
csrc/flash_attn/src/
likeflash_fwd_kernel.h
andflash_bwd_kernel.h
). - Time: 6-12 Weeks
- Hours: Approximately 125-240 hours
- Phase 3: Content Creation (Video/Blog)
- Tasks: Plan the structure for your content, create visuals, draft explanations, and prepare code snippets.
- Time: 2-3 Weeks
- Hours: Approximately 30-50 hours
- Total Estimated for CUDA Path
- Overall Timeline: Roughly 2.5 - 4.5 months
- Total Focused Hours: Approximately 170 - 320 hours
This is obviously an amibitious goal and timeline, especially because of my limited C++/CUDA knowledge and experience. However, I do believe this is a case of aim-for-the-stars-land-on-the-moon, as I’ve already experienced growth and learning in the first steps of Phase 1.
As I was reading through flash_attn_varlen_func
source code, I got stuck on the following piece of the docstring:
If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
1 1 1 1 0
1 1 1 1 1
If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
0 0
0 0
0 0
1 0
1 1
If the row of the mask is all zero, the output will be zero.
I didn’t have hands-on experience working with this concept, where the query length is different than the key and value length. Gemini helped me realize that this happens in the extremely common case of autoregressive generation—the next token (query length of 1) attends to the previous tokens (key/value length > 1). The concept of KV cache also came up in our conversation.
I don’t tend to understand things until I see them in code, so in this notebook, I’ll inspect the shapes of Q, K and V during the HuggingFace model.generate
call. I’ll also peel back a couple layers and understand how HugginFace uses KV cache. After that exploration, I’ll return back to the flash_attn_varlen_func
docstring and walk through the logic behind how the causal mask is shaped.
Understanding HuggingFace’s Default KV Cache
I’ll start by understanding how HuggingFace uses KV cache (I was surprised to find that it uses it by default!).
Inspecting model_kwargs
for Caching Method
Looking at the generate
source code, the first method call of interest when it comes to KV cache seems to be _prepare_cache_for_generation
, which takes the following arguments: generation_config
, model_kwargs
, assistant_model
, batch_size
, max_cache_length
, device
. Going down the different elif statements, _prepare_cache_for_generation
sets the following model_kwargs
value:
= (
model_kwargs[cache_name]
DynamicCache()if not requires_cross_attention_cache
else EncoderDecoderCache(DynamicCache(), DynamicCache())
)
Where cache_name
is defined earlier in that method as:
= "past_key_values" if not is_hybrid_cache else "cache_params" cache_name
I want to inspect what model_kwargs['past_key_values']
is.
_prepare_generation_config
is used in generate
to produce generation_config
and model_kwargs
.
= model._prepare_generation_config(None)
generation_config, model_kwargs generation_config, model_kwargs
(GenerationConfig {
"bos_token_id": 0,
"eos_token_id": 0
},
{})
I can now pass those on to _prepare_cache_for_generation
, which will internally modify model_kwargs
.
None, 1, 8192, "cuda") model._prepare_cache_for_generation(generation_config, model_kwargs,
model_kwargs
{'past_key_values': <transformers.cache_utils.DynamicCache at 0x78c3b80d9850>}
I can see now that model_kwargs
has a 'past_key_values'
key which has a DynamicCache
value.
How is past_key_values
Used?
I think it makes sense to start by looking at the forward pass of the LlamaAttention module:
...
= self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
key_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
value_states
...
if past_key_value is not None:
= past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states, value_states
The hidden_states
pass through k_proj
and v_proj
to produce key_states
and value_states
, respectively, which are then passed to past_key_value.update
to produce a new set of key_states
and value_states
. Looking at DynamicCache.update
:
# Update the cache
if key_states is not None:
if len(self.key_cache) <= layer_idx:
# There may be skipped layers, fill them with empty lists
for _ in range(len(self.key_cache), layer_idx):
self.key_cache.append(torch.tensor([]))
self.value_cache.append(torch.tensor([]))
self.key_cache.append(key_states)
self.value_cache.append(value_states)
elif (
not self.key_cache[layer_idx].numel() # prefers not t.numel() to len(t) == 0 to export the model
# fills previously skipped layers; checking for tensor causes errors
): self.key_cache[layer_idx] = key_states
self.value_cache[layer_idx] = value_states
else:
self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
return self.key_cache[layer_idx], self.value_cache[layer_idx]
Let’s walk through each condition in the if-else block.
if len(self.key_cache) <= layer_idx
A full key_cache
is has n_layers
number of elements. If its number of elements is less than or equal to the layer_idx
that means that it does not contain key_states
for that layer_idx
yet (because python starts count from 0
). For example suppose layer_idx
is 0
, our first layer. if len(self.key_cache) <= layer_idx
is True
, that means len(self.key_cache)
is 0
and doesn’t contain key_states
for the first layer, as would be the case if you were generating the first token of a response. In this case you simply append
the key_states
to the cache.
If layer_idx
is greater than len(self.key_cache)
then it appends an empty tensor for the “skipped” layers. This would be a scenario where you were generating the first token of a response (len(self.key_cache)
is 0
) but starting with layer_idx
of 2
.
elif not self.key_cache[layer_idx].numel()
If a layer was skipped and it has an empty tensor as its key_cache
then this condition is triggered and it simply assigned key_states
to that layer’s key_cache
.
else
I think this is the most common case, used for autoregressive next-token generation. The key_cache
contains a non-empty value for this layer so it concatenates the current value with the new key_states
. In this way, the key_cache
for this layer grows over the course of next token generation. Specifically, it’s second to last dimension (sequence length) increases by 1 for each token processed.
return self.key_cache[layer_idx], self.value_cache[layer_idx]
Finally, the concatenated key_cache
and value_cache
for the given layer are returned. The update
step is complete.
key_states
and value_states
after the past_key_values.update
step are passed onto the attention_interface
which we’ll look at later in this blog post.
= attention_interface(
attn_output, attn_weights self,
query_states,
key_states,
value_states,
attention_mask,=0.0 if not self.training else self.attention_dropout,
dropout=self.scaling,
scaling**kwargs,
)
Visualizing the DynamicCache.update
To see how the cache update takes place during autoregressive language generation, I’ll monkey-patch a debug_update
method.
Show `debug_update
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
def debug_update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,int,
layer_idx: str, Any]] = None,
cache_kwargs: Optional[Dict[-> Tuple[torch.Tensor, torch.Tensor]:
) """
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
Parameters:
key_states (`torch.Tensor`):
The new key states to cache.
value_states (`torch.Tensor`):
The new value states to cache.
layer_idx (`int`):
The index of the layer to cache the states for.
cache_kwargs (`Dict[str, Any]`, `optional`):
Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`.
Return:
A tuple containing the updated key and value states.
"""
# Update the number of seen tokens
if layer_idx == 0:
self._seen_tokens += key_states.shape[-2]
# Update the cache
if key_states is not None:
if len(self.key_cache) <= layer_idx:
print(f"DEBUG: initializing cache for layer_idx {layer_idx}")
for _ in range(len(self.key_cache), layer_idx):
self.key_cache.append(torch.tensor([]))
self.value_cache.append(torch.tensor([]))
self.key_cache.append(key_states)
self.value_cache.append(value_states)
elif (
not self.key_cache[layer_idx].numel() # prefers not t.numel() to len(t) == 0 to export the model
# fills previously skipped layers; checking for tensor causes errors
): print(f"DEBUG: filling empty cache for layer_idx {layer_idx}")
self.key_cache[layer_idx] = key_states
self.value_cache[layer_idx] = value_states
else:
print(f"DEBUG: updating/concatenating cache for layer_idx {layer_idx}")
self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
return self.key_cache[layer_idx], self.value_cache[layer_idx]
from transformers.cache_utils import DynamicCache
if 'ORIGINAL_DYNAMIC_CACHE_UPDATE' not in globals():
= DynamicCache.update
ORIGINAL_DYNAMIC_CACHE_UPDATE print("Stored original DynamicCache.update.")
= debug_update DynamicCache.update
= "The quick brown"
prompt = tokenizer(prompt, return_tensors="pt").to("cuda").values()
input_ids, attention_mask = model.generate(input_ids, pad_token_id=tokenizer.eos_token_id, do_sample=False, return_dict_in_generate=True, max_new_tokens=2) outputs
DEBUG: initializing cache for layer_idx 0
DEBUG: initializing cache for layer_idx 1
DEBUG: initializing cache for layer_idx 2
DEBUG: initializing cache for layer_idx 3
DEBUG: initializing cache for layer_idx 4
DEBUG: initializing cache for layer_idx 5
DEBUG: initializing cache for layer_idx 6
DEBUG: initializing cache for layer_idx 7
DEBUG: initializing cache for layer_idx 8
DEBUG: initializing cache for layer_idx 9
DEBUG: initializing cache for layer_idx 10
DEBUG: initializing cache for layer_idx 11
DEBUG: initializing cache for layer_idx 12
DEBUG: initializing cache for layer_idx 13
DEBUG: initializing cache for layer_idx 14
DEBUG: initializing cache for layer_idx 15
DEBUG: initializing cache for layer_idx 16
DEBUG: initializing cache for layer_idx 17
DEBUG: initializing cache for layer_idx 18
DEBUG: initializing cache for layer_idx 19
DEBUG: initializing cache for layer_idx 20
DEBUG: initializing cache for layer_idx 21
DEBUG: initializing cache for layer_idx 22
DEBUG: initializing cache for layer_idx 23
DEBUG: initializing cache for layer_idx 24
DEBUG: initializing cache for layer_idx 25
DEBUG: initializing cache for layer_idx 26
DEBUG: initializing cache for layer_idx 27
DEBUG: initializing cache for layer_idx 28
DEBUG: initializing cache for layer_idx 29
DEBUG: updating/concatenating cache for layer_idx 0
DEBUG: updating/concatenating cache for layer_idx 1
DEBUG: updating/concatenating cache for layer_idx 2
DEBUG: updating/concatenating cache for layer_idx 3
DEBUG: updating/concatenating cache for layer_idx 4
DEBUG: updating/concatenating cache for layer_idx 5
DEBUG: updating/concatenating cache for layer_idx 6
DEBUG: updating/concatenating cache for layer_idx 7
DEBUG: updating/concatenating cache for layer_idx 8
DEBUG: updating/concatenating cache for layer_idx 9
DEBUG: updating/concatenating cache for layer_idx 10
DEBUG: updating/concatenating cache for layer_idx 11
DEBUG: updating/concatenating cache for layer_idx 12
DEBUG: updating/concatenating cache for layer_idx 13
DEBUG: updating/concatenating cache for layer_idx 14
DEBUG: updating/concatenating cache for layer_idx 15
DEBUG: updating/concatenating cache for layer_idx 16
DEBUG: updating/concatenating cache for layer_idx 17
DEBUG: updating/concatenating cache for layer_idx 18
DEBUG: updating/concatenating cache for layer_idx 19
DEBUG: updating/concatenating cache for layer_idx 20
DEBUG: updating/concatenating cache for layer_idx 21
DEBUG: updating/concatenating cache for layer_idx 22
DEBUG: updating/concatenating cache for layer_idx 23
DEBUG: updating/concatenating cache for layer_idx 24
DEBUG: updating/concatenating cache for layer_idx 25
DEBUG: updating/concatenating cache for layer_idx 26
DEBUG: updating/concatenating cache for layer_idx 27
DEBUG: updating/concatenating cache for layer_idx 28
DEBUG: updating/concatenating cache for layer_idx 29
As we can see by the printed output, for the first generated token update
initializes cache with self.key_cache.append(key_states)
and self.value_cache.append(value_states)
. For the subsequent tokens, it updates the cache with torch.cat
.
I’ll re-assign the original update
to DynamicCache
to avoid cluttering with print outs.
= ORIGINAL_DYNAMIC_CACHE_UPDATE DynamicCache.update
Inspecting past_key_values
During model.generate
With an understanding of how KV cache is updated, I’ll now turn my attention to the key and value cache contents during autoregressive generation.
= AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM2-135M")
tokenizer = AutoModelForCausalLM.from_pretrained("HuggingFaceTB/SmolLM2-135M", attn_implementation="flash_attention_2", torch_dtype=torch.bfloat16).to("cuda") model
= "The quick brown"
prompt prompt
'The quick brown'
= tokenizer(prompt, return_tensors="pt").to("cuda").values()
input_ids, attention_mask input_ids.shape, attention_mask.shape
(torch.Size([1, 3]), torch.Size([1, 3]))
By setting return_dict_in_generate=True
we can retrieve past_key_values
.
= model.generate(input_ids, pad_token_id=tokenizer.eos_token_id, do_sample=False, return_dict_in_generate=True, max_new_tokens=5)
outputs outputs
GenerateDecoderOnlyOutput(sequences=tensor([[ 504, 2365, 6354, 16438, 27003, 690, 260, 23790]],
device='cuda:0'), scores=None, logits=None, attentions=None, hidden_states=None, past_key_values=<transformers.cache_utils.DynamicCache object at 0x78c4f80c4f50>)
0]) tokenizer.decode(outputs.sequences[
'The quick brown fox jumps over the lazy'
We have 8 total tokens—3 from the original prompt and 5 new tokens generated.
outputs.sequences.shape
torch.Size([1, 8])
Inspecting the values in the KV cache: there are 30 items in key_cache
and value_cache
, corresponding to the 30 layers in the model. For the last generated token (the 8th token) there were 7
seen_tokens
.
len(outputs.past_key_values.key_cache), len(outputs.past_key_values.value_cache)
(30, 30)
outputs.past_key_values.seen_tokens
7
The key_cache
tensors all have the same shape: batch size, num_heads, seen_tokens, head_dim.
for k in outputs.past_key_values.key_cache: print(k.shape)
torch.Size([1, 3, 7, 64])
torch.Size([1, 3, 7, 64])
torch.Size([1, 3, 7, 64])
torch.Size([1, 3, 7, 64])
torch.Size([1, 3, 7, 64])
torch.Size([1, 3, 7, 64])
torch.Size([1, 3, 7, 64])
torch.Size([1, 3, 7, 64])
torch.Size([1, 3, 7, 64])
torch.Size([1, 3, 7, 64])
torch.Size([1, 3, 7, 64])
torch.Size([1, 3, 7, 64])
torch.Size([1, 3, 7, 64])
torch.Size([1, 3, 7, 64])
torch.Size([1, 3, 7, 64])
torch.Size([1, 3, 7, 64])
torch.Size([1, 3, 7, 64])
torch.Size([1, 3, 7, 64])
torch.Size([1, 3, 7, 64])
torch.Size([1, 3, 7, 64])
torch.Size([1, 3, 7, 64])
torch.Size([1, 3, 7, 64])
torch.Size([1, 3, 7, 64])
torch.Size([1, 3, 7, 64])
torch.Size([1, 3, 7, 64])
torch.Size([1, 3, 7, 64])
torch.Size([1, 3, 7, 64])
torch.Size([1, 3, 7, 64])
torch.Size([1, 3, 7, 64])
torch.Size([1, 3, 7, 64])
\
model.config.num_attention_heads, \
model.config.num_hidden_layers, model.config.num_key_value_heads
(9, 30, 3)
0].self_attn.k_proj.out_features model.model.layers[
192
3*64
192
The value_cache
is similarly structured.
for v in outputs.past_key_values.value_cache: print(v.shape)
torch.Size([1, 3, 7, 64])
torch.Size([1, 3, 7, 64])
torch.Size([1, 3, 7, 64])
torch.Size([1, 3, 7, 64])
torch.Size([1, 3, 7, 64])
torch.Size([1, 3, 7, 64])
torch.Size([1, 3, 7, 64])
torch.Size([1, 3, 7, 64])
torch.Size([1, 3, 7, 64])
torch.Size([1, 3, 7, 64])
torch.Size([1, 3, 7, 64])
torch.Size([1, 3, 7, 64])
torch.Size([1, 3, 7, 64])
torch.Size([1, 3, 7, 64])
torch.Size([1, 3, 7, 64])
torch.Size([1, 3, 7, 64])
torch.Size([1, 3, 7, 64])
torch.Size([1, 3, 7, 64])
torch.Size([1, 3, 7, 64])
torch.Size([1, 3, 7, 64])
torch.Size([1, 3, 7, 64])
torch.Size([1, 3, 7, 64])
torch.Size([1, 3, 7, 64])
torch.Size([1, 3, 7, 64])
torch.Size([1, 3, 7, 64])
torch.Size([1, 3, 7, 64])
torch.Size([1, 3, 7, 64])
torch.Size([1, 3, 7, 64])
torch.Size([1, 3, 7, 64])
torch.Size([1, 3, 7, 64])
= outputs.past_key_values.key_cache[0]
init_k init_k.shape
torch.Size([1, 3, 7, 64])
While the shapes of the key_cache
tensors across layers are the same, their contents are not. This is because each layer has its own self attention module with its own k_proj
and v_proj
layers with their own learned weights.
for k in outputs.past_key_values.key_cache[1:]: assert torch.allclose(init_k, k)
AssertionError:
Inspecting Intermediate Key/Value Cache Tensors During Generation
Now to understand how KV cache is used during generation: I want to inspect the shape of the key and value cache tensors as the prompt increases by one token at a time.
To achieve this, I’ll add a hook to the first layer’s self attention module’s forward pass using register_forward_hook
. I came to an incorrect conclusion in a previous video and blog post that you can’t use register_forward_hook
for the Llama attention module because it doesn’t capture keyword arguments. What I didn’t realize is that you can capture kwargs with register_forward_hook
by setting with_kwargs=True
, which I have done below.
I wrapped hook_fn
in create_hook_fn
because I wanted to print out the count
of total generated tokens.
def create_hook_fn():
= 1
count def hook_fn(module, args, kwargs, output):
nonlocal count
print(count, kwargs['past_key_value'].key_cache[0].shape)
+= 1
count return hook_fn
= create_hook_fn()
_hook_fn = model.model.layers[0].self_attn
attn_layer = attn_layer.register_forward_hook(_hook_fn, with_kwargs=True)
hook_handle
= model.generate(input_ids, pad_token_id=tokenizer.eos_token_id, do_sample=False, return_dict_in_generate=True, max_new_tokens=5)
outputs hook_handle.remove()
1 torch.Size([1, 3, 3, 64])
2 torch.Size([1, 3, 4, 64])
3 torch.Size([1, 3, 5, 64])
4 torch.Size([1, 3, 6, 64])
5 torch.Size([1, 3, 7, 64])
Let’s parse this output:
- The first new token generated sees only the 3 tokens in the prompt. The KV cache subsequently has a third dimension of
3
. - Each new token generated sees one more new token, so the third dimension (seen tokens) of
key_cache
andvalue_cache
increases by1
I’ll slightly modify hook_fn
so it prints out the first few shapes of key_cache
, allowing us to see what all layers’ cache is storing from the perspective of layer_idx=0
.
def create_hook_fn():
= 1
count def hook_fn(module, args, kwargs, output):
nonlocal count
print(count)
for k in kwargs['past_key_value'].key_cache[:5]: print(k.shape)
+= 1
count return hook_fn
= create_hook_fn()
_hook_fn = model.model.layers[0].self_attn
attn_layer = attn_layer.register_forward_hook(_hook_fn, with_kwargs=True)
hook_handle
= model.generate(input_ids, pad_token_id=tokenizer.eos_token_id, do_sample=False, return_dict_in_generate=True, max_new_tokens=5)
outputs hook_handle.remove()
1
torch.Size([1, 3, 3, 64])
2
torch.Size([1, 3, 4, 64])
torch.Size([1, 3, 3, 64])
torch.Size([1, 3, 3, 64])
torch.Size([1, 3, 3, 64])
torch.Size([1, 3, 3, 64])
3
torch.Size([1, 3, 5, 64])
torch.Size([1, 3, 4, 64])
torch.Size([1, 3, 4, 64])
torch.Size([1, 3, 4, 64])
torch.Size([1, 3, 4, 64])
4
torch.Size([1, 3, 6, 64])
torch.Size([1, 3, 5, 64])
torch.Size([1, 3, 5, 64])
torch.Size([1, 3, 5, 64])
torch.Size([1, 3, 5, 64])
5
torch.Size([1, 3, 7, 64])
torch.Size([1, 3, 6, 64])
torch.Size([1, 3, 6, 64])
torch.Size([1, 3, 6, 64])
torch.Size([1, 3, 6, 64])
Since we are capturing the key_cache
shapes from the first layer (layer_idx=0
), the other subsequent layer’s cache tensors are 1 token “behind”, since the new token’s hidden states have not passed through the model yet.
Ultimately, I want to tie this all back to the flash_attn_varlen_func
’s dostring’s causal mask example, so I’ll take a look at the query_states
shape, copying code from the LlamaAttention
forward pass. I’ll also inspect the length of the key_cache
and its shape, and the shape of value_cache
.
def create_hook_fn():
= 1
count def hook_fn(module, args, kwargs, output):
nonlocal count
= kwargs['hidden_states'].shape[:-1]
input_shape = (*input_shape, -1, module.head_dim)
hidden_shape = module.q_proj(kwargs['hidden_states']).view(hidden_shape).transpose(1, 2)
query_states print(count, f"len(past_key_value): {len(kwargs['past_key_value'].key_cache)},", f"query_states.shape: {query_states.shape},", f"k.shape: {kwargs['past_key_value'].key_cache[0].shape},", f"v.shape: {kwargs['past_key_value'].value_cache[0].shape}")
+= 1
count return hook_fn
= create_hook_fn()
_hook_fn = model.model.layers[0].self_attn
attn_layer = attn_layer.register_forward_hook(_hook_fn, with_kwargs=True)
hook_handle
= model.generate(input_ids, pad_token_id=tokenizer.eos_token_id, do_sample=False, return_dict_in_generate=True, max_new_tokens=5)
outputs hook_handle.remove()
1 len(past_key_value): 1, query_states.shape: torch.Size([1, 9, 3, 64]), k.shape: torch.Size([1, 3, 3, 64]), v.shape: torch.Size([1, 3, 3, 64])
2 len(past_key_value): 30, query_states.shape: torch.Size([1, 9, 1, 64]), k.shape: torch.Size([1, 3, 4, 64]), v.shape: torch.Size([1, 3, 4, 64])
3 len(past_key_value): 30, query_states.shape: torch.Size([1, 9, 1, 64]), k.shape: torch.Size([1, 3, 5, 64]), v.shape: torch.Size([1, 3, 5, 64])
4 len(past_key_value): 30, query_states.shape: torch.Size([1, 9, 1, 64]), k.shape: torch.Size([1, 3, 6, 64]), v.shape: torch.Size([1, 3, 6, 64])
5 len(past_key_value): 30, query_states.shape: torch.Size([1, 9, 1, 64]), k.shape: torch.Size([1, 3, 7, 64]), v.shape: torch.Size([1, 3, 7, 64])
We see that there are 9 query heads, and 3 KV heads. The total hidden dimension for Q, K and V layers is 3 x 64 = 192.
When the first token is being the generated, the length of the key_cache
for layer_idx=0
is 1
, because this is the first attention module’s first forward pass. For subsequent tokens (2, 3, 4, 5) the length of the key_cache
is 30
, as the cache has been instantiated for all 30 layers after the first token is generated.
Finally, we see that the key_cache
and value_cache
shapes are equal, as expected.
Which Flash Attention Interface is Used?
Since this exercise is part of my journey to understand the flash_attn_varlen_func
, I was curious to confirm by visual inspection which Flash Attention interface function was being used. To achieve this, I wrote a “debug” version for the following three functions:
How did I know which functions to modify? Well, largely because I have done this exercise before when I was trying to understand what triggered the use of flash_attn_varlen_func
.
More concisely, I first inspected the forward pass of the attention module:
from inspect import getsource
print(getsource(model.model.layers[0].self_attn.forward))
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor],
past_key_value: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
logger.warning_once(
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)
else:
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling,
**kwargs,
)
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights
In there I saw the following lines of interest:
if self.config._attn_implementation != "eager":
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
logger.warning_once("`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)else:
= ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attention_interface
In our case, the else
block would trigger and ALL_ATTENTION_FUNCTIONS
would be accesssed. Looking at that constant directly we can see that for our model’s _attn_implementation
('flash_attention_2'
) the attention interface funtion is flash_attention_forward
.
model.config._attn_implementation
'flash_attention_2'
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
ALL_ATTENTION_FUNCTIONS[model.config._attn_implementation]
transformers.integrations.flash_attention.flash_attention_forward
def flash_attention_forward(module: torch.nn.Module, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attention_mask: Optional[torch.Tensor], dropout: float=0.0, scaling: Optional[float]=None, sliding_window: Optional[int]=None, softcap: Optional[float]=None, **kwargs) -> Tuple[torch.Tensor, None]
<no docstring>
Show `_debug_flash_attention_forward
from typing import Optional, Tuple
import inspect
from flash_attn import flash_attn_func, flash_attn_varlen_func
import torch
import os
from transformers.modeling_flash_attention_utils import is_flash_attn_greater_or_equal, fa_peft_integration_check, _upad_input, pad_input, prepare_fa2_from_position_ids
= "window_size" in list(inspect.signature(flash_attn_func).parameters)
_flash_supports_window_size = is_flash_attn_greater_or_equal("2.4.1")
flash_241 = None
deterministic_g
def _debug_flash_attention_forward(
query_states: torch.Tensor,
key_states: torch.Tensor,
value_states: torch.Tensor,
attention_mask: Optional[torch.Tensor],int,
query_length: bool,
is_causal: float = 0.0,
dropout: = None,
position_ids: Optional[torch.Tensor] float] = None,
softmax_scale: Optional[int] = None,
sliding_window: Optional[bool = False,
use_top_left_mask: float] = None,
softcap: Optional[bool] = None,
deterministic: Optional[= None,
cu_seq_lens_q: Optional[torch.LongTensor] = None,
cu_seq_lens_k: Optional[torch.LongTensor] int] = None,
max_length_q: Optional[int] = None,
max_length_k: Optional[= None,
target_dtype: Optional[torch.dtype] **kwargs,
):
if not use_top_left_mask:
= is_causal
causal else:
# TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1.
= is_causal and query_length != 1
causal
# Assuming 4D tensors, key_states.shape[1] is the key/value sequence length (source length).
= (
use_sliding_windows and sliding_window is not None and key_states.shape[1] > sliding_window
_flash_supports_window_size
)= {"window_size": (sliding_window, sliding_window)} if use_sliding_windows else {}
flash_kwargs
if flash_241:
if deterministic is None:
global deterministic_g
if deterministic_g is None:
= os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1"
deterministic_g = deterministic_g
deterministic "deterministic"] = deterministic
flash_kwargs[
if softcap is not None:
"softcap"] = softcap
flash_kwargs[
# PEFT possibly silently casts tensors to fp32, this potentially reconverts to correct dtype or is a no op
= fa_peft_integration_check(
query_states, key_states, value_states
query_states, key_states, value_states, target_dtype
)
# Contains at least one padding token in the sequence
if attention_mask is not None:
= query_states.shape[0]
batch_size = _upad_input(
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens
query_states, key_states, value_states, attention_mask, query_length
)= cu_seq_lens
cu_seqlens_q, cu_seqlens_k = max_seq_lens
max_seqlen_in_batch_q, max_seqlen_in_batch_k print("if attention_mask is not None: flash_attn_varlen_func is being used")
= flash_attn_varlen_func(
attn_output_unpad
query_states,
key_states,
value_states,=cu_seqlens_q,
cu_seqlens_q=cu_seqlens_k,
cu_seqlens_k=max_seqlen_in_batch_q,
max_seqlen_q=max_seqlen_in_batch_k,
max_seqlen_k=dropout,
dropout_p=softmax_scale,
softmax_scale=causal,
causal**flash_kwargs,
)= pad_input(attn_output_unpad, indices_q, batch_size, query_length)
attn_output
# If position_ids is provided and check all examples do not contain only 1 sequence, If tensor in increasing
# then we probably have one sequence, otherwise it is packed. Additionally check we are in pre-fill/training stage.
# Use `flash_attn_varlen_func` to prevent cross-example attention and also allow padding free approach
elif position_ids is not None and (
is not None or (query_length != 1 and not (torch.diff(position_ids, dim=-1) >= 0).all())
max_length_q
):= query_states.size(0)
batch_size
if cu_seq_lens_q is None or cu_seq_lens_k is None:
= (
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens
prepare_fa2_from_position_ids(query_states, key_states, value_states, position_ids)
)
= cu_seq_lens
cu_seq_lens_q, cu_seq_lens_k = max_seq_lens
max_length_q, max_length_k
else:
= query_states.reshape(-1, query_states.size(-2), query_states.size(-1))
query_states = key_states.reshape(-1, key_states.size(-2), key_states.size(-1))
key_states = value_states.reshape(-1, value_states.size(-2), value_states.size(-1))
value_states
print("position_ids is not None: flash_attn_varlen_func is being used")
= flash_attn_varlen_func(
attn_output
query_states,
key_states,
value_states,=cu_seq_lens_q,
cu_seqlens_q=cu_seq_lens_k,
cu_seqlens_k=max_length_q,
max_seqlen_q=max_length_k,
max_seqlen_k=dropout,
dropout_p=softmax_scale,
softmax_scale=causal,
causal**flash_kwargs,
)
= attn_output.view(batch_size, -1, attn_output.size(-2), attn_output.size(-1))
attn_output
else:
print("flash_attn_func is being used")
= flash_attn_func(
attn_output =softmax_scale, causal=causal, **flash_kwargs
query_states, key_states, value_states, dropout, softmax_scale
)
return attn_output
Show `debug_flash_attention_forward
from typing import Optional, Tuple
import torch
from transformers.modeling_flash_attention_utils import _flash_attention_forward, flash_attn_supports_top_left_mask
= flash_attn_supports_top_left_mask()
_use_top_left_mask
def debug_flash_attention_forward(
module: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor],float = 0.0,
dropout: float] = None,
scaling: Optional[int] = None,
sliding_window: Optional[float] = None,
softcap: Optional[**kwargs,
-> Tuple[torch.Tensor, None]:
) if kwargs.get("output_attentions", False) or kwargs.get("head_mask", None) is not None:
print(
"`flash_attention_2` does not support `output_attentions=True` or `head_mask`."
" Please set your attention to `eager` if you want any of these features."
)
# This is before the transpose
= query.shape[2]
seq_len
# FA2 uses non-transposed inputs
= query.transpose(1, 2)
query = key.transpose(1, 2)
key = value.transpose(1, 2)
value
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
# therefore the input hidden states gets silently casted in float32. Hence, we need
# cast them back in the correct dtype just to be sure everything works as expected.
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
# in fp32. (usually our RMSNorm modules handle it correctly)
= None
target_dtype if query.dtype == torch.float32:
if torch.is_autocast_enabled():
= torch.get_autocast_gpu_dtype()
target_dtype # Handle the case where the model is quantized
elif hasattr(module.config, "_pre_quantization_dtype"):
= module.config._pre_quantization_dtype
target_dtype else:
= next(layer for layer in module.modules() if isinstance(layer, torch.nn.Linear)).weight.dtype
target_dtype
# FA2 always relies on the value set in the module, so remove it if present in kwargs to avoid passing it twice
"is_causal", None)
kwargs.pop(
print("DEBUG: calling _flash_attention_forward")
= _debug_flash_attention_forward(
attn_output
query,
key,
value,
attention_mask,=seq_len,
query_length=module.is_causal,
is_causal=dropout,
dropout=scaling,
softmax_scale=sliding_window,
sliding_window=softcap,
softcap=_use_top_left_mask,
use_top_left_mask=target_dtype,
target_dtype**kwargs,
)
return attn_output, None
Show `debug_forward
from typing import Callable, Optional, Tuple, Union
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
from transformers.cache_utils import Cache, DynamicCache
from transformers.processing_utils import Unpack
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
from transformers.models.llama.modeling_llama import eager_attention_forward
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
def debug_forward(
self,
hidden_states: torch.Tensor,
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor],= None,
past_key_value: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] **kwargs: Unpack[FlashAttentionKwargs],
-> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
) = hidden_states.shape[:-1]
input_shape = (*input_shape, -1, self.head_dim)
hidden_shape
= self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
query_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
key_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
value_states
= position_embeddings
cos, sin = apply_rotary_pos_emb(query_states, key_states, cos, sin)
query_states, key_states
if past_key_value is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
= {"sin": sin, "cos": cos, "cache_position": cache_position}
cache_kwargs = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
key_states, value_states
= eager_attention_forward
attention_interface: Callable
if self.config._attn_implementation != "eager":
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
print(
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
)else:
#attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
= debug_flash_attention_forward
attention_interface
= attention_interface(
attn_output, attn_weights self,
query_states,
key_states,
value_states,
attention_mask,=0.0 if not self.training else self.attention_dropout,
dropout=self.scaling,
scaling**kwargs,
)
= attn_output.reshape(*input_shape, -1).contiguous()
attn_output = self.o_proj(attn_output)
attn_output return attn_output, attn_weights
= AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM2-135M")
tokenizer = AutoModelForCausalLM.from_pretrained("HuggingFaceTB/SmolLM2-135M", attn_implementation="flash_attention_2", torch_dtype=torch.bfloat16).to("cuda") model
types.MethodType
binds a function (debug_forward
) as a method for a class (attn_layer_instance
).
import types
types.MethodType??
Init signature: types.MethodType(self, /, *args, **kwargs)
Docstring: Create a bound instance method object.
Type: type
Subclasses:
= model.model.layers[0].self_attn
attn_layer_instance
= attn_layer_instance.forward
original_layer_forward
= types.MethodType(debug_forward, attn_layer_instance) attn_layer_instance.forward
def create_hook_fn():
= 1
count def hook_fn(module, args, kwargs, output):
nonlocal count
= kwargs['hidden_states'].shape[:-1]
input_shape = (*input_shape, -1, module.head_dim)
hidden_shape = module.q_proj(kwargs['hidden_states']).view(hidden_shape).transpose(1, 2)
query_states print(count, len(kwargs['past_key_value'].key_cache), f"query_states.shape: {query_states.shape}", f"k.shape: {kwargs['past_key_value'].key_cache[0].shape}")
# for k in kwargs['past_key_value'].key_cache: print(k.shape) # do this for v as well
+= 1
count return hook_fn
= create_hook_fn()
_hook_fn = model.model.layers[0].self_attn
attn_layer = attn_layer.register_forward_hook(_hook_fn, with_kwargs=True)
hook_handle
= model.generate(input_ids, pad_token_id=tokenizer.eos_token_id, do_sample=False, return_dict_in_generate=True, max_new_tokens=5)
outputs print(outputs.sequences[0].shape)
hook_handle.remove()
DEBUG: calling _flash_attention_forward
flash_attn_func is being used
1 1 query_states.shape: torch.Size([1, 9, 3, 64]) k.shape: torch.Size([1, 3, 3, 64])
DEBUG: calling _flash_attention_forward
flash_attn_func is being used
2 30 query_states.shape: torch.Size([1, 9, 1, 64]) k.shape: torch.Size([1, 3, 4, 64])
DEBUG: calling _flash_attention_forward
flash_attn_func is being used
3 30 query_states.shape: torch.Size([1, 9, 1, 64]) k.shape: torch.Size([1, 3, 5, 64])
DEBUG: calling _flash_attention_forward
flash_attn_func is being used
4 30 query_states.shape: torch.Size([1, 9, 1, 64]) k.shape: torch.Size([1, 3, 6, 64])
DEBUG: calling _flash_attention_forward
flash_attn_func is being used
5 30 query_states.shape: torch.Size([1, 9, 1, 64]) k.shape: torch.Size([1, 3, 7, 64])
torch.Size([8])
From the print statements in my _debug_flash_attention_forward
function, I can see that flash_attn_func
, the non-variable-length interface, is being used for this generation. That makes sense because I only have 1 item in the batch.
Understanding the flash_attn_varlen_func
Causal Mask Docstring
A quick recap of what we’ve learned so far:
- HuggingFace’s
model.generate
uses KV cache by default (DynamicCache
) stored aspast_key_values
. - For most scenarios, the
DynamicCache
is updated by concatenating the previous token’skey_cache
andvalue_cache
with thekey_states
andvalue_states
generated for the current new token. - As the next token is generated for a given prompt,
query_states
has a sequence length of1
, whereaskey_cache
andvalue_cache
tensors’ sequence dimension increases by 1. This is directly relates to theflash_attn_varlen_func
causal mask docstring example. model.generate
utilized theflash_attn_func
interface.
Let’s look at the flash_attn_varlen_func
docstring snippet again:
If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
1 1 1 1 0
1 1 1 1 1
If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
0 0
0 0
0 0
1 0
1 1
If the row of the mask is all zero, the output will be zero.
I’ll annotate the causal mask examples a bit:
seqlen_q=2
and seqlen_k=5
k_0 | k_1 | k_2 | k_3 | k_4 | |
---|---|---|---|---|---|
q_0 | 1 | 1 | 1 | 1 | 0 |
q_1 | 1 | 1 | 1 | 1 | 1 |
The final query token (q_1
) sees all 5 key tokens. The first query token (q_0
) only sees the first four key tokens.
seqlen_q=5
and seqlen_k=2
k_0 | k_1 | |
---|---|---|
q_0 | 0 | 0 |
q_1 | 0 | 0 |
q_2 | 0 | 0 |
q_3 | 1 | 0 |
q_4 | 1 | 1 |
Again, the final query token (q_4
) sees all key tokens. As a consequence, since there are only two key tokens, the first three query tokens do not see any key tokens.
In each example, we are offsetting the shorter sequence so that its last token aligns with the other sequences’s last token. This is what the flash_attn_varlen_func
docstring means by
the causal mask is aligned to the bottom right corner of the attention matrix
In the first case, the query sequence is shorter so we offset it by 3 positions to align with the last two tokens of the key sequence. The “offset” positions are 1s (this satisfies the rule of causality j <= i
, query tokens can look back). In the second case, the key sequence is shorter so we offset it by 3 positions to align with the last two tokens of the query sequence. The offset positions are 0s (again, this satisfies causality, the query tokens have nothing to look back to).
The model.generate
examples above are like the first case, where there are more key positions than query positions. The query token (the next-token being predicted) can look back at all key tokens.
A Math-y Way to think About It
For those of you who like to think through things with math.
Causality (in language modeling) means that a query token vector at the i-th position can only see its own and previous tokens’ key vectors. Having different sequence lengths for Q and K (5 and 2 or 2 and 5 in the flash_attn_varlen_func
docstring example or 1 and 3-7 in my inspections above) requires you to pick how Q and K are aligned. In the case of flash_attn_varlen_func
they choose to align Q and K such as the last Q token vector is aligned with the last K token vector. This becomes our “present moment” P
with causality allowing access to previous tokens only.
Let’s define i
as the position of query tokens and j
as the position of key tokens. Causality is defined as token pairs that follow the inequality: j <= i + (seqlen_k - seqlen_q)
.
For the first causal mask example:
j |
i |
j <= i + (seqlen_k - seqlen_q) |
---|---|---|
0 | 0 | 0 <= 0 + 3 (True ) |
1 | 0 | 1 <= 0 + 3 (True ) |
2 | 0 | 2 <= 0 + 3 (True ) |
3 | 0 | 3 <= 0 + 3 (True ) |
4 | 0 | 4 <= 0 + 3 (False ) |
0 | 1 | 0 <= 1 + 3 (True ) |
1 | 1 | 1 <= 1 + 3 (True ) |
2 | 1 | 2 <= 1 + 3 (True ) |
3 | 1 | 3 <= 1 + 3 (True ) |
4 | 1 | 4 <= 1 + 3 (True ) |
Where does j <= i + (seqlen_k - seqlen_q)
come from?
Let q_i
be a query that is seqlen_q - 1 - i
steps before the end of the query sequence, and k_j
be a key that is seqlen_k - 1 - j
steps before the end of the key sequence. More concretely, for the example where seqlen_q = 2
and seqlen_k=5
:
q_i | Steps before end | seqlen_q - 1 - i |
---|---|---|
q_0 | 1 | 2 - 1 - 0 |
q_1 | 0 | 2 - 1 - 1 |
k_j | Steps before end | seqlen_k - 1 - j |
---|---|---|
k_0 | 4 | 5 - 1 - 0 |
k_1 | 3 | 5 - 1 - 1 |
k_2 | 2 | 5 - 2 - 1 |
k_3 | 1 | 5 - 3 - 1 |
k_4 | 0 | 5 - 4 - 1 |
By picking a “present moment” P
(the last token in each sequence) have a unified timeline p
such that causality is defined as: p_j <= p_i
. k_j
has a position on the timeline p_j = P - (seqlen_k - 1 - j)
and q_i
has a position on the timeline p_i = P - (seqlen_q - 1 - i)
. Causality requires that p_j <= p_i
on our “unified timeline”. Writing that out:
P - (seqlen_k - 1 - j) <= P - (seqlen_q - 1 - i)
Cancelling out the P
s and distributing the minus sign:
-seqlen_k + 1 + j <= -seqlen_q + 1 + i
Isolating j
on the lefthand side:
j <= -seqlen_q + 1 + i + seqlen_k - 1
Simplifying + reordering:
j <= i + (seqlen_k - seqlen_q)
We can think of this (seqlen_k - seqlen_q)
to be an “offset” term between the two sequences.
Looking at this concretely for the second causal mask:
k_0 | k_1 | |
---|---|---|
q_0 | 0 | 0 |
q_1 | 0 | 0 |
q_2 | 0 | 0 |
q_3 | 1 | 0 |
q_4 | 1 | 1 |
j |
i |
j <= i + (seqlen_k - seqlen_q) |
---|---|---|
0 | 0 | 0 <= 0 - 3 (False ) |
0 | 1 | 0 <= 1 - 3 (False ) |
0 | 2 | 0 <= 2 - 3 (False ) |
0 | 3 | 0 <= 3 - 3 (True ) |
0 | 4 | 0 <= 4 - 3 (True ) |
1 | 0 | 1 <= 0 - 3 (False ) |
1 | 1 | 1 <= 1 - 3 (False ) |
1 | 2 | 1 <= 2 - 3 (False ) |
1 | 3 | 1 <= 3 - 3 (False ) |
1 | 4 | 1 <= 4 - 3 (True ) |
Closing Thoughts
Understanding flash_attn_varlen_func
is going to require a sequence (pun intended) of such deep dives. It took me hours to just get through the docstring!! I’m also working on understanding ModernBERT’s sequence packing implementation (to the point of explaining it with visuals) and I expect it to interweave with my Flash Attention study, especially when understanding how ModernBERT prepares and packs sequences and related artifacts in preparation of passing it through the attention mechanism, utilizing flash_attn_varlen_func
. It’s an exciting one-two punch for sure! I’m glad I’m working on them together.
I’ll end with listing out again what I’ve learned in this notebook/exercise, with a couple points added about the causal mask:
- HuggingFace’s
model.generate
uses KV cache by default (DynamicCache
) stored aspast_key_values
. - For most scenarios, the
DynamicCache
is updated by concatenating the previous token’skey_cache
andvalue_cache
with thekey_states
andvalue_states
generated for the current new token. - As the next token is generated for a given prompt,
query_states
has a sequence length of1
, whereaskey_cache
andvalue_cache
tensors’ sequence dimension increases by 1. This is directly relates to theflash_attn_varlen_func
causal mask docstring example. model.generate
utilized theflash_attn_func
interface.- The causal mask is aligned to the bottom-right of the attention matrix (the last tokens of the Q and K sequence are aligned).
- Causality, when \(Q_i\) and \(K_j\) sequences are of different length, is satisfied by the equation
j <= i + (seqlen_k - seqlen_q)
. - When there are more query tokens than key tokens, the “offset” (needed to align the last token of each sequence) results in 0s in the mask as there are no key tokens to “look back at”.
- When there are more key tokens than query tokens, the “offset” results in 1s as the query tokens can look back at more key tokens.
I’m trying to grow my YouTube channel this year so if you enjoyed this blog post, please subscribe!