Wall Attention(GitHub リポジトリ)
Wall Attention は、チャネルとタイムステップごとに学習可能な減衰を QK 積に組み込むことで、スカラーゲートや RoPE を一般化した新しいアテンション機構であり、トレーニングと推論の両方で効率化を実現する。
キーポイント
チャネル別・タイムステップ別減衰の実装
標準的なアテンションスコアの計算に、各チャネルごとに独立したコンテンツ依存の忘却率(減衰)を埋め込むことで、FoX や RoPE の概念を全チャネル次元へ一般化します。
トレーニング用融合カーネル
Triton を使用した融合フォワード・バックワードカーネル(wall_attn)を提供し、FlashAttention 形式のストリーミングソフトマックスと解析的勾配を両立させています。
推論用高速化カーネル
デコード段階では事前スケーリングされた KV キャッシュを読み取る単一ステップカーネル(wall_attn_decode)を採用し、各トークン生成時の計算コストを GEMV 演算レベルに削減します。
柔軟な設定と互換性
減衰パラメータ g を 0 に設定することで、従来のバニラソフトマックスアテンションを完全に復元可能であり、既存アーキテクチャへの移行が容易です。
GQA (Grouped Query Attention) のサポート
クエリヘッド数 (HQ) とキー/バリューヘッド数 (H) を別々に指定できるため、GQA 構造を効率的に実装できます。
bfloat16 対応と勾配計算
CUDA 環境上で bfloat16 データ型を使用し、requires_grad=True を設定することで学習時の勾配計算に対応しています。
デコード時のキャッシュ構築
事前計算(prefill)段階で再スケーリング済みのキャッシュを一度構築し、その後のトークン生成では1トークンずつ処理する効率的なデコードモードが用意されています。
影響分析・編集コメントを表示
影響分析
この技術は、Transformer のアテンション層における情報保持と忘却のバランスを動的かつ細かく制御できる画期的なアプローチであり、特に長いコンテキストウィンドウを扱う LLM の推論コスト削減と精度向上に寄与する可能性があります。実用的なカーネル実装が公開されているため、研究段階から実際のモデルトレーニングやデプロイへの応用が迅速に進むことが期待されます。
編集コメント
アテンション機構の根本的な改良を、実装レベルまで落とし込んで公開した点は非常に高く評価できます。特に推論時の KV キャッシュ活用によるコスト削減は、長文処理におけるボトルネック解消に直結する重要な進展です。
Wall Attention は、QK 内積に組み込まれた「チャネルごとの、タイムステップごとの乗算減衰」を特徴とするアテンションのバリアントです。標準的なアテンションがペア (i, j) を ∑_n q_{i,n} k_{j,n} でスコアリングするのに対し、Wall Attention は各チャネル n に対して、2 つの位置間で蓄積された学習済みの減衰で重み付けを行います。これにより、各クエリチャネルが独立した、コンテンツ依存性の忘却率を持つようになり、スカラーゲート(FoX)や RoPE スタイルの減衰を全チャネル次元に一般化します。g = 0 と設定すると、バニラのソフトマックスアテンションが復元されます。
詳細はブログをご覧ください:
https://blog.tilderesearch.com/blog/wall-attn
このリポジトリでは、実務で用いられる 2 つのカーネルをそれぞれ個別にパッケージ化しています。
- トレーニング / プレフィル (wall_attn): q, k, v, g に対する解析的勾配を持つ融合フォワード + バックワード Triton カーネル(FlashAttention スタイルのストリーミングソフトマックス)
- デコード (wall_attn_decode): 事前再スケーリングされた KV キャッシュを読み取る単一ステップカーネル。これにより、トークンごとの生成コストはプレフィックスの再計算ではなく、小さな GEMV 類似パス 1 つで済みます。
インストール
uv を使用(推奨)
uv sync
source .venv/bin/activate
または pip で
pip install -e .
使用方法
トレーニング / プレフィル
import torch
from wall_attn import wall_attn
B, T, H, HQ, K, V = 2, 1024, 4, 8, 64, 64 # GQA: HQ クエリヘッド、H KV ヘッド
q = torch.randn(B, T, HQ, K, device="cuda", dtype=torch.bfloat16, requires_grad=True)
k = torch.randn(B, T, H, K, device="cuda", dtype=torch.bfloat16, requires_grad=True)
v = torch.randn(B, T, H, V, device="cuda", dtype=torch.bfloat16, requires_grad=True)
g = torch.randn(B, T, HQ, K, device="cuda", dtype=torch.bfloat16, requires_grad=True) * 0.02
o = wall_attn(q, k, v, g, scale=K**-0.5) # [B, T, HQ, V]
o.sum().backward()
オプション引数: g_scalar ([B, T, HQ] FoX スタイルの加算ゲート)、sink_bias ([HQ] アテンションシンク)、window_size (スライディングウィンドウ)、cu_seqlens (可変長パッキング、B == 1 を必要とする)。
Decode (キャッシュ生成)
プリフェッチ時に再スケール済みのキャッシュを一度構築し、その後トークンを一つずつデコードする:
import torch
from fla.ops.utils.constant import RCP_LN2
from fla.ops.utils.cumsum import chunk_global_cumsum
from wall_attn import build_wall_kv_cache, wall_attn_decode
C = 64 # キャッシュチャンクサイズ (アンカー粒度)
P = chunk_global_cumsum(g, scale=RCP_LN2) # [B, T, HQ, K] プレフィックス
k_tilde, r_cache = build_wall_kv_cache(k, P, chunk_size=C)
o, _ = wall_attn_decode(
q=q[:, -1:], # 現在のクエリ [B, 1, HQ, K]
v=v, # キャッシュされた値 [B, T_kv, H, V]
p_curr=P[:, -1:], # 現在の行におけるプレフィックス
k_tilde=k_tilde, # 再スケーリング前のキー [B, T_kv, HQ, K]
r_cache=r_cache, # チャンクごとのアンカー [B, ceil(T_kv/C), HQ, K]
sink_bias=None,
scale=K**-0.5,
cache_chunk_size=C,
)
build_wall_kv_cache は、チャンクごとのアンカー R_c を用いて減衰をキーに折りたたみます (k_tilde[j] = k[j] · exp2(R_c − P[j]))。これにより、デコードカーネルはプレフィックスの再累積を行う必要がなくなります。逐次追加によるサービングループの詳細については、tests/test_decode.py::test_decode_streaming_matches_full_forward を参照してください。
コード構造
wall_attn/
├── __init__.py # パブリック API
├── training.py # フォワード/バックワード Triton カーネル + autograd Function + wall_attn()
├── decode.py # 単一ステップデコードカーネル + build_wall_kv_cache()
└── reference.py # イーガーな PyTorch リファレンス (正しさのオラクル)
tests/
├── test_training.py # パリティ + 解析的勾配 (有限差分チェック付き)
└── test_decode.py # デコード == プレフィルフォワード、ストリーミング、キャッシュ形状
機能
- GQA: クエリヘッド HQ は KV ヘッド H を超える可能性があります (HQ % H == 0)。
- 各チャネルごとの減衰 g と正確な解析的勾配に加え、オプションのスカラーゲート g_scalar をサポートします。
- アテンションシンク (sink_bias)、スライディングウィンドウ (window_size)、および可変長パッキング (cu_seqlens) をサポートします。
- 安価な自己回帰生成のための事前再スケーリングされたデコードキャッシュ、数値的に安定した長文コンテキスト(チャンクごとのアンカーにより exp2 が有界に保たれる)。
- BF16/FP32 入力;Hopper / Ampere アーキテクチャ向けに自動調整されたブロックサイズ。
テスト
pytest # CUDA GPU が必要
すべてのカーネルパスは、イージモードの wall_attn_reference と比較され、g および g_scalar の勾配は中心有限差分法に対して検証される。デコードカーネルは、ストリーミング生成ループを含むトレーニングの順方向をトークン単位で再現するかどうかも確認される。
謝辞
Triton カーネルは、flash-linear-attention (MIT) から派生した並列アテンション機構に基づいています。効率的なアテンションに関する優れた研究に尽力された FLA チームに感謝いたします。
ライセンス
MIT、詳細は LICENSE を参照してください。
原文を表示
Wall Attention is an attention variant with a per-channel, per-timestep multiplicative decay baked into the QK inner product. Where standard attention scores a pair
(
i
,
j
)
with
∑
n
q
i
,
n
,
k
j
,
n
, Wall Attention weights each channel
n
by a learned decay accumulated between the two positions. This gives each query channel an independent, content-dependent forgetting rate, generalizing scalar gating (FoX) and RoPE-style decays to the full channel dimension. Setting
g
=
0
recovers vanilla softmax attention.
See the blog for more information:
https://blog.tilderesearch.com/blog/wall-attn
This repo packages the two kernels used in practice, each on its own:
- Training / prefill (wall_attn): a fused forward + backward Triton kernel (FlashAttention-style streaming softmax) with analytic gradients for
q
,
k
,
v
,
g
.
- Decode (wall_attn_decode): a single-step kernel that reads a pre-rescaled KV cache, so per-token generation costs one small GEMV-like pass instead of recomputing the prefix.
Installation
# Using uv (recommended)
uv sync
source .venv/bin/activate
# or with pip
pip install -e .Usage
Training / prefill
import torch
from wall_attn import wall_attn
B, T, H, HQ, K, V = 2, 1024, 4, 8, 64, 64 # GQA: HQ query heads, H kv heads
q = torch.randn(B, T, HQ, K, device="cuda", dtype=torch.bfloat16, requires_grad=True)
k = torch.randn(B, T, H, K, device="cuda", dtype=torch.bfloat16, requires_grad=True)
v = torch.randn(B, T, H, V, device="cuda", dtype=torch.bfloat16, requires_grad=True)
g = torch.randn(B, T, HQ, K, device="cuda", dtype=torch.bfloat16, requires_grad=True) * 0.02
o = wall_attn(q, k, v, g, scale=K**-0.5) # [B, T, HQ, V]
o.sum().backward()Optional arguments: g_scalar ([B, T, HQ] FoX-style additive gate), sink_bias ([HQ] attention sink), window_size (sliding window), and cu_seqlens (varlen packing, requires B == 1).
Decode (cached generation)
Build the pre-rescaled cache once at prefill, then decode one token at a time:
import torch
from fla.ops.utils.constant import RCP_LN2
from fla.ops.utils.cumsum import chunk_global_cumsum
from wall_attn import build_wall_kv_cache, wall_attn_decode
C = 64 # cache chunk size (anchor granularity)
P = chunk_global_cumsum(g, scale=RCP_LN2) # [B, T, HQ, K] prefix
k_tilde, r_cache = build_wall_kv_cache(k, P, chunk_size=C)
o, _ = wall_attn_decode(
q=q[:, -1:], # current query [B, 1, HQ, K]
v=v, # cached values [B, T_kv, H, V]
p_curr=P[:, -1:], # prefix at the current row
k_tilde=k_tilde, # pre-rescaled keys [B, T_kv, HQ, K]
r_cache=r_cache, # per-chunk anchors [B, ceil(T_kv/C), HQ, K]
sink_bias=None,
scale=K**-0.5,
cache_chunk_size=C,
)build_wall_kv_cache folds the decay into the keys (k_tilde[j] = k[j] · exp2(R_c − P[j])) using a per-chunk anchor R_c, so the decode kernel never re-accumulates the prefix. See tests/test_decode.py::test_decode_streaming_matches_full_forward for the full append-as-you-go serving loop.
Code structure
wall_attn/
├── __init__.py # public API
├── training.py # forward/backward Triton kernels + autograd Function + wall_attn()
├── decode.py # single-step decode kernel + build_wall_kv_cache()
└── reference.py # eager PyTorch reference (correctness oracle)
tests/
├── test_training.py # parity + analytic gradients (finite-difference checked)
└── test_decode.py # decode == prefill forward, streaming, cache shapes
Features
- GQA: query heads HQ may exceed kv heads H (HQ % H == 0).
- Per-channel decay g with exact analytic gradient, plus an optional scalar gate g_scalar.
- Attention sink (sink_bias), sliding window (window_size), and varlen packing (cu_seqlens).
- Pre-rescaled decode cache for cheap autoregressive generation, numerically stable to long context (per-chunk anchors keep exp2 bounded).
- BF16/FP32 inputs; autotuned block sizes for Hopper / Ampere.
Testing
pytest # requires a CUDA GPUEvery kernel path is checked against the eager wall_attn_reference, and the g / g_scalar gradients are verified against central finite differences. The decode kernel is checked to reproduce the training forward token-for-token, including a streaming generation loop.
Acknowledgments
The Triton kernels build on the parallel-attention machinery from flash-linear-attention (MIT). We thank the FLA team for their excellent work on efficient attention.
License
MIT, see LICENSE.
関連記事
[AINews] 今日特に大きな出来事はありませんでした
Latent Space は、GLM 5.2 が依然として注目されていると指摘しつつ、AIE WF 2026 の通常チケットが月曜日に完売すると発表しました。同サイト購読者向けに限定割引を提供し、参加者には Warp や Datadog などからのスポンサークレジットも付与されます。
米国がアンソロピックの「Fable 5」発売を禁止、しかし市場は動じず
米国政府は国家安全保障上の懸念から、アマゾンの研究者らがガードレール回避手法を発見したとして、アンソロピックに対し最新モデル「Fable 5」と「Mythos 5」の販売差し止めを命じた。サイバーセキュリティ研究者らはこの措置が危険だとする公開書簡に署名し、同社も他モデルでも同様の抜け道が存在すると指摘している。
社内データ分析エージェントの構築方法について
GitHub は、大規模なデータ組織が直面する自己完結型のデータアクセスと洞察提供の課題に対し、AI を活用した信頼性の高い解決策として、社内でデータ分析エージェントを構築したことを発表した。
今日のまとめ
AI日報で今日の重要ニュースをまとめ読み