Post

A User's Guide to FlexAttention in FlashAttention CuTe DSL

🤔 Curiosity: Can We Unify All Attention Variants Under One Framework?

What if we could implement causal attention, sliding window attention, ALiBi, and dozens of other attention variants using a single, elegant API? What if this unified framework could achieve 95% of FlashAttention 3’s performance while being flexible enough to support novel attention patterns?

Causal Mask Visualization

Curiosity: Attention mechanisms have evolved into many variants—causal, sliding window, ALiBi, and more. Can we unify them all under one framework without sacrificing performance?

Many variants of attention have become popular in recent years, each addressing specific needs: causal attention for autoregressive language modeling, sliding window attention for long-context efficiency, ALiBi for better extrapolation, and many others. The PyTorch team at Meta recognized that most of these variants can be unified under one elegant framework called FlexAttention.

The question: How does FlexAttention work, and how can developers use it to implement custom attention variants efficiently?

As someone who’s worked with attention mechanisms in production, I’ve seen how different variants require different implementations. FlexAttention offers a unified approach that’s both flexible and performant.


📚 Retrieve: Understanding FlexAttention

The FlexAttention Framework

Core Concept:

FlexAttention adds two options for customization:

  1. 1
    
    score_mod
    
    : A callable that modifies pre-softmax attention scores
  2. 1
    
    mask_mod
    
    : A callable that masks out pre-softmax attention scores

The unified formula is:

\[\text{FlexAttention}(Q, K, V) = \text{Softmax}\left(\text{mask\_mod}\left(\text{score\_mod}\left(QK^T\right)\right)\right) V\]

Key Insight:

1
mask_mod
is a special case of
1
score_mod
where scores are set to
1
-inf
. They’re kept separate for efficiency reasons, especially when dealing with block sparsity.

Performance:

The original FlexAttention implementation in Triton achieved ~90% of FlashAttention 2 performance on Ampere GPUs, but performance on Hopper was significantly worse compared to FlashAttention 3.

The CuTe DSL Implementation:

The implementation in FlashAttention 3 CuTe DSL, done in collaboration with Driss Guessous (Meta) and Tri Dao (Princeton; Together AI), achieves 95% of FlashAttention 3’s performance in the forward pass. This is a roughly 50% speedup over the Triton version in most cases.

graph TB
    subgraph Input["Input"]
        Q[Query Q]
        K[Key K]
        V[Value V]
    end
    
    subgraph FlexAttention["FlexAttention Pipeline"]
        QK[QK^T]
        SM[score_mod]
        MM[mask_mod]
        S[Softmax]
        OV[Output × V]
    end
    
    subgraph Output["Output"]
        OUT[Attention Output]
    end
    
    Q --> QK
    K --> QK
    QK --> SM
    SM --> MM
    MM --> S
    S --> OV
    V --> OV
    OV --> OUT
    
    style SM fill:#ff6b6b,stroke:#c92a2a,stroke-width:2px,color:#fff
    style MM fill:#4ecdc4,stroke:#0a9396,stroke-width:2px,color:#fff
    style S fill:#ffe66d,stroke:#f4a261,stroke-width:2px,color:#000

💡 Innovation: Implementing FlexAttention

Score Modification

The

1
score_mod
Callable:

The

1
score_mod
callable modifies pre-softmax attention scores based on position and optional auxiliary tensors. The generic signature is:

1
2
3
4
5
6
7
8
9
10
def generic_score_mod(
    score: float,
    batch_idx: int,
    head_idx: int,
    q_idx: int,
    kv_idx: int,
    aux_tensors: Optional[list[tensor]],
) -> float:
    # Modify score based on position and aux tensors
    return modified_score

Example 1: T5 Relative Positional Bias

1
2
3
4
def rel_bias_score_mod(score, batch_idx, head_idx, q_idx, kv_idx, aux_tensors):
    bias_tensor = aux_tensors[0]
    rel_pos = math.abs(q_idx - kv_idx)
    return score + bias_tensor[batch_idx, head_idx, rel_pos]

Example 2: ALiBi (Attention with Linear Biases)

