NSA Modules and Branches

A closer look at token compression, selection, gating, and branch mechanisms in NSA

Part 3 of a series on the DeepSeek NSA paper.

In this section, we cover:

  • Token compression
  • Token selection and importance scoring
  • Sliding window/local context attention
  • Gating and branch integration

3.1 Token Compression Module (Section 3.1)

The Token Compression module implements a coarse-grained reduction of the sequence by aggregating tokens into blocks and representing each block with a single compression token (β‡’) (β‡’). In practice, the sequence of keys (and similarly values) up to the current position tt is divided into contiguous blocks of length ll, with a stride dd between block start positions (typically d<ld < l for overlapping blocks) (β‡’) (β‡’). A learnable function Ο•\phi (a small MLP with a positional encoding for the block) compresses each block into one vector, effectively summarizing that block’s information (β‡’). Formally, if k1:tk_{1:t} denotes the sequence of key vectors from 1 to tt, the compressed keys at step tt are:

K~tcmpβ€…β€Š=β€…β€ŠfKcmp(k1:t)β€…β€Š=β€…β€Š[ ϕ(k(iβˆ’1)d+1:(iβˆ’1)d+l) ]i=1⌊tβˆ’ldβŒ‹β€‰β£, \tilde{K}^{\text{cmp}}_t \;=\; f^{\text{cmp}}_K(k_{1:t}) \;=\; \Big[\,\phi\big(k_{(i-1)d+1 : (i-1)d + l}\big)\,\Big]_{i=1}^{\left\lfloor\frac{t-l}{d}\right\rfloor}\!,

where each Ο•(k(iβˆ’1)d+1:(iβˆ’1)d+l)\phi(k_{(i-1)d+1:(i-1)d+l}) produces a single compressed key for the ii-th block of length ll (β‡’) (β‡’). Equivalently, the procedure can be described as follows:

This compressed representation captures higher-level, coarser-grained semantics of the sequence while greatly reducing the number of tokens the query has to attend to (β‡’). In other words, it provides a global summary of the context with far fewer tokens, which significantly lowers the computational cost of attention on long sequences (β‡’). Figure 2 (left) illustrates this branch: the input sequence is processed into compressed attention covering coarse-grained patterns of the sequence (β‡’). By using overlapping blocks (if d<ld<l), the compression module preserves context continuity between blocks (mitigating hard information cutoffs) at the cost of minor redundancy (β‡’). The compressed keys and values from this module (denoted K~tcmp,V~tcmp\tilde{K}^{\text{cmp}}_t, \tilde{V}^{\text{cmp}}_t) will be used as one branch of the attention mechanism.

Figure 2, highlighting the Compression branch which groups tokens into coarse blocks and produces one representation per block.

3.2 Token Selection Module – Blockwise Selection (Section 3.2)

The Token Selection module aims to preserve important tokens at full fidelity, complementing the coarse compression branch (β‡’). Relying only on compressed tokens could lose critical fine-grained information, so this module selects a subset of original tokens (keys/values) deemed most relevant to the current query. To make this efficient, NSA adopts a blockwise selection strategy – instead of selecting arbitrary individual tokens, it selects contiguous blocks of tokens as units (β‡’) (β‡’). This strategy is motivated by two considerations:

Blockwise Selection Implementation: To apply this strategy, the sequence of keys/values is first partitioned into equal-length selection blocks (of length lβ€²l') (β‡’). These blocks are typically non-overlapping contiguous segments of the sequence (for example, block 0 covers tokens 11 to lβ€²l', block 1 covers lβ€²+1l'+1 to 2lβ€²2l', and so on). The goal is to identify which of these blocks contain the most relevant tokens for the current query. To do so, NSA computes an importance score for each block (described in Section 3.3 below) and then selects the top-ranking blocks for inclusion in the attention.

Once each block jj has been assigned an importance score ptslc[j]p^{\text{slc}}_t[j], the selection module chooses the top-nn blocks with the highest scores as the active blocks for the query (β‡’). Formally, letting ptslc[j]p^{\text{slc}}_t[j] be the importance of block jj, the set of selected block indices is:

Itβ€…β€Š=β€…β€Š{ i∣rank(ptslc[i])≀n } ,(11) I_t \;=\; \{\, i \mid \text{rank}\big(p^{\text{slc}}_t[i]\big) \le n \,\} \,, \tag{11}

