ワープデコードによるMoEモデル推論の改善
Cursor Blogは、Blackwell GPU上でのMoEモデル推論において、従来の専門家中心アプローチから出力中心の「warp decode」手法へ転換することで、1.84倍のスループット向上と精度向上を実現した技術革新を報告している。
キーポイント
従来のMoE推論の課題
従来の専門家中心アプローチは、大規模バッチ処理では効率的だが、オート回帰デコード(1トークンずつ生成)ではデータレイアウト管理のための5段階のオーバーヘッドが発生し、非効率だった。
warp decodeの技術革新
並列処理の軸を専門家から出力(ニューロン)に転換し、各ワープが1つの出力値を計算するように再設計した。これにより、中間バッファや同期ポイントが不要になり、計算レイヤーを2つのカーネルに圧縮した。
実証された性能向上
Blackwell GPU上で1.84倍のスループット向上を達成し、同時に精度も向上(FP32リファレンスに1.4倍近い出力)という、性能と精度の両立という稀な成果を実現した。
実用への影響
Composerの研究・トレーニングパイプラインを高速化し、モデル改善サイクルを加速させ、新バージョンの頻繁なリリースを可能にしている。
Warp Decodeの効率化メカニズム
Warp Decodeは、従来のパイプラインで必要だったステージとバッファを削除し、ワープの独立性を確保することで、パフォーマンスを向上させる。
ステージ削除によるオーバーヘッド削減
パディング、スキャッター、結合ステップを排除し、ワープ内でルーティング重みを累積器に折り込むことで、中間結果のメモリ書き込みと読み込みコストを削減する。
Warp Decodeによるメモリ効率化
Warp decodeは中間バッファを排除し、8つの専門家の寄与をレジスタアキュムレータに統合することで、トークンあたり32KB以上のグローバルメモリトラフィックを削減し、L2キャッシュ容量を解放する。
影響分析・編集コメントを表示
影響分析
この技術革新は、MoEモデルの実用化における重要なボトルネックであった推論効率を大幅に改善し、大規模言語モデルの展開コスト削減と性能向上に直接寄与する。特にリアルタイム応用やリソース制約環境でのMoEモデル採用を促進する可能性が高い。
編集コメント
理論的な改善だけでなく、実測値での性能・精度両方の向上を実証した点が説得力があり、MoE実装のベストプラクティスを変える可能性のある重要な技術報告。
ほとんどのMoE推論システムは、トークンの生成パスをエキスパート中心に構成しています。これはルーティングの仕組みを反映したもので、大規模なスケールでは標準的なアプローチとなっています。しかし、Blackwell GPUにおける小バッチデコードでは、エキスパートではなく出力を中心にカーネルを構成する方が、より高い性能を発揮することがわかりました。このアプローチを「ワープデコード」と呼びます。
ワープデコードは、Blackwell上でのMoEデコードにおいて達成可能な最大メモリ帯域幅について考察することで生まれました。それが、並列化の軸を完全に逆転させることにつながりました。ワープをエキスパートに割り当てる代わりに、各ワープを単一の出力値(ニューロン)に割り当てるのです。
性能と精度の両方を向上させるカーネルは稀であり、ワープデコードはその一つです。Blackwellでは、スループットを1.84倍向上させると同時に、出力が完全なFP32リファレンスに1.4倍近づくことで精度も向上します。これにより、Composerの研究とトレーニングパイプラインが高速化され、モデルをより迅速に改善し、新しいバージョンをより頻繁にリリースできるようになります。
#従来のMoEパス
現代のMoEモデルは、各トークンを専門化されたエキスパートネットワークのサブセット(例えば、ある層において128個中8個)にルーティングします。標準的な実装では、各エキスパートが必要とするトークンを集め、計算を実行し、結果を再構成することで、すべての計算をエキスパート中心に構成します。
これは、エキスパートごとの共有処理がデータ配置のオーバーヘッドを相殺するプレフィルや大バッチではうまく機能します。しかし、一度に1トークンしか生成しないオートリグレッシブ(自己回帰)デコードステップでは、それを正当化するだけの十分な共有処理がありません。従来のパスにおける8つのステージのうち、5つはエキスパート中心の視点のためのデータレイアウトを管理するためだけに存在し、実際の計算は行いません。
#我々の変更点
ワープデコードは、エキスパートではなく出力を中心に並列性を再構成することで、これら5つの「簿記」ステップを排除します。
現代のGPUは、ワープと呼ばれる32本の並列処理レーンのグループで命令を実行します。この新しいアプローチでは、各ワープは計算すべき1つの出力値に割り当てられます。ワープは必要な重みデータをメモリから直接ストリーミングし、ルーティングされた8つのエキスパートすべてからの合計を単一の累積合計に集約し、1つの結果を書き込みます。
このワープ独立性により、ワープデコードはステージング、ハンドオフ、クロスワープ同期ポイント、中間バッファを一切必要とせずに実行できます。MoE計算層全体が、moe_gate_up_3d_batched と moe_down_3d_batched の2つのカーネルに圧縮されます。
#2つのカーネルの仕組み
ゲート/アップカーネルでは、各協調スレッド配列(CTA)は8つのワープで構成され、各ワープはトークンとルーティングされたエキスパートの各ペアに対して1つの中間ニューロンを担当します。ワープはルーティングされたエキスパートIDをロードし、そのニューロンに対応するゲートとアップの重み行を読み込み、入力活性化ベクトルをストリーミングします。MXFP8の重みはオンデマンドでFP32に変換され、両方のドット積はプライベートレジスタに累積されます。
2つのカーネルが1パスに融合されているため、活性化ベクトルは一度読み込まれると、共有メモリでのステージングを経ずに、両方の射影ですぐに再利用されます。ワープレベルでのリダクション後、ワープはSiLU(gate) × upを適用して1つの中間値を書き込みます。
ダウンカーネルでは、各ワープは1トークンに対する1つの出力次元を担当します。これはトップkのルーティングされた全エキスパートをループし、関連するダウン射影の重み行をロードして中間活性化をストリーミングしながら、各エキスパートのルーティング重みを単一の累積FP32アキュムレータに折り込みます。
全エキスパートの処理後、__shfl_xor_sync を用いたワープレベルのバタフライリダクションで、32のレーンローカルな部分和をリダクションします。
ここでの利点は、同期がレーンマスクを介して命令に組み込まれているため、L1への往復、バンク競合、明示的なバリアが不要な点です。別個のエピローグではなく、最終的な重み付けトップk結合が射影自体の一部となります。
ワープデコードにおける各ワープは独立しており、そのライフタイムを通じて単一の安定した割り当て、すなわち1つの出力スカラーを生成します。このワープ独立性こそが、従来のパスが必要とした共有メモリステージング、クロスワープ同期、中間バッファを排除するものです。
#パイプラインの簡素化と高速化
ワープデコードは、主に2つのメカニズムを通じて性能向上を達成します。従来のパスが必要としたステージとバッファを削除すること、およびワープ独立性を創出することでより優れたスケジューリングとレイテンシ隠蔽を可能にすることです。
#ステージの削除
ステージ削除がスループット向上の大部分をもたらします。パディング、スキャッタ、結合ステップを排除します。これらのステージを削除するには、従来のパイプラインのステージを単に融合するのではなく、並列性を根本から再構成する必要があります。
#パディングの削除
従来のパス:グループ化カーネルの要件に合わせるため、各エキスパートのトークンリストを2のべき乗、または128バイト境界にパディングします。単一トークンのデコード時、これは償却不可能なオーバーヘッドとなります。
ワープデコードパス:エキスパートごとのバッチを一切形成しないため、このオーバーヘッドを完全に回避します。
#スキャッタと結合の削除
従来のパス:各エキスパートの処理後、8つの中間結果をGPUメモリに書き込み、それらを結合する別個のリダクションステップを実行します。
ワープデコードパス:各エキスパートのルーティング重みは、ワープ内の累積アキュムレータに折り込まれます。8つの中間結果がメモリ上に実体化されることはなく、後続のリダクションパスでの書き込みと読み込みの両方のコストが節約されます。
#バッファの削除
この再構成により、従来のパスがそのエキスパート中心のレイアウトの結果として必要としていた、2つの中間メモリバッファも削除されます。
1つ目は活性化収集バッファで、入力活性化ベクトルをコピーしてエキスパートメジャーレイアウトに再配置したものです。バッチサイズ1では、これは既に存在するデータの完全なコピーです。2つ目はエキスパートごとの出力バッファです。エキスパート数8、隠れ次元2048の場合、これはBF16でトークンあたり8 × 2048 × 2バイト = 32 KBとなり、割り当て、書き込み、即座に1回読み込まれ、破棄されます。
ワープデコードは、8つのエキスパートの寄与を32のワープレーンにわたるレジスタアキュムレータに折り込むことで両方を排除し、最終的な単一スカラー書き込みまでグローバルメモリに何も流出させません。トークンあたり32KB以上の中間バッファトラフィックを削除することで、実際のパフォーマンスを決定する重み行のためのL2キャッシュ容量が解放されます。
#ワープ独立性
この再構成により、残された計算も高速化されます。なぜなら、カーネルは設計上「驚くほど並列」だからです。すべてのワープは互いに完全に独立しています。各ワープが正確に1つの出力スカラーを担当し、必要な重み行のみを読み込むため、ワープ間で共有される可変状態は存在しません。
単一ワープのレベルでは、この独立性は完全です。入力活性化は読み取り専用であり、アキュムレータはプライベートレジスタに存在し、出力書き込みは一意のアドレスに向かいます。ハードウェアスケジューラの観点からは、出力次元全体が独立した作業項目のフラットなプールです。
GPUのワープスケジューラは、正確性の制約なしに、任意の順序で、いつでも任意のワープを発行できます。1つのワープがメモリロード待ちで停止すると、スケジューラは即座に別のワープに切り替えます。B200の148個のストリーミングマルチプロセッサ全体で実行中の数千のワープにより、メモリレイテンシは他のワープによる有用な計算によってほぼ完全に隠蔽されます。
カーネルはまた線形にスケールし、出力次元を2倍にすると、追加の同期なしに独立したワープの数も2倍になります。これはトークンバッチ次元全体でも同様であり、スケジューラはノード間に依存関係のない作業のフラットな名前空間を見ます。これは、エキスパートレベルのGEMMカーネルがブロック内での調整を必要とする従来のパスとは対照的です。
#スケールにおけるエンドツーエンドデコードスループット
NVIDIA B200 GPU上でQwen-3スタイルのモデルを実行する内部推論システムでのテストでは、一貫したスループット向上が確認されました。スループット向上はすべてのコンテキスト長の範囲で均一であり、これはプロンプト長に依存しない、純粋な生成時間の改善であることを示しています。
#精度の向上
中間活性化の量子化ステップを削除することは、測定可能な品質への影響をもたらします。BF16活性化をMXFP8に変換して戻すことは、モデルの層全体に蓄積する丸め誤差の下限を導入します。ワープデコードは活性化をBF16で、アキュムレータをFP32で維持するため、リダクションは劣化した入力で操作されることはありません。その結果、ワープデコードは従来のパスよりも完全な32ビットグラウンドトゥルースに1.4倍近い出力を生成します。
#ハードウェア効率
我々は、ハードウェアの最大スループットにどれだけ近づけるかを問うことからワープデコードの開発を始めました。B200の連続メモリ読み取りの測定ピークは6.8 TB/sです(コピーカーネルを使用して測定)。ワープデコードはバッチサイズ32で3.95 TB/sを維持し、そのピークの58%に達します。残りのギャップは、各トークンが5、8、14、19などの非隣接エキスパートにルーティングされる可能性があるため、エキスパートルーティングが生み出すランダムアクセスパターンによるメモリレイテンシコストを反映している可能性が高いです。
対照的に、ピークスループットは連続的(0,1,2,3)なメモリ読み取りを使用して測定されます。リファレンス実装に対する正確性はすべてのバッチサイズで厳密に保たれました:最小コサイン類似度 > 0.999996、最大絶対差 0.001953。
#ワープデコードとComposerトレーニング
ワープデコードは、エキスパート中心実行の汎用的な代替手段ではありません。プレフィルや大規模バッチ推論のようなより高ボリュームのワークロードでは、多くのトークンが同じエキスパートを共有し、それらを構成するコストが十分な実計算によって相殺されるため、エキスパート中心のパッキングが有利です。
ワープデコードは、MoEデコードでしばしばそうであるように、エキスパートごとにそのオーバーヘッドを正当化するだけの十分な共有処理がない場合に優位性を発揮します。これは、我々がComposerを継続的に改善する方法の重要な部分となっています。事前トレーニングデータと強化学習(RL)への投資がモデル出力の品質を決定する一方で、ワープデコードのような推論への投資は、それらの出力が開発者にどれだけ迅速かつ正確に届くかを決定します。
原文を表示
Most MoE inference systems organize the token generation path around experts. This mirrors how routing works and has been the standard approach at scale. For small-batch decode on Blackwell GPUs, however, we found that organizing the kernel around outputs rather than experts works better. We call this approach “warp decode.”
We arrived at warp decode by thinking about what the maximum achievable memory bandwidth for MoE decode on Blackwell actually is. That led us to flip the parallelism axis entirely. Instead of assigning warps to experts, we assign each warp to a single output value (neuron).
Kernels that improve both performance and accuracy are rare, and warp decode is one of them. On Blackwell, it delivers a 1.84x throughput improvement while also improving accuracy with outputs 1.4x closer to a full FP32 reference. This speeds up the research and training pipeline for Composer, letting us improve the model faster and ship new versions more often.
#The conventional MoE path
Modern MoE models route each token through a subset of specialized expert networks, selecting, for example, 8 out of 128 at a given layer. The standard implementation organizes all computation around those experts by collecting the tokens each expert needs, running the math, and reassembling the results.
This works well for prefill and large batches, where the shared work per expert amortizes the overhead of organizing the data. But during the autoregressive decode step, where we only produce one token at a time, there isn’t enough shared work to justify it. Five of the eight stages in the traditional path exist purely to manage data layout for the expert-centric view and perform no actual computation.
#What we changed
Warp decode eliminates those five “bookkeeping” steps by reorganizing the parallelism around outputs rather than experts.
Modern GPUs execute instructions in groups of 32 parallel processing lanes called a warp. In our new approach, each warp is assigned exactly one output value to compute. The warp streams the weight data it needs directly from memory, aggregates the totals across all eight routed experts in a single running total, and writes one result.
This warp independence lets warp decode run without any staging, handoffs, cross-warp sync points, or intermediate buffers. The entire MoE compute layer is compressed into two kernels, moe_gate_up_3d_batched
moe_down_3d_batched
#How the two kernels work
In the gate/up kernel, each cooperative thread array (CTA) is eight warps, and each warp owns one intermediate neuron for each pairing of a token and a routed expert. The warp loads the routed expert ID, reads the gate and up weight rows for that neuron, and streams over the input activation vector. MXFP8 weights are converted to FP32 on the fly, and both dot products accumulate in private registers.
Because the two kernels are fused into one pass, the activation vector is read once and reused immediately for both projections, without any shared memory staging. After a warp-level reduction, the warp applies SiLU(gate) × up and writes one intermediate value.
In the down kernel, each warp owns one output dimension for one token. It loops over all top-k routed experts, loading the relevant down-projection weight row and streaming over the intermediate activations, while folding each expert's routing weight into a single running FP32 accumulator.
After all experts are processed, we reduce the 32 lane-local partial sums with a warp-level butterfly reduction using __shfl_xor_sync
The payoff here is we don’t need L1 round-trips, bank conflicts, or explicit barriers because synchronization is baked into the instruction via the lane mask. Rather than a separate epilogue, the final weighted top-k combination becomes part of the projection itself.
Each warp in warp decode is independent and gets a single, stable assignment for its entire lifetime: produce one output scalar. This warp independence is what eliminates the shared memory staging, cross-warp synchronization, and intermediate buffers that the traditional path requires.
#Pipeline simplification and acceleration
Warp decode achieves performance improvements through two main mechanisms: removing stages and buffers that the traditional path required, and by creating warp independence which allows for better scheduling and latency hiding.
#Stage Elimination
Stage eliminations provide most of the throughput gain. We eliminate padding, scattering, and the combine step. Removing these stages requires reorganizing the parallelism from the ground up, rather than merely fusing stages of the traditional pipeline.
#Elimination of padding
Traditional path: Pads each expert's token list to power-of-2, or 128 byte, boundaries to conform to grouped kernel requirements. At decode time with a single token this is non-amortizable overhead.
Warp decode path: Avoids this overhead entirely by never forming per-expert batches.
#Elimination of scatter and combine
Traditional path: After each expert finishes, it writes eight intermediate results to GPU memory, then runs a separate reduction step to combine them.
Warp decode path: The routing weight for each expert is folded into the running accumulator within the warp. The eight intermediate results never materialize in memory, saving both the write and read costs of a subsequent reduction pass.
#Buffer elimination
The reorganization also removes two intermediate memory buffers that the traditional path requires as a consequence of its expert-centric layout.
The first is an activation gather buffer, which is the input activation vector copied and rearranged into expert-major layout. At batch size 1, this is a full copy of data that already exists. The second is a per-expert output buffer. With eight experts and hidden dimension 2048, this is 8 × 2048 × 2 bytes = 32 KB per token in BF16, allocated, written, immediately read once, and discarded.
Warp decode eliminates both by folding the eight expert contributions into a register accumulator across 32 warp lanes, where nothing reaches global memory until the final single-scalar write. Removing 32+ KB of intermediate buffer traffic per token frees L2 cache capacity for the weight rows that actually determine performance.
#Warp Independence
The reorganization also makes the retained computation faster because the kernel is “embarrassingly parallel” by design: every warp is completely independent of every other. Because each warp owns exactly one output scalar and reads only the weight rows it needs, there is no shared mutable state between warps.
At the level of a single warp, this independence is total. The input activations are read-only, the accumulator lives in private registers, and the output write lands at a unique address. From the hardware scheduler's perspective, the entire output dimension is a flat pool of independent work items.
The GPU's warp scheduler can issue any warp at any time, in any order, with no correctness constraint. When one warp stalls waiting on a memory load, the scheduler immediately switches to another. With thousands of warps in flight across a B200's 148 streaming multiprocessors, memory latency becomes almost entirely hidden behind useful computation from other warps.
The kernel also scales linearly, such that doubling the output dimension doubles the number of independent warps with no added synchronization. The same holds across the token batch dimension, so the scheduler sees one flat namespace of work with no edges between nodes. This stands in contrast to the traditional path, where expert-level GEMM kernels require intra-block coordination.
#End-to-end decode throughput at scale
Testing on our internal inference system running a Qwen-3 style model on NVIDIA B200 GPUs produced a consistent throughput gain. The throughput gain is flat across all context-length buckets, confirming this is a pure generation-time improvement that does not depend on prompt length.
#Improved accuracy
Removing the intermediate activation quantization step has a measurable quality impact. Converting BF16 activations to MXFP8 and back introduces a rounding error floor that accumulates across the model's layers. Warp decode keeps activations in BF16 throughout and accumulators in FP32, so the reduction never operates on degraded inputs. The result is warp decode produces outputs 1.4x closer to full 32-bit ground truth than the classical path.
#Hardware Efficiency
We started developing warp decode by asking how close we could get to the hardware's maximum throughput. The B200's measured peak for contiguous memory reads is 6.8 TB/s (measured using a copy kernel). Warp decode sustains 3.95 TB/s at B=32, or 58% of that peak. The remaining gap likely reflects the memory latency cost of the random access patterns that expert routing creates, since each token may route to non-adjacent experts like 5, 8, 14, 19 etc.
By contrast, peak throughput is measured using contiguous (0,1,2,3) memory reads. Correctness against the reference implementation was tight across all batch sizes: minimum cosine similarity > 0.999996, maximum absolute difference 0.001953.
#Warp decode and Composer training
Warp decode is not a general replacement for expert-centric execution. Higher-volume workloads like prefill and large-batch inference still benefit from expert-centric packing because many tokens share the same expert, and the cost of organizing them is amortized over enough real computation to be worthwhile.
Warp decode wins when there isn't enough shared work per expert to justify that overhead, as is often the case with MoE decode. This makes it an important part of how we continually improve Composer. While investments in pretraining data and RL determine the quality of the model's outputs, inference investments like warp decode determine how quickly and accurately those outputs reach developers.
関連記事
[AINews] 今日は何も大きな出来事はありませんでした
Anthropic が RSI の兆候を示し、OpenAI の ChatGPT が月間アクティブユーザー数で 10 億人を突破。SpaceX AI は IPO について説明しているが、最も重要なのは AIE WF のチケット確保とイベント参加である。
マイクロソフト、新しい MAI モデルを発表
マイクロソフトは今朝、推論に特化した「MAI-Thinking-1」と GitHub コード生成向けに設計された「MAI-Code-1-Flash」の 2 つの新しいテキスト大規模言語モデルを発表した。
MiniMax、新スパースアテンション機構と15.6倍の長文コンテキスト応答速度向上を備えた次期M3モデルを発表
中国のAI企業MiniMaxは、人気シリーズ「M2」の開発に関する技術報告書を公開し、次期モデル「M3」で採用する新スパースアテンション機構を紹介した。この技術により長文コンテキストでのデコード速度が最大15.6倍向上し、超長文コンテキスト対応AIエージェントの経済的実現が可能になる見込みである。