1
2
3
4
def alibi_score_mod(score, batch_idx, head_idx, q_idx, kv_idx, aux_tensors):
    slope = math.exp2(-(head_idx + 1))
    dist = math.abs(q_idx - kv_idx)
    return score - slope * dist

CuTe DSL Implementation:

In the CuTe DSL implementation,

1
score_mod
must be defined using the TensorSSA abstraction. For example, T5 bias:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
@cute.jit
def rel_bias_score_mod_cute(
    tSrS_ssa: cute.TensorSSA,
    batch_idx: cute.TensorSSA,
    head_idx: cute.TensorSSA,
    q_idx: cute.TensorSSA,
    kv_idx: cute.TensorSSA,
    aux_tensors: Optional[list]
) -> cute.TensorSSA:
    bias_tensor = aux_tensors[0]
    rel_pos = cute.TensorSSA(
        mlir_math.absi(q_idx - kv_idx), 
        q_idx.shape, 
        q_idx.dtype
    )
    bias = bias_tensor[batch_idx[0], head_idx[0], rel_pos[0]].to(cutlass.Float32)
    return tSrS_ssa + bias

Vectorization:

Application of

1
score_mod
is expensive as it requires looping over all entries in the scores matrix. TensorSSA allows for easy vectorized and broadcasted instructions. In the attention mainloop, we compute modified scores in groups of
1
vec_size
, a tunable hyperparameter.

Note: Without further assumptions, vectorization of

1
score_mod
application is not feasible when using
1
aux_tensors
.

Using Score Modification:

Direct CuTe DSL interface:

1
2
3
4
5
6
7
from flash_attn.cute.interface import _flash_attn_fwd

out, _ = _flash_attn_fwd(
    q, k, v,  # torch.Tensor
    score_mod=rel_bias_score_mod_cute,
    aux_tensors=aux_tensors,  # Optional[list[torch.Tensor]]
)

PyTorch integrated interface:

1
2
3
4
5
6
7
8
from torch.nn.attention.flex_attention import flex_attention

compiled_fn = torch.compile(flex_attention)
out = compiled_fn(
    q, k, v,
    score_mod=rel_bias_score_mod,
    kernel_options={"force_flash": True},  # Use CuTe DSL backend
)

Mask Modification

The

1
mask_mod
Callable:

Defining

1
mask_mod
callables is similar to
1
score_mod
, but simpler. The mask application logic is contained in the FlashAttention forward kernel, so
1
mask_mod
need only return a Boolean indicating whether a score should be masked (set to
1
-inf
):

1
2
3
4
5
6
7
8
9
def generic_mask_mod(
    batch_idx: cute.TensorSSA,
    head_idx: cute.TensorSSA,
    q_idx: cute.TensorSSA,
    kv_idx: cute.TensorSSA,
    aux_tensors: Optional[list],
) -> cute.TensorSSA:  # dtype == cutlass.Boolean
    # Return True if score should be masked (set to -inf)
    return should_mask

Note: Unlike

1
score_mod
, we don’t pass in the score itself—we only need positional information to determine whether an attention element should be masked.

Example 1: Causal Mask with Offset

To create a causal mask with proper offset (

1
seqlen_k - seqlen_q
, or others as needed):

1
2
3
4
5
6
7
8
9
import flash_attn.cute.utils as utils

def create_causal_mask_with_offset(offset: int):
    @cute.jit
    def _causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx, aux_tensors):
        offset_ssa = utils.scalar_to_ssa(val=offset, dtype=cutlass.Int32)
        return kv_idx <= q_idx + offset_ssa 
    
    return _causal_mask_mod

Note: This mask will require recompilation every time

1
seqlen_k - seqlen_q
changes. To avoid this, one could pass in
1
offset
as an additional
1
aux_tensor
.

Example 2: Document Masking

When sequences from multiple documents have been concatenated, tokens should only attend within their document:

1
2
3
4
5
6
7
8
@cute.jit
def document_mask_mod(batch_idx, head_idx, q_idx, kv_idx, aux_tensors):
    doc_ids = aux_tensors[0]
    doc_id_q = doc_ids[batch_idx[0], head_idx[0], q_idx[0]]
    doc_id_kv = doc_ids[batch_idx[0], head_idx[0], kv_idx[0]]
    q_doc = utils.scalar_to_ssa(doc_id_q, cutlass.Int32)
    kv_doc = utils.scalar_to_ssa(doc_id_kv, cutlass.Int32)
    return q_doc == kv_doc