where rank(ptslc[i])=1\text{rank}(p^{\text{slc}}_t[i])=1 for the block with highest importance,  ≀n\,\le n for the top nn blocks (β‡’). All tokens within those top-nn blocks are retained as fine-grained tokens for attention. The keys from the selected blocks are concatenated to form the selected key set K~tslc\tilde{K}^{\text{slc}}_t for the query (and likewise for values V~tslc\tilde{V}^{\text{slc}}_t) (β‡’). This can be written as:

K~tslcβ€…β€Š=β€…β€ŠConcat{ k ilβ€²+1:(i+1)lβ€²β€…β€Šβˆ£β€…β€Ši∈It } ,(12) \tilde{K}^{\text{slc}}_t \;=\; \text{Concat}\Big\{\,k_{\,i l' + 1 : (i+1) l'} \;\big|\; i \in I_t \,\Big\}\,, \tag{12}

meaning K~tslc\tilde{K}^{\text{slc}}_t is formed by concatenating the original key vectors from each selected block ii (β‡’). The resulting set K~tslc\tilde{K}^{\text{slc}}_t has size nβ‹…lβ€²n \cdot l' (i.e. the module keeps nn blocks worth of tokens). An analogous construction produces V~tslc\tilde{V}^{\text{slc}}_t from the corresponding value vectors (β‡’). These selected tokens are then used as a second branch of attention, providing high-resolution information from the most relevant parts of the sequence. In Figure 2 (left), this corresponds to the Selected attention branch, which focuses on β€œimportant token blocks” from the input (β‡’).

After selection, the keys/values in the chosen blocks are treated just like in standard attention (except limited to those tokens). The attention computation for this branch will operate over K~tslc\tilde{K}^{\text{slc}}_t and V~tslc\tilde{V}^{\text{slc}}_t, and later be combined with other branches via gating (see Section 3.5). In summary, the token selection module dynamically zooms in on salient regions of the sequence with fine granularity while ignoring less important tokens, all in a hardware-efficient blockwise manner.

Figure 2, highlighting the Selection branch which picks out top-ranked blocks of tokens. The figure shows that only a few blocks (green highlights) are attended, corresponding to important token regions (β‡’).

3.3 Importance Score Computation for Token Selection (Section 3.3)

A key challenge in the selection module is scoring the blocks by importance without introducing too much overhead. NSA addresses this by leveraging the computation already done by the compression branch to derive importance scores for free (β‡’). The idea is that when the query attends to the compressed tokens, the attention weights on those compressed tokens indicate which portions of the sequence are important. In effect, the compression branch provides a rough importance distribution over the sequence, which can guide the selection branch.

Derivation from Compression Attention: Recall the compression branch yields a set of compressed keys K~tcmp\tilde{K}^{\text{cmp}}_t (one per block of the original sequence). When computing attention for query qtq_t on these compressed keys, we get a vector of attention scores ptcmpp^{\text{cmp}}_t corresponding to those blocks. In particular, let qtq_t be the query vector at position tt. The attention of qtq_t over the compressed keys is:

ptcmpβ€…β€Š=β€…β€ŠSoftmax⁑ ⁣(qtβŠ€β€‰K~tcmp) ,(8) p^{\text{cmp}}_t \;=\; \operatorname{Softmax}\!\Big(q_t^\top\,\tilde{K}^{\text{cmp}}_t\Big)\,, \tag{8}

which yields ptcmp∈RMp^{\text{cmp}}_t \in \mathbb{R}^M where M=⌊tβˆ’ldβŒ‹M = \big\lfloor\frac{t - l}{d}\big\rfloor is the number of compression tokens (blocks) (β‡’) (β‡’). In other words, ptcmp[i]p^{\text{cmp}}_t[i] is the normalized attention weight on the ii-th compressed block for the current query. If ptcmp[i]p^{\text{cmp}}_t[i] is high, it means the query found block ii (covering a certain segment of the sequence) very relevant.

