エピローグ融合による効率的なカーネルの作成
fal.ai のブログ記事は、CUTLASS を用いた Epilogue Fusion 技術の解説を通じて、現代の ML ワークロードにおけるメモリアクセスボトルネックを解消し、Hopper や Blackwell アーキテクチャでの計算効率を最大化する実装手法を詳述している。
キーポイント
Epilogue Fusion の基本原理と効果
GEMM 演算後のバイアス、活性化関数(ReLU, Sigmoid など)、型変換などの操作を、結果をグローバルメモリに書き込む前にレジスタ内で実行することで、不要なメモリの読み書きを排除し、帯域幅ボトルネックを解消する。
CUTLASS と EVT(Epilogue Visitor)の活用
CUTLASS の EVT コールバック機能を利用することで、ユーザーは GEMM エピローグに独自のロジック(例:Gated-SiLU/SwiGLU など)をフックし、要素ごとの独立した変換を効率的に実装できる。
Warp-Specialized 設計とアーキテクチャ最適化
Hopper や Blackwell アーキテクチャでは、ワープをプロデューサー(ロード)とコンシューマー(ストア・変換)に分割し、Tensor Core の計算と他の命令のオーバーラップを最大化することで、さらなる性能向上を図る設計が推奨される。
実装の現実性と将来展望
記事は完全な最適化版ではなく概念実証として BF16 GEMM を例に挙げるが、本番環境では NVFP4 ブロックスケーリングやより堅牢なビジター、TMA ストアパスの活用が必要であると指摘している。
EVT を用いた効率的な Epilogue の構築
Sm90EVT は Sm90AccFetch と計算ノード(Identity, Multiply, Add など)を連結し、アキュムレータの型変換やスケーリング・ReLU などの複雑な処理を最小限のオーバーヘッドで実行するビジターツリーとして機能します。
Collective Epilogue と Mainloop の連携
CollectiveBuilder はアーキテクチャやスケジュールに応じて最適化された Epilogue と Mainloop を生成し、Warp-Specialized 環境では各コンシューマー Warp がサブタイルを反復処理して TMA ストアを実行します。
GemmUniversal による統合
CollectiveMainloop と CollectiveEpilogue は GemmUniversal ラッパーによって結合され、入力データのロード(TMA/LDG)、パイプラインステージ数、MMA 発行などの詳細を自動調整しながら GEMM カーネルとして動作します。
影響分析・編集コメントを表示
影響分析
この技術解説は、大規模言語モデルや生成 AI の推論・学習において頻出する GEMM ベースの計算ボトルネックを解決するための具体的な手法を提供しており、CUDA カーネル開発者にとって即座に適用可能な知見となる。特に Blackwell アーキテクチャへの対応を示唆している点は、次世代ハードウェアでのパフォーマンス最大化を目指すエンジニアリングチームにとって極めて重要な指針である。
編集コメント
CUTLASS の高度な機能である Epilogue Fusion を、具体的な実装フロー(EVT)を通じて解説しており、単なる理論ではなくコードレベルでの適用可能性が高い記事です。
多くの機械学習ワークロードでは、GEMM の後にバイアス、活性化関数、スケーリング、または型変換などの小規模な演算が続きます。これらの演算は数学的な計算コストは低いものの、グローバルメモリへの追加のトラフィック(GEMM 結果の保存、再読み取り、再度書き込み)を必要とすることが多くあります。
エピローグ融合(Epilogue fusion)はこの問題を回避する手法であり、GEMM の結果がレジスタ内に残っている間に、最終的なグローバルメモリへのストア直前にこれらの追加演算を適用することができます。Hopper および Blackwell アーキテクチャでは、Tensor Core の作業と他の命令のオーバーラップにさらに余裕があるため、エピローグで追加の計算を行うことがさらに魅力的になる場合があります。
image
エピローグ融合は、活性化関数を GEMM のエピローグ内で直接適用することで、中間的なグローバルメモリの読み書きを排除します。
CUTLASS は、バイアス加算や活性化関数(ReLU、Sigmoid など)、型変換などの演算を、結果をグローバルメモリに書き込む前にアキュムレータフラグメントに対して直接適用するエピローグ融合(epilogue fusion)を可能にすることで、この利点を活用しています。これらの演算を融合させることで、現代のワークロードにおいて支配的なコストとなることが多い追加のメモリアクセス(読み取り・書き込み)を回避できます。このような演算は要素ごとの操作であるため、比較的容易に融合可能です。行や列をまたぐ縮約(reduction)や通信を必要とせず、各出力要素に対して独立して適用することができます。
本ブログ記事では、基本概念の解説を行い、事前に用意された演算を示した上で、GEMM ゲート付き SiLU(別名 SwiGLU とも呼ばれる)のカスタムビジターを作成するアイデアについて説明します。目標は完全な網羅性を追求することではなく、ビジターのフローがどのように動作し、ロジックをエピローグにどのように「フック」するかを示すことにあります。
ここでは完全に最適化されたバージョンを構築するわけでも、カスタム TMA ストアパスやパイプラインの詳細については示しません(これらはすぐに複雑化するためです)。シンプルさを保つため、これらの例では BF16 GEMM を使用します。本番レベルの Blackwell 実装においては、より堅牢な NVFP4 ブロックスケーリング GEMM、より頑健なビジター、および必要に応じて追加の融合を使用すべきです。
ワープ特化型エピローグにおける重要な詳細は、エピローグ作業が生産者(ロード)と消費者(ストア)の役割に分割される点です(ワープ/ワープグループ)。EVT コールバックは、アキュムレータ断片が変換され、最終的に保存される消費者側で実行されます。
CUTLASS GEMM + EVT の基礎
まずは CUTLASS における最小限の GEMM 定義から始めましょう。
CUTLASS で GEMM カーネルを定義する際、通常は以下の3つの主要コンポーネントを指定する必要があります:
メインループ。これは入力タイルのロード方法、タイルおよびクラスタ形状、オペランドレイアウト、Tensor Core MMA 演算(Matrix Multiply-Accumulate)の発行方法を定義します。
エピローグ。アキュムレータ断片の変換方法とグローバルメモリへの書き戻し方法を記述します。
カーネルラッパー。メインループとエピローグを結合して、完全な GEMM カーネルを作成します。
以下は、アイデンティティ(恒等)エピローグを使用する単純な例です。
// EVT: acc -> cast
using EVT = Sm90EVT,
Sm90AccFetch>;
// Collective epilogue
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder::CollectiveOp;
// Collective mainloop
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder(sizeof(typename CollectiveEpilogue::SharedStorage))>,
cutlass::gemm::collective::KernelScheduleAuto>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal, CollectiveMainloop, CollectiveEpilogue, void>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter;
各部分が何を行っているかを示します:
Sm90AccFetch: メインループから生のアキュムレータフラグメント(acc)を提供するリーフノードです。
Sm90Compute: 単純な「計算」ノードです。Identity を使用すると、これは選択された丸めモードを用いて、アキュムレータ要素型(ElementAcc)を出力要素型(ElementOutput)に変換する効果を持ちます。
Sm90EVT: これらのノードを組み合わせて、acc -> cast という小さなビジターツリーを作成します。
次に、この EVT をエピローグの「コレッティブ」に接続します:
CollectiveEpilogue: 選択されたアーキテクチャ/スケジューリングに対するエピローグ実装です。SM90 ワープ特別化カーネルの場合、これは通常、各コンシューマーワープ/グループがエピローグサブタイルを反復処理し、EVT コールバック(visit, reduce など)を呼び出し、その後実際のストアパス(多くの場合、レジスタ → シェアードメモリ → TMA を用いたグローバルメモリへのストア)を実行することを意味します。
そして最後にメインループの「コレッティブ」に接続します:
CollectiveMainloop: A/B タイルのロード方法(TMA vs LDG など)、使用されるパイプラインステージ数、および TileShape/ClusterShape に対する MMA の発行方法を定義します。
GemmUniversal ラッパーはメインループとエピローグを結びつけます。
EVT チェーンの構築(スケール + バイアス + ReLU)
以下のように、GEMM のアキュムレータに対してスケールを適用し、バイアスを加算し、最後に ReLU を適用するやや複雑なエピローグを構築できます。
// EVT ノード演算
using NodeMultiply = Sm90Compute;
using NodeAdd = Sm90Compute;
// EVT: (global_scale * acc) + 行ごとのバイアス -> ReLU -> BF16 へキャスト
using EVT0 = Sm90EVT, Sm90AccFetch>;
using EVT1 = Sm90EVT>, EVT0>;
using EVT2 = Sm90EVT,
EVT1>;
// Collective epilogue
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder::CollectiveOp;
このエピローグを以下のように視覚化できます:
image
詳細については、こちらのブログが非常に参考になります: https://research.colfax-intl.com/epilogue_visitor_tree
いくつかの事前構築されたエピローグ演算子 (op) があります。
CUTLASS にはすでにいくつかの定義済みのエピローグスタイルが含まれており、例えば以下のようなものがあります:
線形結合: LinearCombination (D = alpha*AB + beta*C)、ScaledAcc` (D = alpha*acc)
活性化関数: LinCombEltAct with ReLU / GELU / SiLU / Sigmoid / Tanh / HardSwish / LeakyReLU / Clamp
バイアス: LinCombPerRowBias、LinCombPerColBias (活性化関数と併用可能)
ブロードキャスト: Sm90ScalarBroadcast、Sm90RowBroadcast、Sm90ColBroadcast
リダクション: Sm90ScalarReduction、Sm90RowReduction、Sm90ColReduction
ゲート付き SiLU の融合
ゲート付き SiLU パターン (Flux、Flux2、および LLaMA などの大規模言語モデルで使用) は以下のようになります:
C = A @ B
M, N = C.shape
output = SiLU(C[:, :N//2]) * C[:, N//2:]
これにより出力次元が 2 倍に削減されます (2 つの半分を取り出して乗算します)。CUTLASS にはこの特定の「ペア + N/2 リダクション」パターンに対応する組み込みのエピローグがないため、これを単一のパスで融合したい場合はカスタムエピローグビジターが必要です。
CUTLASS のカスタムエピローグビジターを使用すると、スレッドごとのフラグメントをインターセプトして変換し、その保存方法や場所を制御できます。引数/パラメータを持つビジター構造体を定義し、get_consumer_store_callbacks() を実装して、ゲート付き SiLU 演算をペアに対して適用し、縮小された出力を書き込む visit() または postreduce()/end_loop() などのコールバックを実装します。その後、このビジターを EVT ツリー(例:Sm90EVT)に挿入することで、GEMM は追加のカーネルなしで1回のパスで融合された出力を生成します。
カスタムエピローグビジターを実装する前に、重要な課題に直面します。ゲート付き SiLU 演算では、silu(gate) を対応する up 要素と乗算する必要がありますが、単純な出力レイアウトではこれらの値は N 次元の異なる半分に存在します。具体的には、出力列 n が [0, N_out) の範囲にある場合、ペアは gate = C[:, n] および up = C[:, n + N_out) に位置します。これは、対応する値が N_out 列分離れており、容易に異なるフラグメント/サブタイル(あるいは異なる CTAs)に属してしまうことを意味し、その場合、クロスタイル通信が必要になります。
タイル間で同期を試みると、パフォーマンスに著しい悪影響を与え、GEMM エピローグの並列実行モデルを破綻させてしまいます。代わりに、重みパッキング時に B の列(モデル重み行列)を一度だけ並べ替えることで、GEMM 出力 C = A @ B のレイアウトを変更します。これにより、[gate(0..N_out-1), up(0..N_out-1)] を2つの分離した半分として生成するのではなく、インターリーブされたペア [gate0, up0, gate1, up1, ...] を生成します。具体的には、出力列 n に対するペアは、C 内で gate -> 2n および up -> 2n + 1 の位置に隣接して配置されます。この隣接性が保証されることで、ビジター(visitor)はゲート付き SiLU(gated-SiLU)乗算を融合し、タイル間の調整を行わずにエピローグで縮小された N_out 結果を書き込むことができます。
このレイアウト変更の後、同じスレッド内の隣接要素を使用してゲート付き SiLU を適用するビジターを容易に記述できます。CUTLASS には、エピローグの異なる段階でデータをインターセプト(傍受)して変換するためにオーバーライドできるいくつかのフックが用意されています:
begin() -> ストアループ開始前に1回呼び出されます
begin_loop(epi_m, epi_n) -> 各サブタイルの開始時に呼び出されます
previsit(...) -> ビジット(訪問)の前に呼び出され、共有メモリでのブロードキャストに使用されます
visit(...) -> フラグメントごとに呼び出され、計算された値を受け取ります
reduce(...) -> フラグメント間での縮小ステップです
postreduce(...) -> 縮小後、メモリーフェンス前に呼び出されます
tma_store(...) -> 補助テンソルに対して TMA ストア(TMA stores)を発行します
end_loop(epi_m, epi_n) -> 各サブタイルの終了時に呼び出されます
end() -> store ループ完了後に一度呼び出されます
基本的な gated-SiLU ビジターでは、visit() と end_loop() の 2 つのフックだけで十分です。「完全に統合された」TMA エピローグでは、通常、最終結果を reduce()/postreduce() を通して流し込み、集合演算がレジスタ→smem→TMA ストアパスを処理させます。このポストをシンプルに保つため、ここでは end_loop() で集約された出力を直接グローバルメモリへ書き込むことにします。
データフローの理解
エピローグは階層的な方法でデータを処理します。最上位レベルでは TileShape によって定義されるタイルがあります。各タイルはさらに subtiles(epi_m, epi_n でインデックス付け)に分割され、各 subtile は複数のフラグメントとして処理されます(epi_v でインデックス付け)。
Tile (M × N)
└── Subtile (epi_m, epi_n)
└── Fragment 0 (epi_v = 0)
└── Fragment 1 (epi_v = 1)
└── ...
重要な洞察は、visit() が各フラグメントごとに 1 回ずつ、subtile あたり複数回呼び出される一方で、end_loop() は subtile 内のすべてのフラグメントが訪問された後に一度だけ呼び出されることです。
visit() フック:フラグメントの集約
visit() の実装を見てみましょう。
template
CUTLASS_DEVICE auto visit(Array const&, int epi_v, int, int,
Array const& frg_input) {
Tensor tC_rOut_frg = recast>(coalesce(tC_rOut));
tC_rOut_frg(epi_v) = frg_input;
return frg_input;
}
関数のシグネチャから、受け取るデータがわかります:
Array const& -> 生のアキュムレータ値(これは無視します)
epi_v -> 現在の subtile 内のフラグメントインデックス
Array const& frg_input -> 前の EVT ノードからの入力(すでに変換済み)
当方のビジターは EVT チェーンの末尾に位置しています。データが私たちに到達する時点では、すでにアキュムレータからフェッチされ、作業精度に変換されています。frg_input には、このスレッドが責任を負う FragmentSize 個の要素が含まれています。
ここで重要な操作は、フラグメントを tC_rOut に格納することです:
tC_rOut_frg(epi_v) = frg_input;
レジスタテンソルをフラグメントの配列として参照するために recast を使用し、epi_v でインデックス指定して各フラグメントを正しい位置に格納します。これにより、複数の visit() 呼び出し全体でフラグメントが累積され、レジスタ内に完全なサブタイルデータが構築されます。
end_loop() フック:ゲート付き SiLU の計算。
サブタイルのすべてのフラグメントが収集されると、end_loop() が呼び出されます:
CUTLASS_DEVICE void end_loop(int epi_m, int epi_n) {
if constexpr (EnableNullptr) {
if (params_ptr->ptr_out == nullptr) {
return;
}
}
auto [M, N_full, L] = problem_shape_mnl_full;
int N_out = N_full / 2;
// N 次元を半分にした出力テンソルビューの作成
Tensor gOut = make_tensor(
make_gmem_ptr(params_ptr->ptr_out),
make_shape(M, N_out, L),
params_ptr->dOut
);
// 簡単な反復処理のためにテンソルをフラット化
Tensor tC_rOut_flat = coalesce(tC_rOut);
Tensor tC_cOut_full_flat = coalesce(tC_cOut_full(_,_,_,epi_m,epi_n));
using ConvertOutput = NumericConverter;
ConvertOutput convert_output{};
int total_elements = size(tC_cOut_full_flat);
// Process pairs: (gate, up) -> silu(gate) * up
CUTLASS_PRAGMA_UNROLL
for (int flat_idx = 0; flat_idx >= 1; // Divide column by 2
gOut(m, n, l) = convert_output(out);
}
}
Let's break down what happens:
出力テンソルのセットアップ: N_out = N_full / 2 の形状 (M, N_out, L) を持つ gOut を作成します。ペアを融合しているため、出力の列数は半分になります。
座標の追跡: tC_cOut_full は各要素の元の (m, n, l) 座標を保持しています。これを (epi_m, epi_n) でスライスして、現在のサブタイルの座標を取得します。
ペア処理: 要素を 2 つずつ反復処理します。行列 B の列再配置により、フラット化されたテンソル内の隣接する要素は必ず (gate, up) ペアであることが保証されます。
SiLU 計算: 各ペアに対して silu(gate) * up を計算します。このコードは明確さを保つために単純な expf 形式を示していますが、CUTLASS の Sigmoid/SiLu はビルド/設定に応じて高速な tanh ベースの近似(例:sigmoid(x) ≈ 0.5fast_tanh(0.5x) + 0.5**)を使用できます。
座標マッピング: 元の列インデックス n は広大な行列に対応します。2 つの入力列が 1 つの出力列にマップされるため、右シフト (n >>= 1) して出力列インデックスを取得します。
コールバックの設定。
get_consumer_store_callbacks() 関数は、必要なテンソルを使用してコールバックオブジェクトを初期化します:
template
CUTLASS_DEVICE auto get_consumer_store_callbacks(ConsumerStoreArgs const& args) {
auto [M, N, K, L] = args.problem_shape_mnkl;
auto problem_shape_mnl_full = make_shape(M, N, L); // "wide" output shape (before N/2 reduction)
// Allocate a per-thread register tile to accumulate fragments for the current subtile.
// The exact shape comes from CUTLASS's epilogue partitioning (omitted here).
Tensor tC_rOut = make_tensor(/* same (CPY,CPY_M,CPY_N) shape as the thread's output tile */);
// Also build a matching coordinate tensor for the wide output so we can map each register element
// back to (m, n, l) and then apply n >>= 1 when storing the reduced output.
Tensor coordOut_full = make_identity_tensor(make_shape(M, N, L));
Tensor tC_cOut_full = sm90_partition_for_epilogue(coordOut_full, /* ... */);
return ConsumerStoreCallbacks(
cute::move(tC_rOut),
cute::move(tC_cOut_full),
problem_shape_mnl_full,
params_ptr
);
}
Key points:
tC_rOut is a register tensor that accumulates fragments during visit() calls
tC_cOut_full maps flat indices to (m, n, l) coordinates in the original wide matrix
sm90_partition_for_epilogue handles the complex tiling and thread mapping for us
Putting it all together.
The complete data flow looks like this:
┌─────────────────────────────────────────────────────────────┐
│ GEMM Mainloop: Compute A @ B_reordered │
│ Output shape: (M, N_full) where N_full = 2 * N_out │
└─────────────────────────────────────────────────────────────┘
│
▼
┌─────────────────────────────────────────────────────────────┐
│ visit() called multiple times per subtile │
│ Each call: store frg_input into tC_rOut[epi_v] │
│ │
│ tC_rOut: [frag0][frag1][frag2]... │
│ ↑ ↑ ↑ │
│ epi_v=0 =1 =2 │
└─────────────────────────────────────────────────────────────┘
│
▼
┌─────────────────────────────────────────────────────────────┐
│ end_loop() called once per subtile │
│ │
│ Process pairs from this thread's tC_rOut: │
│ [gate₀, up₀, gate₁, up₁, ...] │
│ ↓ ↓ │
│ silu(gate₀) * up₀ → output[m, n>>1, l] │
│ │
│ Output shape: (M, N_out) │
└─────────────────────────────────────────────────────────────┘
完全統合された実装の姿
実際の生産環境向けの実装では、通常以下を行います:
TMA エピローグパイプラインとの統合:end_loop() で直接グローバルメモリへ保存するのではなく、reduce()/postreduce() を通じて計算・保存を行い、tma_store() 経由で TMA ストアを発行します。これにより、標準的なレジスタ→smem→TMA のパスを維持できます。
Blackwell では BF16 ではなく FP8/NVFP4 メインループを使用し、より高い Tensor Core スループットを活用します。
「次に来るもの」(例えば出力量子化や補助テンソルなど)を融合させ、メモリアクセスを最小限に抑えます。gated‑SiLU はすでに論理 N 次元を半分にする効果がありますが、もし出力も量子化(例:NVFP4、4 ビット)すれば、非融合の広幅 BF16 中間結果(2N_out* カラム、要素あたり 2 バイト)を書き込む場合と比較して、書き込みフットプリントは劇的に減少します。融合結果を N_out カラム、要素あたり 4 ビットで保存することで、出力バイト数は最大約 8 倍削減され、さらに追加の読み書きペアも回避できます。
いくつかの実生産環境での数値
これらは実生産モデルからの数値です(呼び出しあたりの時間、マイクロ秒):

これは、GEMM 後の処理をエピローグに折り込むことで、166 マイクロ秒の節約(約 1.28 倍の高速化)を実現した結果です。
エピローグ融合の素晴らしい点は、何らかの近似を行っていないことです。単に同じ計算をより早く(グローバルメモリへの往復の前に行う)実行しているだけなので、速度向上のために品質を犠牲にする必要はありません。
原文を表示
imageIn many ML workloads, a GEMM is followed by small operations like bias, activation, scaling, or type conversion. These ops are cheap in math, but they often cost extra global memory traffic (store GEMM result, read it back, write again).
Epilogue fusion is a way to avoid this, we can apply these extra ops while the GEMM result is still in registers, right before the final store to global memory. On Hopper and Blackwell, there is also more room to overlap Tensor Core work with other instructions, so doing some extra compute in the epilogue can be even more attractive.
image
Epilogue fusion eliminates intermediate global memory reads and writes by applying
the activation function directly within the GEMM epilogue.
CUTLASS takes advantage of this by allowing epilogue fusion, where operations such as bias addition, activation functions (ReLU, Sigmoid, etc.), and type conversion are applied directly to the accumulator fragments before writing the result to global memory. By fusing these operations, we avoid additional memory reads and writes, which are often the dominant cost in modern workloads. These kinds of operations are relatively easy to fuse because they are elementwise. They do not require reductions or communication across rows or columns, and can be applied independently to each output element.
In this blog post we will go over the basics, show some prebuilt ops, and then show the idea of writing a custom visitor for GEMM gated‑SiLU (aka SwiGLU). The goal is not to be super complete, just to show how the visitor flow works and how you "hook" your logic into epilogue.
We are not building a fully optimized version, nor are we going to show custom TMA store paths or pipelining details (those get complicated fast). To keep things simple, we will use BF16 GEMM for these examples. For a production level Blackwell implementation, you should use NVFP4 block-scaled GEMM, a more robust visitor, and potentially additional fusions.
One important detail for warp-specialized epilogues, the epilogue work is split into producer-load and consumer-store roles (warps/warp-groups). The EVT callbacks run on the consumer side, where the accumulator fragments are transformed and then ultimately stored.
CUTLASS GEMM + EVT basics
Let's start with a minimal GEMM definition in CUTLASS.
In CUTLASS, defining a GEMM kernel typically involves specifying three main components:
The mainloop, which defines how input tiles are loaded, the tile and cluster shapes, operand layouts, and how Tensor Core MMA operations are issued.
The epilogue, which describes how accumulator fragments are transformed and written back to global memory.
The kernel wrapper, which combines the mainloop and epilogue into a complete GEMM kernel.
Below is a simple example that uses an identity epilogue.
// EVT: acc -> cast
using EVT = Sm90EVT<
Sm90Compute<cutlass::epilogue::thread::Identity, ElementOutput, ElementAcc, RoundStyle>,
Sm90AccFetch>;
// Collective epilogue
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag, OpClass, TileShape, ClusterShape,
cutlass::epilogue::collective::EpilogueTileAuto, ElementAcc, ElementAcc,
ElementOutput, cutlass::layout::RowMajor, AlignOutput,
ElementOutput, cutlass::layout::RowMajor, AlignOutput,
cutlass::epilogue::collective::EpilogueScheduleAuto, EVT>::CollectiveOp;
// Collective mainloop
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag, OpClass, ElementInput, cutlass::layout::RowMajor, AlignInput,
ElementInput, cutlass::layout::ColumnMajor, AlignInput, ElementAcc, TileShape, ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
cutlass::gemm::collective::KernelScheduleAuto>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<Shape<int,int,int,int>, CollectiveMainloop, CollectiveEpilogue, void>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
Here is what each piece is doing:
Sm90AccFetch: leaf node that provides the raw accumulator fragment (acc) from the mainloop.
Sm90Compute<Identity, ElementOutput, ElementAcc, ...>: a trivial "compute" node. With Identity, this effectively just converts the accumulator element type (ElementAcc) into the output element type (ElementOutput) using the chosen rounding mode.
Sm90EVT<...>: composes those nodes into a tiny visitor tree: acc -> cast.
Then you plug that EVT into the epilogue "collective":
CollectiveEpilogue: the epilogue implementation for the chosen architecture/schedule. On SM90 warp-specialized kernels this typically means, each consumer warp/group iterates over epilogue subtiles, calls the EVT callbacks (visit, reduce, etc.), and then performs the actual store path (often register -> shared memory -> TMA store to global memory).
And finally into the mainloop “collective”:
CollectiveMainloop: defines how A/B tiles are loaded (e.g. TMA vs LDG), how many pipeline stages are used, and how MMA is issued for your TileShape/ClusterShape.
The GemmUniversal wrapper ties the mainloop and epilogue together.
Building an EVT chain (scale + bias + ReLU).
We can build slightly complex epilogue that applies scaling to accumulator of gemm applies addition with bias and applies ReLU at the end as below
// EVT node ops
using NodeMultiply = Sm90Compute<cutlass::multiplies, ElementAcc, ElementAcc, RoundStyle>;
using NodeAdd = Sm90Compute<cutlass::plus, ElementAcc, ElementAcc, RoundStyle>;
// EVT: (global_scale * acc) + per-row bias -> ReLU -> cast to BF16
using EVT0 = Sm90EVT<NodeMultiply, Sm90ScalarBroadcast<ElementScale>, Sm90AccFetch>;
using EVT1 = Sm90EVT<NodeAdd, Sm90ColBroadcast<0, TileShape, ElementBias, ElementBias, Stride<_1, _0, _0>>, EVT0>;
using EVT2 = Sm90EVT<
Sm90Compute<cutlass::epilogue::thread::ReLU, ElementOutput, ElementAcc, RoundStyle>,
EVT1>;
// Collective epilogue
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag, OpClass, TileShape, ClusterShape,
cutlass::epilogue::collective::EpilogueTileAuto, ElementAcc, ElementAcc,
ElementOutput, cutlass::layout::RowMajor, AlignOutput,
ElementOutput, cutlass::layout::RowMajor, AlignOutput,
cutlass::epilogue::collective::EpilogueScheduleAuto, EVT2>::CollectiveOp;
We can visualize this epilogue like this:
image
If you want more details, this blog is really nice: https://research.colfax-intl.com/epilogue_visitor_tree
A few prebuilt epilogue ops.
CUTLASS already has some predefined epilogue styles, for example:
linear combinations: LinearCombination (D = alpha*AB + beta*C), ScaledAcc` (D = alpha*acc)
activations: LinCombEltAct with ReLU / GELU / SiLU / Sigmoid / Tanh / HardSwish / LeakyReLU / Clamp
bias: LinCombPerRowBias, LinCombPerColBias (can be with activations)
broadcasts: Sm90ScalarBroadcast, Sm90RowBroadcast, Sm90ColBroadcast
reductions: Sm90ScalarReduction, Sm90RowReduction, Sm90ColReduction
Fusing gated‑SiLU
The gated‑SiLU pattern (used in Flux, Flux2, and LLMs like LLaMA) looks like this:
C = A @ B
M, N = C.shape
output = SiLU(C[:, :N//2]) * C[:, N//2:]
This reduces the output dimension by 2x(we take two halves and multiply them). CUTLASS doesn't have a built-in epilogue for this exact "pair + reduce N/2" pattern, so we need a custom epilogue visitor if we want to fuse it in a single pass.
Custom epilogue visitors in CUTLASS let you intercept the per‑thread fragment, transform it, and control how/where its stored. You define a visitor struct with Arguments/Params, then implement get_consumer_store_callbacks() and a callback like visit() or postreduce()/end_loop() to apply the gated‑SiLU on pairs and write the reduced output. The visitor is then inserted into the EVT tree (e.g., Sm90EVT<CustomVisitor, Sm90AccFetch>), so the GEMM produces the fused output in one pass without an extra kernel.
Before implementing the custom epilogue visitor, we face a key challenge. The gated‑SiLU operation requires multiplying silu(gate) with its corresponding up element, but in the naive output layout those values live in different halves of the N dimension, for output column n in [0, N_out), the pair is at gate = C[:, n] and up = C[:, n + N_out]. This means the paired values are separated by N_out columns, and can easily land in different fragments / subtiles (or even different CTAs), which would require cross-tile communication.
image
Trying to synchronize across tiles would significantly hurt performance and break the GEMM epilogue’s parallel execution model. Instead, we permute the columns of B (the model weight matrix) once during weight packing so the GEMM output C = A @ B is laid out differently, instead of producing [gate(0..N_out-1), up(0..N_out-1)] as two separated halves, it produces interleaved pairs [gate0, up0, gate1, up1, ...]. Concretely, the pair for output column n ends up adjacent in C at gate -> 2n and up -> 2n + 1. With adjacency guaranteed, the visitor can fuse the gated‑SiLU multiply and write the reduced N_out result in the epilogue without cross-tile coordination.
image
After this layout tweak, we can easily write a visitor that applies gated‑SiLU using adjacent elements in the same thread. CUTLASS provides several hooks you can override to intercept and transform data at different stages of the epilogue:
begin() -> called once before the store loop starts
begin_loop(epi_m, epi_n) -> called at the start of each subtile
previsit(...) -> called before visit, used for shared memory broadcasts
visit(...) -> called per-fragment, where you receive computed values
reduce(...) -> reduction step across fragments
postreduce(...) -> called after reduction, before memory fence
tma_store(...) -> issue TMA stores for auxiliary tensors
end_loop(epi_m, epi_n) -> called at the end of each subtile
end() -> called once after the store loop completes
For our basic gated‑SiLU visitor, we only need two hooks, visit() and end_loop(). In a "fully integrated" TMA epilogue you will typically flow the final result through reduce()/postreduce() and let the collective handle the register -> smem -> TMA store path. To keep this post simple, we will directly write the reduced output to global memory in end_loop().
Understanding the Data Flow
The epilogue processes data in a hierarchical manner. At the top level, we have tiles defined by TileShape. Each tile is further divided into subtiles (indexed by epi_m, epi_n), and each subtile is processed as multiple fragments (indexed by epi_v).
Tile (M × N)
└── Subtile (epi_m, epi_n)
└── Fragment 0 (epi_v = 0)
└── Fragment 1 (epi_v = 1)
└── ...
The key insight is that visit() is called multiple times per subtile once for each fragment while end_loop() is called once after all fragments in a subtile have been visited.
The visit() hook: accumulating fragments.
Let's look at our visit() implementation:
template <typename ElementAccumulator, typename ElementInput, int FragmentSize>
CUTLASS_DEVICE auto visit(Array<ElementAccumulator, FragmentSize> const&, int epi_v, int, int,
Array<ElementInput, FragmentSize> const& frg_input) {
Tensor tC_rOut_frg = recast<Array<ElementInput, FragmentSize>>(coalesce(tC_rOut));
tC_rOut_frg(epi_v) = frg_input;
return frg_input;
}
The function signature tells us what data we receive:
Array<ElementAccumulator, FragmentSize> const& -> the raw accumulator values (we ignore this)
epi_v -> the fragment index within the current subtile
Array<ElementInput, FragmentSize> const& frg_input -> the input from previous EVT nodes (already converted)
Our visitor sits at the end of an EVT chain. By the time data reaches us, it has already been fetched from the accumulator and converted to our working precision. The frg_input contains FragmentSize elements that this thread is responsible for.
The critical operation here is storing fragments into tC_rOut:
tC_rOut_frg(epi_v) = frg_input;
We use recast to view our register tensor as an array of fragments, then index by epi_v to store each fragment in its correct position. This accumulates all fragments across multiple visit() calls, building up the complete subtile data in registers.
The end_loop() hook: computing gated‑SiLU.
Once all fragments for a subtile have been collected, end_loop() is called:
CUTLASS_DEVICE void end_loop(int epi_m, int epi_n) {
if constexpr (EnableNullptr) {
if (params_ptr->ptr_out == nullptr) {
return;
}
}
auto [M, N_full, L] = problem_shape_mnl_full;
int N_out = N_full / 2;
// Create output tensor view with halved N dimension
Tensor gOut = make_tensor(
make_gmem_ptr(params_ptr->ptr_out),
make_shape(M, N_out, L),
params_ptr->dOut
);
// Flatten tensors for easy iteration
Tensor tC_rOut_flat = coalesce(tC_rOut);
Tensor tC_cOut_full_flat = coalesce(tC_cOut_full(_,_,_,epi_m,epi_n));
using ConvertOutput = NumericConverter<ElementOut, float, RoundStyle>;
ConvertOutput convert_output{};
int total_elements = size(tC_cOut_full_flat);
// Process pairs: (gate, up) -> silu(gate) * up
CUTLASS_PRAGMA_UNROLL
for (int flat_idx = 0; flat_idx < total_elements; flat_idx += 2) {
float gate = tC_rOut_flat(flat_idx);
float up = tC_rOut_flat(flat_idx + 1);
// SiLU(x) = x * sigmoid(x) = x / (1 + exp(-x))
float silu = gate / (1.0f + expf(-gate));
float out = silu * up;
// Map back to output coordinates (column index halved)
auto [m, n, l] = tC_cOut_full_flat(flat_idx);
n >>= 1; // Divide column by 2
gOut(m, n, l) = convert_output(out);
}
}
Let's break down what happens:
Output tensor setup: We create gOut with shape (M, N_out, L) where N_out = N_full / 2. The output has half the columns because we're fusing pairs.
Coordinate tracking: tC_cOut_full stores the original (m, n, l) coordinates for each element. We slice it by (epi_m, epi_n) to get coordinates for the current subtile.
Pair processing: We iterate through elements two at a time. Due to our column reordering of matrix B, adjacent elements in the flattened tensor are guaranteed to be (gate, up) pairs.
SiLU computation: For each pair, we compute silu(gate) * up. This code shows the simple expf form for clarity, but CUTLASS’s Sigmoid/SiLu can use a fast tanh-based approximation (e.g. sigmoid(x) ≈ 0.5fast_tanh(0.5x) + 0.5**) depending on build/config.
Coordinate mapping: The original column index n corresponds to the wide matrix. We right-shift by 1 (n >>= 1) to get the output column index, since two input columns map to one output column.
Setting up the callbacks.
The get_consumer_store_callbacks() function initializes our callback object with the necessary tensors:
template <bool ReferenceSrc, class... Args>
CUTLASS_DEVICE auto get_consumer_store_callbacks(ConsumerStoreArgs<Args...> const& args) {
auto [M, N, K, L] = args.problem_shape_mnkl;
auto problem_shape_mnl_full = make_shape(M, N, L); // "wide" output shape (before N/2 reduction)
// Allocate a per-thread register tile to accumulate fragments for the current subtile.
// The exact shape comes from CUTLASS's epilogue partitioning (omitted here).
Tensor tC_rOut = make_tensor<float>(/* same (CPY,CPY_M,CPY_N) shape as the thread's output tile */);
// Also build a matching coordinate tensor for the wide output so we can map each register element
// back to (m, n, l) and then apply n >>= 1 when storing the reduced output.
Tensor coordOut_full = make_identity_tensor(make_shape(M, N, L));
Tensor tC_cOut_full = sm90_partition_for_epilogue<ReferenceSrc>(coordOut_full, /* ... */);
return ConsumerStoreCallbacks</* ... */>(
cute::move(tC_rOut),
cute::move(tC_cOut_full),
problem_shape_mnl_full,
params_ptr
);
}
Key points:
tC_rOut is a register tensor that accumulates fragments during visit() calls
tC_cOut_full maps flat indices to (m, n, l) coordinates in the original wide matrix
sm90_partition_for_epilogue handles the complex tiling and thread mapping for us
Putting it all together.
The complete data flow looks like this:
┌─────────────────────────────────────────────────────────────┐
│ GEMM Mainloop: Compute A @ B_reordered │
│ Output shape: (M, N_full) where N_full = 2 * N_out │
└─────────────────────────────────────────────────────────────┘
│
▼
┌─────────────────────────────────────────────────────────────┐
│ visit() called multiple times per subtile │
│ Each call: store frg_input into tC_rOut[epi_v] │
│ │
│ tC_rOut: [frag0][frag1][frag2]... │
│ ↑ ↑ ↑ │
│ epi_v=0 =1 =2 │
└─────────────────────────────────────────────────────────────┘
│
▼
┌─────────────────────────────────────────────────────────────┐
│ end_loop() called once per subtile │
│ │
│ Process pairs from this thread's tC_rOut: │
│ [gate₀, up₀, gate₁, up₁, ...] │
│ ↓ ↓ │
│ silu(gate₀) * up₀ → output[m, n>>1, l] │
│ │
│ Output shape: (M, N_out) │
└─────────────────────────────────────────────────────────────┘
What a fully integrated implementation looks like
For a real production implementation, you will do typically:
Integrate with the TMA epilogue pipeline: instead of directly storing to global memory in end_loop(), compute/store through reduce()/postreduce() and issue TMA stores via tma_store() so you keep the standard register→smem→TMA path.
Use an FP8/NVFP4 mainloop on Blackwell (rather than BF16) to take advantage of the higher Tensor Core throughput.
Fuse “what comes next” (e.g. output quantization and/or auxiliary tensors) so you minimize memory traffic. gated‑SiLU already halves the logical N dimension; if you also quantize the output (e.g. to NVFP4, 4‑bit), the write footprint can drop dramatically compared to writing the unfused wide BF16 intermediate (2N_out* columns at 2 bytes/elem), storing the fused result at N_out columns and 4 bits/elem is up to ~8× fewer output bytes, plus you avoid an extra read+write pair.
Some production numbers
These are numbers from a production model (time per invocation, us):
imageThat’s 166 us saved (about 1.28× faster) by folding the post-GEMM work into the epilogue.
The cool part of epilogue fusion is that you are not approximating anything. You are just doing the same math earlier (before a round-trip to global memory), so you don't need to sacrifice quality to get the speedup.
関連記事
今日のまとめ
AI日報で今日の重要ニュースをまとめ読み