Here,

1
doc_ids
is an
1
Int32
tensor of shape
1
(B, H, seqlen)
representing which document a token belongs to, with the assumption that documents are contiguous.

Usage:

1
2
3
4
5
out, _ = _flash_attn_fwd(
    q, k, v,
    mask_mod=document_mask_mod,
    aux_tensors=[doc_ids],
)

Document Masking Visualization

Block Sparsity

The Optimization Challenge:

When large portions of the scores matrix are masked, we want to intelligently avoid these regions where possible, skipping unnecessary data movement and computation. FlexAttention implements block sparsity with mask mods.

Example: Causal Masking

Consider a problem with:

  • Batch size: 1
  • One head
  • 1
    
    seqlen_q = 768
    
  • 1
    
    seqlen_kv = 896
    
  • Work tile size: 128×128

There are 42 total blocks to handle:

  1. 6 blocks along the main diagonal are split in half by the causal mask; these need
    1
    
    mask_mod
    
    application
  2. 21 blocks below the diagonal have no masking; these do not need
    1
    
    mask_mod
    
    application (though they do need
    1
    
    score_mod
    
    ), so we skip applying
    1
    
    mask_mod
    
    on these blocks
  3. 15 blocks are to be skipped entirely; it would be wasteful even to load them

Block Sparsity Tensors:

Each work tile in the FlashAttention kernel corresponds to one

1
(batch, head, q_block)
coordinate. To compute only the tiles needed, we need to know:

  • 1
    
    mask_block_idx
    
    :
    1
    
    [B, H, num_q_blocks, num_kv_blocks]
    
    - blocks requiring
    1
    
    mask_mod
    
  • 1
    
    full_block_idx
    
    :
    1
    
    [B, H, num_q_blocks, num_kv_blocks]
    
    - fully-computed blocks

And count tensors:

  • 1
    
    mask_block_cnt
    
    :
    1
    
    [B, H, num_q_blocks]
    
    - number of partially-masked
    1
    
    kv_blocks
    
  • 1
    
    full_block_cnt
    
    :
    1
    
    [B, H, num_q_blocks]
    
    - number of fully-computed
    1
    
    kv_blocks
    

Where:

  • 1
    
    num_q_blocks = ceil_div(seqlen_q, tile_m)
    
  • 1
    
    num_kv_blocks = ceil_div(seqlen_kv, tile_n)
    

BlockSparseTensors Class:

1
2
3
4
5
class BlockSparseTensors(NamedTuple):
    mask_block_cnt: cute.Tensor
    mask_block_idx: cute.Tensor
    full_block_cnt: Optional[cute.Tensor]
    full_block_idx: Optional[cute.Tensor]

Note:

1
full_block_cnt
and
1
full_block_idx
can be optional;
1
mask_mod
will be applied to all blocks in that case.

Example: Causal Masking Block Sparsity

For causal masking with the parameters above:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
mask_block_cnt = [[[1, 1, 1, 1, 1, 1]]]
mask_block_idx = [[[[1, 0, 0, 0, 0, 0, 0],
                    [2, 0, 0, 0, 0, 0, 0],
                    [3, 0, 0, 0, 0, 0, 0],
                    [4, 0, 0, 0, 0, 0, 0],
                    [5, 0, 0, 0, 0, 0, 0],
                    [6, 0, 0, 0, 0, 0, 0]]]]
full_block_cnt = [[[1, 2, 3, 4, 5, 6]]]
full_block_idx = [[[[0, 0, 0, 0, 0, 0, 0],
                    [0, 1, 0, 0, 0, 0, 0],
                    [0, 1, 2, 0, 0, 0, 0],
                    [0, 1, 2, 3, 0, 0, 0],
                    [0, 1, 2, 3, 4, 0, 0],
                    [0, 1, 2, 3, 4, 5, 0]]]]

Computing Block Sparsity:

Computing