NSA uses ptcmpp^{\text{cmp}}_t to induce the selection-block importance scores ptslcp^{\text{slc}}_t. If the selection blocks were chosen to align exactly with the compression blocks (i.e. if the compression block length ll equals the selection block size lβ€²l' and the stride dd equals ll so there’s a one-to-one correspondence), then we can simply take ptslc=ptcmpp^{\text{slc}}_t = p^{\text{cmp}}_t (β‡’). In general, the blocking for selection might differ (e.g. lβ€²l' could be larger or smaller than ll, or compression blocks might overlap while selection blocks don’t). In that case, each selection block spans one or more compression blocks (or fractions thereof). The importance score for a selection block should then be the aggregate of the compression-based scores for all compression tokens that lie within that selection block’s span (β‡’) (β‡’). Formally, suppose both ll and lβ€²l' are multiples of the compression stride dd (as assumed in the paper). Then the importance of selection block jj can be obtained by summing the appropriate entries of ptcmpp^{\text{cmp}}_t that fall into block jj’s range:

ptslc[j]β€…β€Š=β€…β€Šβˆ‘m=0lβ€²dβˆ’1β€…β€Šβˆ‘n=0ldβˆ’1β€…β€Šptcmp ⁣(lβ€²d jβ€…β€Š+β€…β€Šmβ€…β€Š+β€…β€Šn) ,(9) p^{\text{slc}}_t[j] \;=\; \sum_{m=0}^{\frac{l'}{d}-1} \;\sum_{n=0}^{\frac{l}{d}-1} \; p^{\text{cmp}}_t\!\Big(\frac{l'}{d}\,j \;+\; m \;+\; n\Big)\,, \tag{9}

which accumulates the softmax weights of all compression tokens whose indices map into selection block jj (β‡’) (β‡’). (This double sum accounts for the possibly overlapping compression blocks within the span of the jj-th selection block.) In simpler terms, ptslc[j]p^{\text{slc}}_t[j] is the total attention mass that the query qtq_t assigned to the portion of the sequence covered by block jj (according to the compression branch’s view).

After computing ptslcp^{\text{slc}}_t for every block jj, the selection module has an importance score for each selection block. These scores are then used to rank the blocks and choose the top nn as described in Section 3.2. Because ptslcp^{\text{slc}}_t is derived largely from the existing attention computation (the compression branch’s attention), this approach adds minimal computational overhead (β‡’). Essentially, NSA recycles the coarse attention results to guide fine-grained token selection.

Multi-Head Considerations: In transformers, multiple attention heads may share the same key-value memory (for example, in grouped-query attention or multi-query attention, several heads use one shared set of keys/values). In such cases, NSA ensures consistent block selection across heads so that all heads in a group pick the same blocks, avoiding redundant memory fetches (β‡’) (β‡’). They do this by aggregating the importance scores across those heads. If h=1,…,Hh=1,\dots,H index the heads in a group, the shared importance for block jj can be computed as a sum over heads’ importance scores:

pt′ slc[j]β€…β€Š=β€…β€Šβˆ‘h=1Hβ€…β€Šptslc,(h)[j] ,(10) p'^{\,\text{slc}}_t[j] \;=\; \sum_{h=1}^{H} \;p^{\text{slc},(h)}_t[j]\,, \tag{10}

where ptslc,(h)p^{\text{slc},(h)}_t is the importance vector from head hh (β‡’) (β‡’). The selected blocks are then chosen based on this combined score pt′ slcp'^{\,\text{slc}}_t, ensuring all heads in the group attend to the same set of top-nn blocks. This synchronization minimizes divergent memory access patterns between heads during decoding (when reading from the key-value cache) (β‡’) (β‡’).

Summary: The importance score computation module provides the bridge between the compression and selection branches. By computing a query-dependent importance distribution over blocks (using the compression attention output), it enables the token selection module to efficiently pick out the most relevant tokens. These importance scores ptslcp^{\text{slc}}_t directly determine which blocks of tokens are fed into the fine-grained selection attention branch (as formalized by Equation (11) in the previous section) (β‡’) (β‡’). This design preserves crucial information for the query without a heavy computation of its own, thereby maintaining the efficiency of NSA’s sparse attention approach.

3.4 Sliding Window Module for Local Context (Section 3.4)

