パッキングシーケンス、GQA、ALiBi、SwiGLU、因果アテンションを用いたメモリ効率的なTransformerの構築方法(xFormers活用)
MarkTechPost は、xFormers を用いてメモリ効率を劇的に向上させるための実装チュートリアルを提供し、GQA や SwiGLU などの最新技術を組み合わせた大規模言語モデル構築の具体的な手順を示している。
キーポイント
メモリエフィシアントランスフォーマーの実装検証
標準的なアテンション実装と比較し、xFormers のメモリ効率と速度を異なるシーケンス長で定量的に評価する手法を示している。
高度な最適化技術の統合
因果マスク、パッキングされた可変長シーケンス、グループドクエリアテンション(GQA)、およびカスタム ALiBi 位置バイアスの実装方法を解説している。
高性能 GPT スタイルモデルの構築
xFormers アテンション、SwiGLU フィードフォワード層、自動混合精度トレーニングを組み合わせた、訓練可能な GPT 型モデルの完成形を提示している。
メモリ効率と精度の両立
xFormers の `memory_efficient_attention` は、従来の注意機構が生成する巨大な M×M スコア行列をメモリ上に展開せず計算を行うため、VRAM 使用量を劇的に削減します。
ベンチマーク機能の実装
CUDA イベントを用いた正確な実行時間計測関数や、GPU メモリピーク値を取得するユーティリティ関数を定義し、パフォーマンス評価の基盤を整えています。
検証環境の確認
コードは GPU の利用可否を確認し、xFormers が正しくビルドされたカーネルを備えているか `xformers.info` を実行して環境設定を検証します。
メモリ効率の劇的な向上
xFormers はシーケンス長が増加してもメモリ使用量が線形に増加するのみであり、従来の注意機構が示す二次関数的な爆発的成長(M が 2 倍になるとメモリが約 4 倍)を防ぎます。
影響分析・編集コメントを表示
影響分析
この記事は、大規模言語モデル(LLM)のトレーニングにおけるメモリ制約という現実的な課題に対し、xFormers を活用した具体的な解決策を提供する点で極めて重要です。実践的なコード例と最新技術の統合方法を提示することで、開発者がリソースを節約しつつ高性能なモデルを構築できる道筋を示しており、業界全体の効率化に貢献します。
編集コメント
実務で大規模モデルを扱うエンジニアにとって、メモリ効率化の具体的なコード例が得られる非常に価値のある技術記事です。最新の最適化技術を即座に適用するための指針として活用できます。
このチュートリアルでは、GPU で高速かつメモリ効率的な Transformer モデルを構築するための実用的なツールキットである xFormers の実装を行います。まず、標準的なアテンション実装に対してメモリ効率的なアテンションの有効性を検証し、その後、異なるシーケンス長にわたって両者の速度とメモリー消費量を比較します。次に、因果マスク(causal masking)、パッキングされた可変長のシーケンス(packed variable-length sequences)、グループ化クエリアテンション(grouped-query attention)、およびカスタム ALiBi 位置バイアスについて検討します。最後に、これらすべての技術を組み合わせて、xFormers アテンション、SwiGLU フィードフォワード層、自動混合精度トレーニングを活用した学習可能な GPT スタイルモデルを構築します。
xFormers のセットアップとメモリ効率的なアテンションの検証
コードをコピーしました。別のブラウザを使用してください
import subprocess, sys
def _pip(*a): subprocess.run([sys.executable, "-m", "pip", "install", *a], check=False)
try:
import xformers
except Exception:
_pip("-q", "-U", "xformers")
import math, time
import torch, torch.nn as nn, torch.nn.functional as F
import xformers, xformers.ops as xops
from xformers.ops import fmha
ab = fmha.attn_bias
assert torch.cuda.is_available(), (
"No GPU detected. In Colab: Runtime → Change runtime type → GPU, then re-run.")
device = "cuda"
torch.manual_seed(0)
print("torch :", torch.__version__)
print("xformers :", xformers.__version__)
print("GPU :", torch.cuda.get_device_name(0))
print("\n--- xformers.info (which kernels are built/available) ---")
try:
subprocess.run([sys.executable, "-m", "xformers.info"], check=False)
except Exception as e:
print("xformers.info unavailable:", e)
def cuda_time(fn, iters=20, warmup=5):
for _ in range(warmup): fn()
torch.cuda.synchronize()
s, e = (torch.cuda.Event(enable_timing=True) for _ in range(2))
s.record()
for _ in range(iters): fn()
e.record(); torch.cuda.synchronize()
return s.elapsed_time(e) / iters
def peak_mem_mb(fn):
torch.cuda.empty_cache(); torch.cuda.reset_peak_memory_stats()
fn(); torch.cuda.synchronize()
return torch.cuda.max_memory_allocated() / 1e6
def vanilla_attention(q, k, v, causal=False):
"""Reference attention that MATERIALIZES the [B,H,M,M] score matrix.
Inputs are xformers-layout [B, M, H, K]."""
q, k, v = (t.transpose(1, 2).float() for t in (q, k, v))
scores = (q @ k.transpose(-2, -1)) / math.sqrt(q.shape[-1])
if causal:
M = scores.shape[-1]
m = torch.triu(torch.ones(M, M, device=q.device, dtype=torch.bool), 1)
scores = scores.masked_fill(m, float("-inf"))
out = scores.softmax(-1) @ v
return out.transpose(1, 2)
print("\n" + "="*70 + "\n1. memory_efficient_attention basics + correctness\n" + "="*70)
B, M, H, K = 2, 512, 8, 64
q, k, v = (torch.randn(B, M, H, K, device=device, dtype=torch.float16) for _ in range(3))
out_xf = xops.memory_efficient_attention(q, k, v)
out_ref = vanilla_attention(q, k, v).half()
print("output shape :", tuple(out_xf.shape), "(layout B, M, H, K)")
print("max abs diff vs ref : {:.2e}".format((out_xf - out_ref).abs().max().item()))
print("-> it's EXACT attention (fp16 rounding only), just computed without")
print(" ever storing the full MxM score matrix.")
xFormers のインストールとインポートを行い、GPU の利用可能性を確認し、環境でサポートされているアテンションカーネルを検証します。CUDA 実行時間とピークメモリ消費量を測定するためのヘルパー関数を定義します。その後、メモリエフィシェントなアテンションが標準的なアテンションと同等の結果を生成することを確認するために両者を比較検証します。
Naive Causal Attention(単純因果アテンション)に対するメモリと速度のベンチマーク
Copy CodeCopiedUse a different Browser
print("\n" + "="*70 + "\n2. Memory & speed vs naive attention (fwd+bwd)\n" + "="*70)
print(f"{'seqlen':>8} | {'naive MB':>10} | {'xformers MB':>12} | {'naive ms':>9} | {'xf ms':>7}")
print("-"*60)
for M in [512, 1024, 2048, 4096]:
q, k, v = (torch.randn(2, M, 8, 64, device=device, dtype=torch.float16,
requires_grad=True) for _ in range(3))
def run_xf():
o = xops.memory_efficient_attention(q, k, v); o.sum().backward()
def run_naive():
o = vanilla_attention(q, k, v); o.sum().backward()
try:
nm = peak_mem_mb(run_naive); nt = cuda_time(run_naive, 8, 3)
except RuntimeError:
nm, nt = float("nan"), float("nan"); torch.cuda.empty_cache()
xm = peak_mem_mb(run_xf); xt = cuda_time(run_xf, 8, 3)
print(f"{M:>8} | {nm:>10.0f} | {xm:>12.0f} | {nt:>9.2f} | {xt:>7.2f}")
print("-> naive memory grows ~4x per doubling of M (it stores BxHxMxM);")
print(" xformers grows ~linearly and stays fast.")
print("\n" + "="*70 + "\n3. Causal attention via LowerTriangularMask\n" + "="*70)
B, M, H, K = 2, 256, 8, 64
q, k, v = (torch.randn(B, M, H, K, device=device, dtype=torch.float16) for _ in range(3))
out_causal = xops.memory_efficient_attention(q, k, v, attn_bias=ab.LowerTriangularMask())
ref_causal = vanilla_attention(q, k, v, causal=True).half()
print("causal max abs diff : {:.2e}".format((out_causal - ref_causal).abs().max().item()))
print("-> the mask is implicit; no MxM boolean tensor is allocated.")
進化する長さのシーケンスに対して、フォワードパスとバックワードパスを用いて、単純なアテンションと xFormers アテンションをベンチマークします。実行時間と GPU メモリのピーク使用量を比較し、xFormers が二次的なメモリ増加をどのように回避するかを観察します。また、暗黙的な下三角マスクを適用し、参照実装に対して因果アテンション(causal attention)を検証します。
可変長シーケンスのパッキングとグループ化クエリアテンションの実行
コードをコピーしました
別のブラウザを使用してください
print("\n" + "="*70 + "\n4. 可変長のパケット化されたバッチ — パディングによる無駄なし\n" + "="*70)
seqlens = [37, 120, 8, 200]
total = sum(seqlens)
H, K = 8, 64
q = torch.randn(1, total, H, K, device=device, dtype=torch.float16)
k = torch.randn(1, total, H, K, device=device, dtype=torch.float16)
v = torch.randn(1, total, H, K, device=device, dtype=torch.float16)
try:
bias = ab.BlockDiagonalMask.from_seqlens(seqlens)
out_packed = xops.memory_efficient_attention(q, k, v, attn_bias=bias)
s0 = seqlens[0]
ref0 = vanilla_attention(q[:, :s0], k[:, :s0], v[:, :s0]).half()
print("packed shape :", tuple(out_packed.shape), "(all", total, "tokens, no pad)")
print("segment-0 max diff : {:.2e}".format((out_packed[:, :s0] - ref0).abs().max().item()))
cbias = ab.BlockDiagonalCausalMask.from_seqlens(seqlens)
_ = xops.memory_efficient_attention(q, k, v, attn_bias=cbias)
print("-> also did a packed CAUSAL pass. This is how vLLM-style engines")
print(" batch requests of different lengths with zero padding overhead.")
splits = bias.split(out_packed)
print("recovered segments :", [tuple(t.shape) for t in splits])
except Exception as e:
print("BlockDiagonalMask path skipped on this version/backend:", repr(e))
print("\n" + "="*70 + "\n5. グループ化クエリアテンション (5-D BMGHK レイアウト)\n" + "="*70)
B, M, K = 2, 256, 64
n_q_heads, n_kv_heads = 8, 2
G, Hq = n_kv_heads, n_q_heads // n_kv_heads
try:
qg = torch.randn(B, M, G, Hq, K, device=device, dtype=torch.float16)
kg = torch.randn(B, M, G, 1, K, device=device, dtype=torch.float16)
vg = torch.randn(B, M, G, 1, K, device=device, dtype=torch.float16)
out_gqa = xops.memory_efficient_attention(qg, kg, vg)
print("GQA output shape :", tuple(out_gqa.shape), "= [B, M, G, Hq, K]")
print(f"-> {n_q_heads} query heads, only {n_kv_heads} KV heads: smaller KV-cache,")
print(" which is exactly what Llama-/Mistral-class models use at inference.")
except Exception as e:
print("GQA 5-D path skipped on this version/backend:", repr(e))
可変長のシーケンスを結合し、パディングなしでアテンションがシーケンス境界を超えないように BlockDiagonalMask を使用します。個々の出力を復元するとともに、デコーダスタイルのワークロードに対してパッキングされた因果アテンション(causal attention)も実行します。その後、グループ化クエリアテンション(grouped-query attention)を実証し、複数のクエリヘッドがより少ないキーバリューヘッドを共有することで KV キャッシュの要件を削減します。
カスタム ALiBi 加算位置バイアスの追加
print("\n" + "="*70 + "\n6. Custom ALiBi additive bias\n" + "="*70)
B, M, H, K = 1, 128, 8, 64
q, k, v = (torch.randn(B, M, H, K, device=device, dtype=torch.float16) for _ in range(3))
try:
slopes = (2.0 (-8.0 / H)) torch.arange(1, H + 1, device=device)
pos = torch.arange(M, device=device)
rel = (pos[None, :] - pos[:, None]).clamp(max=0).float()
alibi = slopes[:, None, None] * rel[None]
alibi = alibi[None].expand(B, H, M, M).to(torch.float16).contiguous()
causal = torch.triu(torch.ones(M, M, device=device, dtype=torch.bool), 1)
alibi = alibi.masked_fill(causal[None, None], float("-inf"))
out_alibi = xops.memory_efficient_attention(q, k, v, attn_bias=alibi)
print("ALiBi output shape :", tuple(out_alibi.shape))
print("-> any per-(head,query,key) additive bias works the same way.")
except Exception as e:
print("Custom-bias path skipped (some backends restrict bias shapes):", repr(e))
各アテンションヘッドに異なる線形位置ペナルティを適用するカスタム ALiBi テンソルを構築します。この加算バイアスを因果マスク(causal mask)と組み合わせることで、トークンが有効な過去の位置のみを参照できるようにします。得られたバイアスを xFormers アテンションに直接渡して、その出力の形状を検証します。
xFormers アテンションと SwiGLU を用いた GPT ブロックのトレーニング
コピーコード コピー済み別のブラウザを使用
print("\n" + "="*70 + "\n7. Train a small GPT block (xformers attn + SwiGLU)\n" + "="*70)
def make_swiglu(d, hidden):
"""Fused xformers SwiGLU if available, else a clean manual fallback."""
try:
m = xops.SwiGLU(in_features=d, hidden_features=hidden, out_features=d, bias=True)
return m, "fused xops.SwiGLU"
except Exception:
class SwiGLU(nn.Module):
def __init__(s):
super().__init__()
s.w12 = nn.Linear(d, 2 * hidden); s.w3 = nn.Linear(hidden, d)
def forward(s, x):
a, b = s.w12(x).chunk(2, -1)
return s.w3(F.silu(a) * b)
return SwiGLU(), "manual SwiGLU fallback"
class Block(nn.Module):
def __init__(self, d, n_heads, mlp_mult=4):
super().__init__()
self.h, self.k = n_heads, d // n_heads
self.n1, self.n2 = nn.LayerNorm(d), nn.LayerNorm(d)
self.qkv, self.proj = nn.Linear(d, 3 * d), nn.Linear(d, d)
self.ff, self.ff_kind = make_swiglu(d, mlp_mult * d)
def forward(self, x):
B, M, d = x.shape
qkv = self.qkv(self.n1(x)).reshape(B, M, 3, self.h, self.k)
q, kk, vv = qkv.unbind(2)
a = xops.memory_efficient_attention(q, kk, vv, attn_bias=ab.LowerTriangularMask())
x = x + self.proj(a.reshape(B, M, d))
return x + self.ff(self.n2(x))
class TinyGPT(nn.Module):
def __init__(self, vocab, d=128, n_layers=3, n_heads=8, maxlen=64):
super().__init__()
self.tok = nn.Embedding(vocab, d); self.pos = nn.Embedding(maxlen, d)
self.blocks = nn.ModuleList(Block(d, n_heads) for _ in range(n_layers))
self.nf, self.head = nn.LayerNorm(d), nn.Linear(d, vocab)
def forward(self, idx):
B, M = idx.shape
x = self.tok(idx) + self.pos(torch.arange(M, device=idx.device))[None]
for b in self.blocks: x = b(x)
return self.head(self.nf(x))
VOCAB, SEQ = 64, 64
def make_batch(B):
start = torch.randint(0, VOCAB, (B, 1), device=device)
return (start + torch.arange(SEQ, device=device)[None]) % VOCAB
model = TinyGPT(VOCAB).to(device)
print("FFN type :", model.blocks[0].ff_kind)
opt = torch.optim.AdamW(model.parameters(), lr=3e-3)
scaler = torch.amp.GradScaler("cuda")
for step in range(400):
seq = make_batch(64); inp, tgt = seq[:, :-1], seq[:, 1:]
with torch.autocast("cuda", dtype=torch.float16):
logits = model(inp)
loss = F.cross_entropy(logits.reshape(-1, VOCAB), tgt.reshape(-1))
opt.zero_grad(); scaler.scale(loss).backward(); scaler.step(opt); scaler.update()
if step % 80 == 0 or step == 399:
acc = (logits.argmax(-1) == tgt).float().mean().item()
print(f"step {step:4d} | loss {loss.item():.4f} | next-token acc {acc*100:5.1f}%")
print("-> a full causal transformer running on memory-efficient attention,")
print(" trained end-to-end with AMP. Swap in real data/tokenizer to scale up.")
print("\nDone. Sections 1-3 are core; 4-6 are the advanced bits worth keeping.")
因果的な xFormers アテンション、残差接続、正規化、そして SwiGLU フィードフォワード層を用いたコンパクトな GPT スタイルの Transformer を構築します。このモデルは、語彙サイズを法とする加算計算を行う合成された次トークン予測タスクに対して、自動混合精度でトレーニングを行います。完全なメモリ効率型 Transformer がエンドツーエンドで正常に学習できることを確認するために、その損失と精度を監視します。
結論として、xFormers が基本的なアテンション計算を変更することなく Transformer の効率性をどのように向上させるかについて、私たちは実践的な理解を深めました。長系列のコストを削減するメモリ効率的なカーネルの仕組みや、因果マスク、パケット化されたシーケンス、グループ化クエリアテンション(Grouped-Query Attention)、加算バイアスが現実的なトレーニングおよび推論ワークフローをどのように支えるかを確認しました。これらの機能をコンパクトな GPT モデルに統合し、エンドツーエンドでトレーニングを行うことで、より大規模な言語モデルや過酷なデータセットに対して xFormers を適用するための堅固な基盤を得ることができました。
ノートブック付きの完全なコードをチェックしてください。また、Twitter でフォローすることも歓迎します。15 万人以上の ML サブレッド(SubReddit)に参加し、ニュースレターを購読することを忘れないでください。待ってください!Telegram を利用していますか?今なら Telegram でも私たちに参加できます。
GitHub リポジトリの宣伝や Hugging Face ページ、製品リリース、ウェビナーなどのプロモーションのためにパートナーシップをご希望ですか?私たちに連絡してください。
メモリ効率の高いトランスフォーマーを、パッキングされたシーケンス、GQA(グループ化クエリアテンション)、ALiBi、SwiGLU、因果的注意機構を用いて xFormers で構築する方法という記事は、MarkTechPost に最初に掲載されました。
原文を表示
In this tutorial, we implement xFormers: a practical toolkit for building fast, memory-efficient Transformer models on GPUs. We begin by validating memory-efficient attention against a standard attention implementation, then compare their speed and memory consumption across different sequence lengths. We then examine causal masking, packed variable-length sequences, grouped-query attention, and custom ALiBi positional biases. Finally, we combine these techniques into a trainable GPT-style model that uses xFormers attention, SwiGLU feed-forward layers, and automatic mixed-precision training.
Setting Up xFormers and Validating Memory-Efficient Attention
Copy CodeCopiedUse a different Browser
import subprocess, sys
def _pip(*a): subprocess.run([sys.executable, "-m", "pip", "install", *a], check=False)
try:
import xformers
except Exception:
_pip("-q", "-U", "xformers")
import math, time
import torch, torch.nn as nn, torch.nn.functional as F
import xformers, xformers.ops as xops
from xformers.ops import fmha
ab = fmha.attn_bias
assert torch.cuda.is_available(), (
"No GPU detected. In Colab: Runtime → Change runtime type → GPU, then re-run.")
device = "cuda"
torch.manual_seed(0)
print("torch :", torch.__version__)
print("xformers :", xformers.__version__)
print("GPU :", torch.cuda.get_device_name(0))
print("\n--- xformers.info (which kernels are built/available) ---")
try:
subprocess.run([sys.executable, "-m", "xformers.info"], check=False)
except Exception as e:
print("xformers.info unavailable:", e)
def cuda_time(fn, iters=20, warmup=5):
for _ in range(warmup): fn()
torch.cuda.synchronize()
s, e = (torch.cuda.Event(enable_timing=True) for _ in range(2))
s.record()
for _ in range(iters): fn()
e.record(); torch.cuda.synchronize()
return s.elapsed_time(e) / iters
def peak_mem_mb(fn):
torch.cuda.empty_cache(); torch.cuda.reset_peak_memory_stats()
fn(); torch.cuda.synchronize()
return torch.cuda.max_memory_allocated() / 1e6
def vanilla_attention(q, k, v, causal=False):
"""Reference attention that MATERIALIZES the [B,H,M,M] score matrix.
Inputs are xformers-layout [B, M, H, K]."""
q, k, v = (t.transpose(1, 2).float() for t in (q, k, v))
scores = (q @ k.transpose(-2, -1)) / math.sqrt(q.shape[-1])
if causal:
M = scores.shape[-1]
m = torch.triu(torch.ones(M, M, device=q.device, dtype=torch.bool), 1)
scores = scores.masked_fill(m, float("-inf"))
out = scores.softmax(-1) @ v
return out.transpose(1, 2)
print("\n" + "="*70 + "\n1. memory_efficient_attention basics + correctness\n" + "="*70)
B, M, H, K = 2, 512, 8, 64
q, k, v = (torch.randn(B, M, H, K, device=device, dtype=torch.float16) for _ in range(3))
out_xf = xops.memory_efficient_attention(q, k, v)
out_ref = vanilla_attention(q, k, v).half()
print("output shape :", tuple(out_xf.shape), "(layout B, M, H, K)")
print("max abs diff vs ref : {:.2e}".format((out_xf - out_ref).abs().max().item()))
print("-> it's EXACT attention (fp16 rounding only), just computed without")
print(" ever storing the full MxM score matrix.")
We install and import xFormers, verify GPU availability, and inspect the attention kernels supported by the environment. We define helper functions for measuring CUDA execution time and peak memory consumption. We then validate memory-efficient attention against standard attention to confirm that both produce results that closely match each other.
Benchmarking Memory and Speed Against Naive Causal Attention
Copy CodeCopiedUse a different Browser
print("\n" + "="*70 + "\n2. Memory & speed vs naive attention (fwd+bwd)\n" + "="*70)
print(f"{'seqlen':>8} | {'naive MB':>10} | {'xformers MB':>12} | {'naive ms':>9} | {'xf ms':>7}")
print("-"*60)
for M in [512, 1024, 2048, 4096]:
q, k, v = (torch.randn(2, M, 8, 64, device=device, dtype=torch.float16,
requires_grad=True) for _ in range(3))
def run_xf():
o = xops.memory_efficient_attention(q, k, v); o.sum().backward()
def run_naive():
o = vanilla_attention(q, k, v); o.sum().backward()
try:
nm = peak_mem_mb(run_naive); nt = cuda_time(run_naive, 8, 3)
except RuntimeError:
nm, nt = float("nan"), float("nan"); torch.cuda.empty_cache()
xm = peak_mem_mb(run_xf); xt = cuda_time(run_xf, 8, 3)
print(f"{M:>8} | {nm:>10.0f} | {xm:>12.0f} | {nt:>9.2f} | {xt:>7.2f}")
print("-> naive memory grows ~4x per doubling of M (it stores BxHxMxM);")
print(" xformers grows ~linearly and stays fast.")
print("\n" + "="*70 + "\n3. Causal attention via LowerTriangularMask\n" + "="*70)
B, M, H, K = 2, 256, 8, 64
q, k, v = (torch.randn(B, M, H, K, device=device, dtype=torch.float16) for _ in range(3))
out_causal = xops.memory_efficient_attention(q, k, v, attn_bias=ab.LowerTriangularMask())
ref_causal = vanilla_attention(q, k, v, causal=True).half()
print("causal max abs diff : {:.2e}".format((out_causal - ref_causal).abs().max().item()))
print("-> the mask is implicit; no MxM boolean tensor is allocated.")
We benchmark naive attention and xFormers attention across progressively longer sequences using forward and backward passes. We compare their execution times and peak GPU memory usage to observe how xFormers avoids quadratic memory growth. We also apply an implicit lower-triangular mask and verify causal attention against the reference implementation.
Packing Variable-Length Sequences and Running Grouped-Query Attention
Copy CodeCopiedUse a different Browser
print("\n" + "="*70 + "\n4. Variable-length packed batch — no padding waste\n" + "="*70)
seqlens = [37, 120, 8, 200]
total = sum(seqlens)
H, K = 8, 64
q = torch.randn(1, total, H, K, device=device, dtype=torch.float16)
k = torch.randn(1, total, H, K, device=device, dtype=torch.float16)
v = torch.randn(1, total, H, K, device=device, dtype=torch.float16)
try:
bias = ab.BlockDiagonalMask.from_seqlens(seqlens)
out_packed = xops.memory_efficient_attention(q, k, v, attn_bias=bias)
s0 = seqlens[0]
ref0 = vanilla_attention(q[:, :s0], k[:, :s0], v[:, :s0]).half()
print("packed shape :", tuple(out_packed.shape), "(all", total, "tokens, no pad)")
print("segment-0 max diff : {:.2e}".format((out_packed[:, :s0] - ref0).abs().max().item()))
cbias = ab.BlockDiagonalCausalMask.from_seqlens(seqlens)
_ = xops.memory_efficient_attention(q, k, v, attn_bias=cbias)
print("-> also did a packed CAUSAL pass. This is how vLLM-style engines")
print(" batch requests of different lengths with zero padding overhead.")
splits = bias.split(out_packed)
print("recovered segments :", [tuple(t.shape) for t in splits])
except Exception as e:
print("BlockDiagonalMask path skipped on this version/backend:", repr(e))
print("\n" + "="*70 + "\n5. Grouped-query attention (5-D BMGHK layout)\n" + "="*70)
B, M, K = 2, 256, 64
n_q_heads, n_kv_heads = 8, 2
G, Hq = n_kv_heads, n_q_heads // n_kv_heads
try:
qg = torch.randn(B, M, G, Hq, K, device=device, dtype=torch.float16)
kg = torch.randn(B, M, G, 1, K, device=device, dtype=torch.float16)
vg = torch.randn(B, M, G, 1, K, device=device, dtype=torch.float16)
out_gqa = xops.memory_efficient_attention(qg, kg, vg)
print("GQA output shape :", tuple(out_gqa.shape), "= [B, M, G, Hq, K]")
print(f"-> {n_q_heads} query heads, only {n_kv_heads} KV heads: smaller KV-cache,")
print(" which is exactly what Llama-/Mistral-class models use at inference.")
except Exception as e:
print("GQA 5-D path skipped on this version/backend:", repr(e))
We concatenate variable-length sequences and use BlockDiagonalMask to prevent attention from crossing sequence boundaries without padding. We recover the individual outputs and also perform packed causal attention for decoder-style workloads. We then demonstrate grouped-query attention, where multiple query heads share fewer key-value heads to reduce KV-cache requirements.
Adding a Custom ALiBi Additive Positional Bias
Copy CodeCopiedUse a different Browser
print("\n" + "="*70 + "\n6. Custom ALiBi additive bias\n" + "="*70)
B, M, H, K = 1, 128, 8, 64
q, k, v = (torch.randn(B, M, H, K, device=device, dtype=torch.float16) for _ in range(3))
try:
slopes = (2.0 (-8.0 / H)) torch.arange(1, H + 1, device=device)
pos = torch.arange(M, device=device)
rel = (pos[None, :] - pos[:, None]).clamp(max=0).float()
alibi = slopes[:, None, None] * rel[None]
alibi = alibi[None].expand(B, H, M, M).to(torch.float16).contiguous()
causal = torch.triu(torch.ones(M, M, device=device, dtype=torch.bool), 1)
alibi = alibi.masked_fill(causal[None, None], float("-inf"))
out_alibi = xops.memory_efficient_attention(q, k, v, attn_bias=alibi)
print("ALiBi output shape :", tuple(out_alibi.shape))
print("-> any per-(head,query,key) additive bias works the same way.")
except Exception as e:
print("Custom-bias path skipped (some backends restrict bias shapes):", repr(e))
We construct a custom ALiBi tensor that applies a different linear positional penalty to each attention head. We combine this additive bias with a causal mask so that tokens attend only to valid previous positions. We pass the resulting bias directly to xFormers attention and verify the shape of its output.
Training a GPT Block with xFormers Attention and SwiGLU
Copy CodeCopiedUse a different Browser
print("\n" + "="*70 + "\n7. Train a small GPT block (xformers attn + SwiGLU)\n" + "="*70)
def make_swiglu(d, hidden):
"""Fused xformers SwiGLU if available, else a clean manual fallback."""
try:
m = xops.SwiGLU(in_features=d, hidden_features=hidden, out_features=d, bias=True)
return m, "fused xops.SwiGLU"
except Exception:
class SwiGLU(nn.Module):
def __init__(s):
super().__init__()
s.w12 = nn.Linear(d, 2 * hidden); s.w3 = nn.Linear(hidden, d)
def forward(s, x):
a, b = s.w12(x).chunk(2, -1)
return s.w3(F.silu(a) * b)
return SwiGLU(), "manual SwiGLU fallback"
class Block(nn.Module):
def __init__(self, d, n_heads, mlp_mult=4):
super().__init__()
self.h, self.k = n_heads, d // n_heads
self.n1, self.n2 = nn.LayerNorm(d), nn.LayerNorm(d)
self.qkv, self.proj = nn.Linear(d, 3 * d), nn.Linear(d, d)
self.ff, self.ff_kind = make_swiglu(d, mlp_mult * d)
def forward(self, x):
B, M, d = x.shape
qkv = self.qkv(self.n1(x)).reshape(B, M, 3, self.h, self.k)
q, kk, vv = qkv.unbind(2)
a = xops.memory_efficient_attention(q, kk, vv, attn_bias=ab.LowerTriangularMask())
x = x + self.proj(a.reshape(B, M, d))
return x + self.ff(self.n2(x))
class TinyGPT(nn.Module):
def __init__(self, vocab, d=128, n_layers=3, n_heads=8, maxlen=64):
super().__init__()
self.tok = nn.Embedding(vocab, d); self.pos = nn.Embedding(maxlen, d)
self.blocks = nn.ModuleList(Block(d, n_heads) for _ in range(n_layers))
self.nf, self.head = nn.LayerNorm(d), nn.Linear(d, vocab)
def forward(self, idx):
B, M = idx.shape
x = self.tok(idx) + self.pos(torch.arange(M, device=idx.device))[None]
for b in self.blocks: x = b(x)
return self.head(self.nf(x))
VOCAB, SEQ = 64, 64
def make_batch(B):
start = torch.randint(0, VOCAB, (B, 1), device=device)
return (start + torch.arange(SEQ, device=device)[None]) % VOCAB
model = TinyGPT(VOCAB).to(device)
print("FFN type :", model.blocks[0].ff_kind)
opt = torch.optim.AdamW(model.parameters(), lr=3e-3)
scaler = torch.amp.GradScaler("cuda")
for step in range(400):
seq = make_batch(64); inp, tgt = seq[:, :-1], seq[:, 1:]
with torch.autocast("cuda", dtype=torch.float16):
logits = model(inp)
loss = F.cross_entropy(logits.reshape(-1, VOCAB), tgt.reshape(-1))
opt.zero_grad(); scaler.scale(loss).backward(); scaler.step(opt); scaler.update()
if step % 80 == 0 or step == 399:
acc = (logits.argmax(-1) == tgt).float().mean().item()
print(f"step {step:4d} | loss {loss.item():.4f} | next-token acc {acc*100:5.1f}%")
print("-> a full causal transformer running on memory-efficient attention,")
print(" trained end-to-end with AMP. Swap in real data/tokenizer to scale up.")
print("\nDone. Sections 1-3 are core; 4-6 are the advanced bits worth keeping.")
We build a compact GPT-style Transformer using causal xFormers attention, residual connections, normalization, and SwiGLU feed-forward layers. We train the model with automatic mixed precision on a synthetic next-token prediction task that counts upward modulo the vocabulary size. We monitor its loss and accuracy to confirm that the complete memory-efficient Transformer learns successfully end-to-end.
Conclusion
In conclusion, we developed a practical understanding of how xFormers improves Transformer efficiency without changing the fundamental attention calculation. We saw how memory-efficient kernels reduce the cost of long sequences, while causal masks, packed sequences, grouped-query attention, and additive biases support realistic training and inference workflows. We concluded by integrating these capabilities into a compact GPT model and training it end-to-end, giving us a strong foundation for applying xFormers to larger language models and more demanding datasets.
Check out the Full Codes with Notebook. Also, feel free to follow us on Twitter and don’t forget to join our 150k+ML SubReddit and Subscribe to our Newsletter. Wait! are you on telegram? now you can join us on telegram as well.
Need to partner with us for promoting your GitHub Repo OR Hugging Face Page OR Product Release OR Webinar etc.? Connect with us
The post How to Build Memory-Efficient Transformers with xFormers Using Packed Sequences, GQA, ALiBi, SwiGLU, and Causal Attention appeared first on MarkTechPost.
関連記事
[AINews] GLM は GPT より優れているか?GLM-5.2 が実用性を証明、Z.ai が 12 月までに「Open Fable」を公開予定
Latent Space のニュースでは、中国のモデル「GLM-5.2」がベンチマークで優れた結果を示し実用性があると評価されたことと、Z.ai が 12 月までにオープンソースプロジェクト「Open Fable」を発表する見込みについて報じられています。
Salesforce CodeGen チュートリアル:ユニットテストと安全性チェック付きの Python 関数の生成・検証・再ランク付け
Salesforce は Hugging Face からモデルを読み込み、自然言語から Python 関数を生成するエンドツーエンドワークフローを公開した。この手法には構文チェックや静的解析、ユニットテストによる検証が含まれる。
CloudWatch の SageMaker メトリクスとインサイトダッシュボードを用いた生成 AI 推論の監視・デバッグ
AWS は、大規模な生成 AI 推論エンドポイントの P99 レイテンシ急上昇などのトラブルを GPU メモリ圧力や KV キャッシュ飽和などから特定できるよう、CloudWatch に SageMaker の詳細メトリクスとインサイトダッシュボードを追加した。
今日のまとめ
AI日報で今日の重要ニュースをまとめ読み