1
BlockSparseTensors
for a given
1
mask_mod
, sequence length, and tile size can be computationally expensive, but it’s generally amortized across all layers of a model.

PyTorch Integration:

PyTorch has a similar, more robust class

1
BlockMask
that can be converted into
1
BlockSparseTensors
:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
from torch.nn.attention.flex_attention import create_block_mask

block_mask_torch = create_block_mask(
    mask_mod_fn,  # PyTorch mask function
    B, H, seqlen_q, seqlen_kv,
    device="cuda",
    BLOCK_SIZE=(tile_m, tile_n),
)

# Convert to CuTe DSL format
_, _, mask_cnt, mask_idx, full_cnt, full_idx, *_ = block_mask_torch.as_tuple()
block_sparse_tensors = BlockSparseTensorsTorch(
    mask_block_cnt=mask_cnt,
    mask_block_idx=mask_idx,
    full_block_cnt=full_cnt,
    full_block_idx=full_idx,
)

⚠️ Warning: The tile size used to compute block sparsity must be the same as the tile size used in the kernel.

Complete API Call

Full FlexAttention Call:

1
2
3
4
5
6
7
8
9
10
from flash_attn.cute.interface import _flash_attn_fwd

out, _ = _flash_attn_fwd(
    q, k, v,  # torch.Tensor, shape (B, H, seqlen_q, d)
    score_mod=score_mod_fn,  # Optional
    mask_mod=mask_mod_fn,  # Optional
    aux_tensors=aux_tensors,  # Optional[list[torch.Tensor]]
    block_sparse_tensors=block_sparse_tensors,  # Optional[BlockSparseTensors]
    # ... other optional arguments
)

PyTorch Interface:

1
2
3
4
5
6
7
8
9
from torch.nn.attention.flex_attention import flex_attention

compiled_fn = torch.compile(flex_attention)
out = compiled_fn(
    q, k, v,
    score_mod=score_mod_fn,
    mask_mod=mask_mod_fn,
    kernel_options={"force_flash": True},
)

🎯 Key Takeaways

FeatureDescriptionUse Case
score_modModifies pre-softmax attention scoresT5 bias, ALiBi, relative position encoding
mask_modMasks out attention scores (sets to -inf)Causal masking, document masking, sliding window
Block SparsitySkips unnecessary computation for masked regionsPerformance optimization for sparse attention
TensorSSARequired abstraction for CuTe DSLVectorized and broadcasted operations
Performance95% of FlashAttention 3, 50% faster than TritonProduction-ready attention variants

Why FlexAttention Matters

As someone who’s implemented attention variants, here’s what excites me:

  1. Unified API: One framework for many attention variants
  2. High Performance: 95% of FlashAttention 3 performance
  3. Flexibility: Easy to implement custom attention patterns
  4. Block Sparsity: Intelligent optimization for sparse attention
  5. Production Ready: Integrated into PyTorch and FlashAttention

What I’d Implement First:

  • Causal attention with proper offset handling
  • ALiBi for better long-context extrapolation
  • Document masking for multi-document scenarios
  • Custom score modifications for domain-specific attention patterns

🤔 New Questions This Raises

  1. Performance Trade-offs: When does the flexibility of FlexAttention outweigh the performance cost compared to specialized implementations?

  2. Block Sparsity Optimization: How do we automatically determine optimal block sparsity patterns for novel attention variants?

  3. Vectorization Limits: How can we improve vectorization when using

    1
    
    aux_tensors
    
    in
    1
    
    score_mod
    
    ?

  4. Backward Pass: How does FlexAttention perform in the backward pass, and what are the optimization opportunities?

  5. Novel Combinations: What new attention patterns can we create by combining different

    1
    
    score_mod
    
    and
    1
    
    mask_mod
    
    functions?

  6. Hardware Optimization: How does FlexAttention perform on different GPU architectures (Ampere, Hopper, Blackwell)?

Next Steps: Experiment with custom attention variants, benchmark performance on different hardware, and explore novel combinations of score and mask modifications.


References

Original Article:

Research Papers:

FlashAttention:

CuTe DSL:

PyTorch Integration:

Related Attention Mechanisms:

Performance Optimization:

This post is licensed under CC BY 4.0 by the author.