NSA’s third branch is the Sliding Window module, which explicitly handles local context. This branch ensures that each query always has direct access to the most recent tokens (a fixed-length window of prior tokens), thereby capturing short-range dependencies with high fidelity (β‡’) (β‡’). The motivation for a dedicated local branch is that local patterns tend to dominate learning in attention networks (β‡’). Nearby tokens often have very high relevance (especially in causal language modeling, the immediate past tokens are strong predictors of the next token). If local context is mixed with global context in a single attention, the model may overly focus on these easy-to-use local cues, hindering it from learning long-range patterns (β‡’). By isolating local attention in its own branch, NSA allows the model to learn long-range attention in the other branches without being β€œshort-circuited” by always relying on local tokens (β‡’) (β‡’).

Mechanism: For each query position tt, the sliding window branch simply takes the last ww tokens (keys and values) preceding and including tt as the context. These are the tokens in the interval (tβˆ’w, t](t-w, \, t] (or [t-w, t] in inclusive indexing) (β‡’). Formally, the set of keys and values for the window branch at time tt is:

K~twin=k tβˆ’w:t ,V~twin=v tβˆ’w:t , \tilde{K}^{\text{win}}_t = k_{\,t-w : t}\,, \qquad \tilde{V}^{\text{win}}_t = v_{\,t-w : t}\,,

meaning K~twin\tilde{K}^{\text{win}}_t consists of the ww most recent key vectors (and similarly for V~twin\tilde{V}^{\text{win}}_t) (β‡’). In practice, ww is a fixed window size (e.g. 512 tokens in the authors’ experiments) defining how much immediate context is always attended with full resolution.

This module essentially performs a standard local (sliding-window) attention: the query qtq_t will attend to the last ww keys/values without any sparsification inside this window. What makes it part of NSA’s sparse design is that wβ‰ͺtw \ll t for long sequences, so at any point the query only considers a small local subset in this branch. Crucially, the sliding window branch is separate from the other two branches (compression and selection) (β‡’). Its attention output will be combined with the others at the end (via gating), rather than competing directly with them during the attention softmax. This separation means that the presence of strong local correlations does not suppress the learning of long-range attention in the other branches (β‡’) (β‡’). The authors note that local patterns β€œadapt faster and can dominate the learning process” if not isolated (β‡’), so by giving local context its own branch, other branches can focus on global context without interference (β‡’).

Another implementation detail is that NSA uses independent key and value projections for each branch (β‡’). In other words, even though the sliding window branch and the global branches might ultimately use some of the same raw tokens (e.g., the most recent tokens could also be part of a compression block or a selected block), the keys/values fed into each branch are obtained from separate learnable transformations. This prevents gradient interference between branches (β‡’). For example, the model can adjust the key representations used for long-range compression independently from those used for local window attention. This design choice ensures stable end-to-end training of NSA, as gradients from the local branch (which has a very strong signal) do not distort the representations needed for the global sparse branches (β‡’). The overhead of maintaining separate projections is minimal compared to the overall model size, but it greatly helps in preventing one branch from overpowering the others during training (β‡’).

In Figure 2 (left), the sliding window module corresponds to the Sliding attention branch that always covers the recent context window for each query (β‡’). On the right side of Figure 2, this is visualized as a band along the diagonal of the attention matrix (green squares in a contiguous region near the query token) – indicating that each query attends to its neighboring tokens in the sequence. By including this branch, NSA guarantees that no matter how aggressive the global sparsification (compression/selection) is, the model will not lose track of local context and short-range correlations.

Figure 2 featuring the Sliding Window branch, showing that each query attends to a fixed window of previous tokens. Note how this branch covers the diagonal band of the attention pattern, ensuring local context is attended (β‡’).

3.5 Gating and Branch Integration Mechanism (Section 3.5)

NSA combines the outputs of the three parallel branches – compression, selection, and sliding window – using a learned gating mechanism (β‡’) (β‡’). Rather than simply concatenating or averaging the branch outputs, NSA computes a weighted sum where the weight (gate) for each branch is dynamically determined by the model for each query. This allows the model to adaptively emphasize or de-emphasize each source of information (coarse global context, selected important tokens, or local window) depending on the needs of the current token.

Mathematically, let otcmp=Attn(qt,K~tcmp,V~tcmp)o^{cmp}_t = \text{Attn}(q_t, \tilde{K}^{cmp}_t, \tilde{V}^{cmp}_t) be the attention output from the compression branch alone (i.e. attending qtq_t to compressed tokens), and similarly otslco^{slc}_t from the selection branch and otwino^{win}_t from the sliding window branch. The gated output otβˆ—o^*_t is given by a sum over branches c∈{cmp,slc,win}c \in \{\text{cmp}, \text{slc}, \text{win}\}, each scaled by a gate coefficient gtcg^c_t:

otβˆ—β€…β€Š=β€…β€Šβˆ‘c∈{cmp, slc, win}β€…β€Šgtcβ€…β€Šβ‹…β€…β€ŠAttn(qt, K~tc, V~tc) .(5) o^*_t \;=\; \sum_{c \in \{\text{cmp},\,\text{slc},\,\text{win}\}} \; g^c_t \;\cdot\; \text{Attn}(q_t,\, \tilde{K}^c_t,\, \tilde{V}^c_t)\,. \tag{5}

Here gtc∈[0,1]g^c_t \in [0,1] is the gate score for branch cc at query tt (β‡’) (β‡’). These gating values are produced by a small neural network (an MLP with a sigmoid activation) that takes as input the query’s features (e.g. the transformer's hidden state or qtq_t itself) (β‡’). Intuitively, gtcg^c_t tells the model how much attention to pay to branch cc for this token. For example, if gtwing^{win}_t is high (close to 1) and gtcmpg^{cmp}_t is low (near 0), then the model is relying mostly on local context for token tt and largely ignoring the compressed global context for that token. The gating values are continuous, so the model can blend branches or even use all three if needed (they are not a mutually-exclusive softmax; each gate is independently in [0,1]) (β‡’). In practice, the model will learn to set these gates in a way that optimally combines the information sources for each scenario.

The gating mechanism thus performs a learned fusion of the branch outputs. As illustrated in Figure 2 (left), after computing attention in each branch, NSA produces a Gated Output by weighting each branch’s output and summing them (β‡’) (β‡’). This gating not only merges the information but also can serve to selectively activate branches: if a gate value is near 0, that branch’s contribution is negligible for that token. (For instance, in a situation where the local context provides all necessary information, the model might gate down the other two branches for efficiency.) Because gating is learned end-to-end, the model can flexibly adjust the importance of global vs. local context for different tokens or tasks.

To maintain the efficiency of the sparse attention, the total number of keys/values that the query actually attends to remains low. In fact, the total effective key set size is Nt=∣K~tcmp∣+∣K~tslc∣+∣K~twin∣N_t = |\tilde{K}^{cmp}_t| + |\tilde{K}^{slc}_t| + |\tilde{K}^{win}_t| (β‡’). By design of the previous modules, each of these sets is much smaller than tt, so Ntβ‰ͺtN_t \ll t even though we have three branches (β‡’). This NtN_t is the sum of compressed tokens, selected tokens, and window tokens considered at step tt (see Equation (6) in the paper) (β‡’). NSA keeps NtN_t low (high sparsity) by appropriate choices of block sizes, nn (number of selected blocks), and ww (window size), so that the combined attention still operates on a sparse set of context tokens per query. In summary, the gating mechanism integrates the three sparse attention branches to produce the final attention output otβˆ—o^*_t (which then goes into the transformer’s next layers), while preserving the computational gains of sparsity.

Figure 2 provides an overview of this integration. The left panel shows the three parallel attention branches (Compression, Selection, Sliding Window) processing the input sequence, and a gating module that combines their outputs into the final result (β‡’). The right panel visualizes the attention mask patterns for each branch (green indicates attended positions, white indicates skipped) (β‡’). The gating ensures that these patterns are fused appropriately: effectively, the final attention for qtq_t covers coarse global context (from compression), important specific tokens (from selection), and the immediate neighborhood (from the sliding window), with each part weighted by learned gates. This branch integration via gating is one of NSA’s core innovations, enabling it to balance local and global context in a trainable manner.

Figure 2 shows the gating mechanism: an illustration of the gating MLP taking qtq_t as input and outputting gtcmpg^{cmp}_t, gtslcg^{slc}_t, gtwing^{win}_t, which then weight the branch outputs. The outputs of the three branches are summed to produce the final otβˆ—o^*_t (β‡’) (β‡’).