AIニュース最前線
最新ニュースAI日報Hacker日報週報動画AIツールトレンド企業

AIニュース最前線

世界中のAI最新情報を日本語で毎時更新

最新ニュース日報トレンド企業プレミアムRSS
© 2026 ainew.jp特定商取引法に基づく表記
ニュース一覧元記事を開く
TLDR AI·2026年6月12日 09:00·約24分で読める

PyTorch の Fused MLP を活用した最適化手法(29 分読了)

#LLM#PyTorch#推論最適化#深層学習
TL;DR

TLDR AI は、PyTorch における Fused MLP 技術の導入により、大規模言語モデルの推論速度を大幅に向上させる具体的な最適化手法を解説している。

AI深層分析2026年6月13日 00:05
4
重要/ 5段階
深度40%
4
関連度30%
5
実用性20%
5
革新性10%
3

キーポイント

1

Fused MLP の概念と仕組み

従来の MLP レイヤーを個別に実行するのではなく、活性化関数(例:SiLU)やバイアス加算などを融合させることで、メモリアクセスのオーバーヘッドを削減し計算効率を高める手法。

2

PyTorch での実装最適化

PyTorch のネイティブ機能やカスタムオペレーションを活用して、Fused MLP を効率的に構築・実行する具体的なコード例と設定手順の提示。

3

推論速度とリソース効率の向上

この最適化により、特にバッチサイズが小さい場合や遅延が敏感な推論タスクにおいて、顕著なレイテンシ短縮とスループット向上が期待できる。

影響分析・編集コメントを表示

影響分析

この記事は、大規模言語モデルの運用コスト削減とパフォーマンス向上に向けた具体的なエンジニアリングアプローチを提供しており、開発現場での実装即戦力となる価値が高いです。Fused MLP のような低レベルな最適化技術の普及が、AI モデルのスケーラビリティを決定づける重要な要素であることを再認識させる内容となっています。

編集コメント

理論的な解説だけでなく、実装レベルでの具体的な最適化手法に言及している点が非常に貴重です。大規模モデルを運用するエンジニアにとって即座に適用可能な知見と言えます。

記事一覧に戻る

このシリーズの「PyTorch でのプロファイリング」第 1 部では、torch.add(torch.matmul(x, w), b) を用いて PyTorch プロファイラーのトレースを読み取る方法を学びました。また、CPU ディスパッチチェーン、起動オーバーヘッド、オーバーヘッドバウンドと計算バウンドのレジームの違い、そして torch.compile の内部構造など、いくつかの他のトピックについても議論しました。

2 回目の反復(今回のブログ記事)では、一歩階段を登ります。手書きの matmul-add ペアを nn.Linear (bias=True) に置き換えます。これはあらゆる深層学習モデルが使用するビルディングブロックです。その後、活性化関数を間に挟みながらこれらを 3 つ積み重ね(例に特化した構成)、Multilayer Perceptron (MLP: 多層パーセプトロン) ブロックを形成します。

**

このブログ記事のスクリプトはここにあります:02_linear.py、03_simple_mlp.py、および 03_kernels_mlp.py。前回の投稿と同様に、これらを別タブで開きながらコードを追いかけて読むと理解が深まります。スクリプトの実行には NVIDIA A100-SXM4-80GB GPU を使用しています。Hugging Face のインフラ上で GPU をセットアップし、Dev Mode with Spaces を用いてスクリプトを試しに動かすのは非常に簡単です。また、Hugging Face Jobs pipeline を利用してスクリプトを実行することも可能です。

始める前に、後で繰り返し参照することになる 2 つの概念を簡単に復習しておきましょう:

  • GPU カーネル(GPU kernel)とは、GPU の多数のスレッド上で並列実行されるプログラムのことです。
  • CPU はこれらのカーネルのスケジューリングと起動を担当します。プロファイラトレースで目にする PyTorch のオーバーヘッドの大部分は、このスケジューリング作業に起因するものです。

matmul-add から Linear へ

nn.Linear は、Part 1 で既にプロファイル済みの行列乗算と加算をラップしたモジュールです。唯一の違いは、重み(weight)とバイアス(bias)をパラメータとして保持し、PyTorch ユーザーが慣れ親しんだ forward メソッドを公開している点にあります。

bias=True は、シリーズの第 1 部で確認した乗算と加算演算を真にエミュレートします

linear_layer = nn.Linear(in_dim, out_dim, bias=True)

y = linear_layer(x)

ここで扱っている演算は以下のように記述できます:

y = x @ w.T + b

ここで、x は入力、w は重み、b はバイアスです。02_linear.py を実行してプロファイルを確認してみましょう。

uv run 02_linear.py --batch 1024 --in_dim 32 --out_dim 64

uvx trace-util traces -b traces

trace-util は、トレースを Hugging Face bucket に同期し、ターミナル上で Preffeto URLs を提供するユーティリティです。

Figure 1: Profiler trace of nn.Linear

Figure 1 は、線形層のフォワード呼び出しのプロファイラトレースを示しています。この線形層のフォワード呼び出しは、以前のトレースと同様のスケジュール設定(wait=1, warmup=1, active=3)でトレースされています。そのため、CPU および GPU ラーンに 3 つの Profile Steps が表示されます。

トランスポーズは何をしているのか?

Figure 2: The transpose CPU row

Figure 2 のようにプロファイラトレースを拡大すると、aten::t(トランスポーズ)演算子が aten::addmm(乗算と加算)演算子の前にあることに気づきます。これにより、nn.Linear は重みパラメータをトランスポーズしてから入力と乗算していることがすでに推測できます。これが aten::t 演算子が表示される理由です。

注意すべき重要な点は、aten::t は実際にはデータをコピーしたり再編成したりするものではないということです。これは CPU 上で転置行列を表すためにテンソルのメタデータ(形状とストライド)を書き換えるだけであり、GPU でカーネルを起動することはありません。これを検証するには2つの方法があります。1つ目はトレースの GPU ラインを確認する方法、2つ目はプロファイラーテーブル内の aten::t 行を確認し、CUDA 上で要した時間をチェックする方法です。

なぜ個別の mul および add カーネルがないのか?

Figure 3: リニア層のプロファイルに aten::add は存在しない

図3に示されるように、リニア層のディスパッチチェーンには aten::add(バイアス加算)は存在しません。これは、バイアスの加算がエピローグと呼ばれる手法を用いて行列乗算カーネルに*折り畳まれている*からです。

エピローグとは、GEMM(GEneral Matrix Multiply:汎用行列乗算)カーネルが結果を HBM(High Bandwidth Memory:GPU のメインメモリ)へ書き戻す直前に行う小さな計算処理のことです。バイアスの加算、活性化関数の適用、定数によるスケーリングなどはすべて古典的なエピローグの例です。エピローグの目的は、メモリアクセスが演算コストを高めるため、HBM への読み込みまたは書き出しを2回行うことを回避することにあります。

nn.Linear は torch.nn.functional.linear を呼び出し、それがさらに aten::linear を呼び出します。aten::linear は入力を確認し、バイアスが渡されていることに気づくと、行列乗算と加算を別々に行うのではなく、代わりに aten::addmm(bias, x, weight) をディスパッチします。addmm は以下のように計算を行います。

out = x @ weight.T + bias

GPU で実行される cuBLAS GEMM カーネルには、バイアス加算バリアントが組み込まれており、これが aten::addmm が選択するカーネルです。この加算は独立したカーネルとして現れることはありません。なぜなら、それは行列乗算カーネルの書き戻し(writeback)の一部であり、まさにエピローグ(epilogue)が果たす役割そのものだからです。

ここで注意すべき微妙な点があります。Part 1 の --compile 下 でご覧になったカーネル(addmm)は、すでに eager モードの nn.Linear が使用しているものです。torch.compile にはここで融合させるべきものが残っていません。これが次に検証する内容です。

--compile は単一の Linear に役立つか?

順伝播呼び出しをコンパイルして、プロファイラトレースを確認してみましょう。(プロファイラトレースは次のセクションで可視化されます)

uv run 02_linear.py --batch 1024 --in_dim 32 --out_dim 64 --compile

uvx trace-util traces -b traces

単一の nn.Linear の順伝播における eager モードとコンパイルされたモードのトレースを比較すると、以下のことがわかります。

  • GPU 上では同じ cuBLAS GEMM カーネルが実行される。
  • CPU 上では同じ aten::addmm オペレーション(op)が使用される。
  • CPU ラーンには、compile に固有の追加行がいくつか存在する。

これは深く理解しておく価値があります。モデルが遅いと感じると、誰もが反射的に torch.compile を使おうとする傾向があります。しかし、バイアス付きの単一の GEMM 演算においては、compile の効果は非常に限定的です。これはバグではなく、compile が何らかの融合を行うためには複数の操作が必要であるというだけの話です。これを証明するために、MLP(多層パーセプトロン)を見てみましょう。

トランスポーズはどこへ行ったのか?カーネルレイアウトと事前演算

2 つのトレース(イーガーモードとコンパイルモード)を注意深く読めば、イーガー CPU ディスパッチチェーンにはコンパイル版よりも多くの要素が含まれていることに気づくでしょう。

図 4: aten::linear が aten::t(トランスポーズ)を経て aten::addmm を通るイーガーディスパッチチェーン

図 5: トランスポーズを介さず、直接 aten::addmm が呼び出されるコンパイルディスパッチチェーン

aten::linear 内部のイーガー CPU ディスパッチチェーンは、まず aten::t に続き、次に aten::addmm です(図 4)。aten::t が実際に何を行うのかを理解するには、*ストライド*と*ビュー*について少し立ち寄る必要があります。

テンソルはメモリ上では、1 つの平坦で連続した数値の列としてデータを格納しています。形状(shape)とストライド(stride)は、その列の上に位置するメタデータであり、PyTorch がどのように走査するかを指示します:ストライドが (s0, s1) である場合、「行を 1 つ進むには s0 要素分ステップし、列を 1 つ進むには s1 要素分ステップする」という意味です。メタデータを変更すれば、コピーなしで同じ生データを異なる*ビュー*として取得できます。

>> M = torch.tensor([[0, 1],

... [2, 3],

... [4, 5]])

>> M.shape, M.stride()

(torch.Size([3, 2]), (2, 1)) # 行あたり 2 ステップ、列あたり 1 ステップ

>> T = M.t() # トランスポーズ

>> T.shape, T.stride()

(torch.Size([2, 3]), (1, 2)) # 形状とストライドが入れ替わり、データは不変

>> T

tensor([[0, 2, 4],

[1, 3, 5]])

>> T.flatten() # 強制的に実体化されるため、データ順序が入れ替わる

tensor([0, 2, 4, 1, 3, 5])

M.t() は数値を一つも移動させませんでした。これはストライドを入れ替えた新しいビューを返すものであり、行ごとに読み込むと、元のバッファの 0, 1, 2, 3, 4, 5 が転置された順序で走査されることになります。基盤となるデータは同一であり、異なるのはメタデータだけです。

これはまさに linear レイヤー内部で aten::t が行うことです:新しいテンソルを割り当てたりデータをコピーしたりするのではなく、書き換えられたストライドを持つ重みの *ビュー* を生成します。

Figure 5 で確認できるように、compile は GPU カーネルを削除したわけではありません。そのビューをディスパッチするための *CPU オーバーヘッド* を削除しました。Inductor はコンパイル時にビューチェーンを追跡し、結果のストライドを一度計算して、それらのストライドがハードコードされた直接の aten::addmm 呼び出しを生成しました。GPU が同じ数学演算を行う間、数マイクロ秒分の CPU 作業が消滅します。

当然のことながら、入力データがコンパイラによって事前に計算されたストライドに違反すると、エラーが発生します。

両方のトレースにおける GPU ラーンを見ると、フォワードパスごとにカーネルはちょうど一つであり、その二回とも *同じ* カーネルです:

cutlass_80_wmma_tensorop_bf16_s161616gemm_bf16_32x32_32x1_tn_align8

もし転置カーネルが実行されなかったなら、GEMM が重み行列を転置された順序で読み込むように誰が教えたのでしょうか?その答えはカーネル名にあります。接尾辞を見てみましょう:

cutlass_80_wmma_tensorop_bf16_s161616gemm_bf16_32x32_32x1_tn_align8

^^

この tn がレイアウト記述子です。cuBLAS と CUTLASS は、入力レイアウトの各組み合わせに対して *個別のカーネルバイナリ* を事前コンパイルしています。

n(非転置)と t(転置)は、カーネルが内部ループ内で入力をどのように走査するかを示します。ディスパッチャの役割は、入力ストライドを確認し、どの接尾辞の組み合わせが一致するかを判断して、適切な事前コンパイル済みカーネルを選択することです。

**

プロファイラトレースにおけるカーネル名は、カーネルのアイデンティティのハッシュダンプです。2 つの実行で同じカーネル名が表示される場合、GPU は同じ作業を行っています。名前が異なる場合(例:_tn_ と _nn_、bf16 と fp16、または s16816gemm と s161616gemm)、GPU は異なる作業を行っており、ディスパッチャが別の分岐を選択したことを意味します。この名前を読み解く方法を学ぶことは、トレースを比較する際に最も有用な習慣の一つです。

3 つの Linear をスタック:MLP

このセクションでは、Multilayer Perceptron(MLP)をプロファイルします。より興味深いものにするため、GeGLU アクティベーション変種を持つフィードフォワードネットワークをプロファイルします(これは実際には非常に頻繁に使用されています)。また、これは深層学習研究の歴史において書かれた最も優れた行の一つへのオマージュでもあります(図 6)。

図 6: GLU Variants Improve Transformer ペーパーの結論セクション。

class SimpleGeGLUMLP(nn.Module):

def __init__(self, dim, hidden):

super().__init__()

self.gate_proj = nn.Linear(dim, hidden, bias=False)

self.up_proj = nn.Linear(dim, hidden, bias=False)

self.down_proj = nn.Linear(hidden, dim, bias=False)

def forward(self, x):

g = self.gate_proj(x)

u = self.up_proj(x)

h = F.gelu(g, approximate="tanh")

m = h * u

y = self.down_proj(m)

return y

完全なスクリプトはここで見つけることができます:03_simple_mlp.py。以下のように実行してください:

uv run 03_simple_mlp.py --batch 64 --seq 128 --dim 768 --hidden 3072

uvx trace-util traces -b traces

トレースを開く前に、一緒に何を期待すべきか考えてみましょう。forward 関数はかなりの計算量を行いますが、その大部分はすでに私たちに馴染みのあるものです。

各 nn.Linear レイヤーに対して 1 つずつ、合計 3 つの aten::linear ディスパッチ(dispatch)が発生することを期待できます。また、GeLU と乗算演算それぞれに対して 1 つずつ、計 2 つのポイントワイズカーネル起動も予想されます。トレースを見る前にこの期待を形成することは、プロファイリングの旅において最も有用な習慣の一つです:トレースを読むのは、ゼロから推測を作るためではなく、仮説を*確認または反証するため*に行うものです。

Figure 7: GeGLU MLP のプロファイラトレース

Figure 8: 線形投影 CPU ラーンにおける「occupancy query」のハイライト

Figure 7 から、私たちの直感が正しかったと自分自身を褒め称えることができます。1 つのフォワードパス(1 つの mlp_fwd)あたり、GPU は正確に 5 つのカーネルを実行します。Figure 8 では、線形投影レイヤーの CPU ラーンで確認できる「occupancy query」がハイライトされています。

Op

CPU op

GPU kernel

launches

gate_proj

aten::linear

ampere_bf16_s16816gemm_bf16_128x128_...

occupancy query + cudaLaunchKernel

up_proj

aten::linear

ampere_bf16_s16816gemm_bf16_128x128_...

occupancy query + cudaLaunchKernel

gelu

aten::gelu

vectorized_elementwise_kernel<4, GeluCUDAKernelImpl...>

cudaLaunchKernel

h * u

aten::mul

vectorized_elementwise_kernel<4, ...MulFunctor...>

cudaLaunchKernel

down_proj

aten::linear

ampere_bf16_s16816gemm_bf16_128x256_...

occupancy query + cudaLaunchKernel

3 つの GEMM(行列積演算)は、それぞれ起動前に cudaOccupancyMaxActiveBlocksPerMultiprocessor 呼び出しを余分に行っています。これについては Part 1 で別のセクションを設けており、こちらで確認できます。これは cuBLAS がグリッドサイズを決定しているためです。一方、ポイントワイズ演算(GeLU と mul)は、occupancy クエリを行わずに直接起動されます。つまり、「線形演算」は実際にはクエリと起動の両方を含みますが、「ポイントワイズ演算」は単なる起動のみです。

Figure 9: The table shows that some ops launch zero kernels

aten::t, aten::transpose, aten::reshape, aten::view, aten::as_strided, および aten::_unsafe_view の各演算は、ゼロ個のカーネルを起動します。これらは表(Figure 9)で CUDA 時間が 0.000us と表示されていますが、これは CPU 上でテンソルのメタデータ(形状とストライド)を書き換えるだけだからです。表をスキャンする読者は、線形演算ごとに約 6 つの演算名を目にしますが、GPU に到達するのはそのうちの一つ(mm)だけです。

なぜ GEMM カーネルには 2 種類あるのか?

MLP は、行列積のために [batch, seq, dim] を [batch * seq, dim] にフラット化します。コマンドラインでの呼び出しではバッチサイズに 64、シーケンス長に 128 を使用したため、以下の 8192(batch * seq = 64 * 128)という値がここから来ています。

トレースより:

Linear

aten::mm input dims

M·K·N

cuBLAS kernel

avg CUDA

gate_proj

[8192,768] x [768,3072]

8192·768·3072

…128x128…stages_32x5_tn

0.19ms

up_proj

[8192,768] x [768,3072]

8192·768·3072

…128x128…stages_32x5_tn

0.19ms

down_proj

[8192,3072] x [3072,768]

8192·3072·768

…128x256…stages_64x3_tn

0.17ms

これら 3 つの GEMM(General Matrix Multiplication、一般行列積)はすべて同じ FLOP 数を持ち、それぞれ約 38.7 GFLOP(2·8192·768·3072)ですが、down_proj は約 10% 高速です。同じ計算量でありながら形状が異なる(N が 3072 ではなく 768)ため、cuBLAS はその形状に最適なタイル(128×256 で、より深い stages_64x3 パイプラインを持つもの)を選択し、リユース性を高めています。

タイル処理について深く学びたい場合は、こちら が素晴らしい入門リソースです。

まさにこの理由から、表には 2 つの GEMM の行(Figure 9)が記載されています:128x128 の行は gate と up を示し、128x256 の行は down を示しています。

torch.compile は何をするのか?

フォワードメソッドをコンパイルしてトレースを可視化する前に、再び頭の中で「トレースで何を期待するか」を考える演習を行いましょう。これは楽しい実験であり、自分で何かをプロファイリングするたびに繰り返すべき重要な作業です。常に直感を頼りにし、何かが予想と一致しない瞬間には立ち止まってその理由を探る癖をつけましょう。

uv run 03_simple_mlp.py --batch 64 --seq 128 --dim 768 --hidden 3072 --compile

uvx trace-util traces -b traces

Figure 10: The profiler trace for the compiled GeGLU MLP

Eager モードでは、各 nn.Linear が dispatcher ops の連鎖(aten::linear → aten::t → aten::transpose → aten::matmul → aten::reshape → aten::mm)に展開されます。これらは、ATen が実際の GEMM に到達する前に通る高レベルなラッパーです。torch.compile はこの連鎖を除去します。

コンパイルされたグラフが実行される頃には、linear も matmul も transpose も reshape も存在せず、それらのメタデータ演算は mm の呼び出し方に折り畳まれています。Figure 10 に示すように、3 つの裸の aten::mm 外部呼び出しが見られます。これが同じ GEMM であることの証拠は、カーネル名が eager モードとバイト単位で完全に一致していることです:ゲートおよびアップ用には ...128x128...stages_32x5_tn、ダウン用には ...128x256...stages_64x3_tn です。

The fused Triton kernel

Figure 11: The fused Triton kernel

これはコンパイルの教訓全体の要です。2 つの eager ポイントワイズカーネル(GeLU と mul)と reshape が、1 つのカーネル triton_poi_fused__unsafe_view_gelu_mul_0 に統合されました(Figure 11)。この名前を解読してみましょう:

  • triton: Inductor の Triton バックエンドによって生成されたもの(cuBLAS でも ATen でもありません)。
  • poi: pointwise(Inductor はポイントワイズカーネルに poi、リダクションに red、永続的リダクションに per というタグを付けます)。
  • fused__unsafe_view_gelu_mul: 統合された演算:_unsafe_view(reshape)、GeLU、および mul です。
  • 0: グラフ内の一意な ID です。

なぜこれが勝利となるのか?イーグルモードでは、中間変数 h = gelu(g) は [8192, 3072] の bf16 テンソル(約 50 MB)であり、GeLU カーネルがこれを HBM に書き込み、mul カーネルが即座に読み戻します。融合により、このデータはレジスタ(チップ内部に存在し、HBM よりも近いメモリ)内に保持されます。Triton カーネルは g と u をそれぞれ 1 回だけ読み取り、gelu(g) * u を計算して結果を 1 回書き込みます。これにより、中間変数がグローバルメモリを経由する往復のラウンドトリップが完全に消滅します。

手動調整されたカーネルを使ってみましょう

これまで PyTorch(イーグルモード)とコンパイラー(torch.compile)にカーネル選択を任せてきました。今度は人間のエキスパートが手書きで調整したカーネルを組み込みます。ここでは、kernels ライブラリを使用して Hugging Face Hub から簡単に取得できる LigerGEGLUMLP レイヤーを使用します。

from kernels import get_kernel

kernels_layers = get_kernel("kernels-community/liger-kernels", version=1).layers

kernels_geglu_mlp = kernels_layers.LigerGEGLUMLP(Config()).to(device, dtype=torch.bfloat16).eval()

完全なスクリプトはこちらです:03_kernels_mlp.py。

uv run 03_kernels_mlp.py --batch 64 --seq 128 --dim 768 --hidden 3072

uvx trace-util traces -b traces

Figure 12: The profiler trace for the LigerGEGLUMLP layer

Figure 12 shows the profile for the LigerGEGLUMLP layer using the Liger kernels from the Hub.

なぜカーネルライブラリを使用するのか

Triton や CUDA でカーネルを書くことは一つの課題ですが、それらを配布することは別の課題です。カーネルは、GPU アーキテクチャ、CUDA バージョン、PyTorch バージョンの組み合わせが完全に一致するようにコンパイルされなければなりません。これは通常失敗するステップです(「自分のマシンでは動く」、nvcc の欠如、間違った Triton バージョンなど)。

kernels ライブラリは、このビルドステップをあなたのマシンから外します。get_kernel("kernels-community/liger-kernels", version=1) は、Hugging Face Hub から事前ビルドされバージョンが固定されたカーネルパッケージをダウンロードし、ローカル(ここでは ~/.cache/...kernels-community--liger-kernels 以下)にキャッシュします。その利点は以下の通りです。

  • カーネルは CI 内で一度だけコンパイルされ、多くのアーキテクチャとバージョンの組み合わせに対応しています。自分でコンパイルするのではなく、正しいバイナリをダウンロードできます。
  • version=1 は正確なビルドを固定するため、スクリプトを実行する誰もが同じカーネルを使用します。「パッケージを更新したら遅くなった」という現象は発生しません。
  • パッケージには .layers 属性が公開されており、ドロップインで使える nn.Modules(LigerGEGLUMLP など)が含まれています。モデルのモジュールをそれらに置き換えるだけで、モデルの他の部分は一切変更する必要がありません。

なぜチューニングされたカーネルが優れているのか

「チューニングされた」と言うとき、私たちは2つの具体的なことを指しており、どちらもトレースで確認できます。

Figure 13: コンパイル実行では、GEMM が実行される前に Dynamo、ガード、プロローグなどの事前演算のコストが発生します

Figure 14: Liger カーネルには事前演算がありません — 通常そこにあるべきボックスは空です

  • この融合は組み込まれています。LigerGEGLUMLP の順伝播は、down_proj(LigerGELUMulFunction.apply(gate_proj(x), up_proj(x))) です。LigerGELUMulFunction は単一のTritonカーネル_geglu_tanh_forward_kernel を実行し、gelu(gate) * up を1回のパスで計算します。これは、中間データがHBM(High Bandwidth Memory)を経由して往復しないという、torch.compile で見たのと同じ結果です。図 13 と図 14 に示されているように、ここではコンパイラなしでこの恩恵を得られます(Dynamoガードなし、コンパイル遅延なし、再コンパイルリスクなし)。
  • 起動パラメータはハードウェアに合わせて選択されています。カーネルがランダムにブロックサイズを推測することはありません。Liger の calculate_settings は、列数からこれらの値を選択します。

ここでトレードオフについて正直になる価値があります。なぜなら、生データだけを見ると誤解を招く可能性があるからです。Liger カーネルの実行時間は92.8 µsですが、コンパイル実行におけるInductorの融合カーネルは89.4 µsです。一見すると手書きのカーネルがわずかに遅いように見えますが、この比較にはその価値を正当化するコストが見落とされています。

torch.compile は静的形状に対して特別化されます。Inductor の 89.4 µs のカーネルが高速なのは、まさに*この特定の*[8192, 3072]の問題のために生成されたからです。バッチサイズやシーケンス長、隠れ次元を変更すると、Dynamo は再トレースを行い、新しい特別化カーネルを得るために再びコンパイルコストを支払う必要があります。

つまり、真の選択は「遅い人手によるカーネル vs 高速なコンパイル済みカーネル」ではありません。それは汎用的に高速なカーネル vs 特定の1つの入力形状に特化したカーネルです。Liger カーネルは1セットのパラメータを起動パラメータとして受け取り、再コンパイルを行うことなく*あらゆる*形状に対して実行します。これは、形状ごとの特化がもたらす数マイクロ秒の速度向上を犠牲にする代わりに、変化する形状に対する堅牢性を獲得するものです。

結論

以下の表は、各ステップで GPU に何の変更を加え、何が変更されなかったかをまとめています。

Setup | What changed | What stayed the same

---|---|---

Eager nn.Linear | ベースライン:バイアス加算はすでに GEMM エピローグ(addmm)に折り込まれており、*1 つの* cuBLAS カーネルであり、行列乗算と加算の組み合わせではない | —

Compiled nn.Linear | いくつかの CPU ディスパッチ演算(aten::t ビューのブックキーピング)が消失 | 同じ単一の cuBLAS GEMM カーネルで、バイト単位でも同一。コンパイルでは融合は行われない

Eager MLP | GPU カーネル5個:3 つの GEMM + GeLU + 乗算。[8192, 3072] の中間結果が HBM(High Bandwidth Memory)をフルラウンドトリップする | 各 GEMM は、独立した線形演算と同じバイアスなし cuBLAS カーネルである

Compiled MLP | GeLU + 乗算 + リシェイプが1 つの融合された Triton カーネルに統合され、中間結果はレジスタ内に残る。コンパイル前処理(Dynamo、ガード)のコストを払う | 3 つの GEMM は変更されず、cuBLAS カーネル名も同一

Liger MLP | 同じ融合だが、ハードウェアチューニングされた起動パラメータを持つ手書きの Triton カーネルに組み込まれており、Dynamo やガード、コンパイル遅延は不要 | 3 つの GEMM は依然として同じ cuBLAS カーネルである

一つだけ継続すべき習慣があるとすれば、それはトレースの前に行ってきたもの、すなわちまず推測し、その後確認するというものです。トレースに何が含まれると予想するかを述べ、それを開き、不一致が生じた場合は画面で最も興味深い出来事として扱ってください。

これは「PyTorch におけるプロファイリング」というシリーズの二番目のステップでした。次の投稿では、この MLP ブロックからアテンションブロックへと、そして最終的には完全なモデルへと、さらに階段を上って進んでいきます。

記事の初期草案に対するレビューを寄せてくださった Noe Flandre 氏と Pedro Gabriel Gengo Lourenço 氏に感謝いたします!

原文を表示

Back to Articles

In the first part of this series "Profiling in PyTorch", we used torch.add(torch.matmul(x, w), b) to learn how to read PyTorch profiler traces. We also discussed several other topics that came our way - the CPU dispatch chain, launch overhead, the difference between an overhead-bound and a compute-bound regime, and some internals of torch.compile.

In the second iteration (this blog post), we climb one rung up the ladder. We replace the hand-written matmul-add pair with an nn.Linear (with bias=True). This is the building block every deep learning model uses. We then stack three of them (specific to our example), with an activation in between, to form a Multilayer Perceptron (MLP) block.

The scripts for this blog post live here: 02_linear.py, 03_simple_mlp.py, and 03_kernels_mlp.py. Like before, it helps to open them in a separate tab and walk through the code as you read. We use an NVIDIA A100-SXM4-80GB GPU to run the scripts. It is really easy to set up a GPU on the Hugging Face infrastructure and experiment with the scripts using Dev Mode with Spaces. One could also run the scripts with the Hugging Face Jobs pipeline.

Before we begin, a quick recap of two ideas we will lean on repeatedly:

  • A GPU kernel is a program that runs in parallel on many threads of the GPU.
  • The CPU schedules and launches these kernels. Most of the PyTorch overhead you see in a profiler trace is this scheduling work.

From matmul-add to Linear

nn.Linear is a module wrapper around the same matrix multiplication and addition we already profiled in Part 1. The only difference is that it owns its weight and bias as parameters and exposes a forward method that PyTorch users have grown familiar with.

code
# bias=True would truly emulate the multiplication and addition
# operations we have seen in part 1 of the series
linear_layer = nn.Linear(in_dim, out_dim, bias=True)
y = linear_layer(x)

The operation at hand can be written as:

code
y = x @ w.T + b

Where x is the input, w is the weight and b is the bias. Let's run 02_linear.py and check the profile.

code
uv run 02_linear.py --batch 1024 --in_dim 32 --out_dim 64
uvx trace-util traces -b traces

trace-util is a utility that will sync your traces to a Hugging Face bucket and then provide the Preffeto URLs on your terminal.

Figure 1: Profiler trace of nn.Linear

Figure 1 shows the profiler trace of a forward call of the linear layer. We trace the forward call of the linear layer with a similar schedule setup as the previous traces, with wait=1, warmup=1 and active=3. This is why we see three Profile Steps in the CPU and GPU lanes.

What is the transpose doing?

Figure 2: The transpose CPU row

If we zoom into the profiler trace, as we do in Figure 2, we notice an aten::t (transpose) op before the aten::addmm (multiplication and addition) op. We can already figure out that nn.Linear transposes the weight parameter and then multiplies it with the input. This is the reason we see an aten::t op.

An important thing to notice is that aten::t does not really copy or reorganize data: it only rewrites tensor metadata (shape and stride) on the CPU to represent the transposed matrix. It does not launch a kernel on the GPU. One can verify this two ways: by looking at the GPU lane in the trace, or by checking the aten::t row in the profiler table and the time it took on CUDA.

Why are there no separate mul and add kernels?

Figure 3: No aten::add in the profile of a linear layer

There is no aten::add (the bias addition) in the dispatch chain of the linear layer, as seen in Figure 3. This is because the bias addition has been *folded* into the matrix multiplication kernel, using what is called an epilogue.

An epilogue is a small computation that a GEMM (GEneral Matrix Multiply) kernel does at the very end, just before it writes its result back to HBM (High Bandwidth Memory, the GPU's main memory). Adding a bias, applying an activation, or scaling by a constant are all classic epilogues. The point of an epilogue is to avoid loading or writing to HBM a second time, since memory traffic makes an operation expensive.

nn.Linear calls torch.nn.functional.linear, which, in turn, calls aten::linear. aten::linear looks at the inputs, notices that a bias was passed, and dispatches aten::addmm(bias, x, weight) instead of doing a matmul and an add separately. addmm computes:

code
out = x @ weight.T + bias

The cuBLAS GEMM kernel that runs on the GPU has a bias-add variant built in, and that's the kernel aten::addmm picks. The add never appears as a separate kernel because it is part of the matmul kernel's writeback, which is exactly what an epilogue is.

This is the moment to notice something subtle. The kernel you saw in Part 1 under --compile (addmm) is the kernel that eager nn.Linear already uses. There is nothing left for torch.compile to fuse here, which is the next thing we will verify.

Can --compile help a single Linear?

Let's compile the forward call and look at the profiler trace. (The profiler trace is visualized in the next section)

code
uv run 02_linear.py --batch 1024 --in_dim 32 --out_dim 64 --compile
uvx trace-util traces -b traces

If you compare the eager and compiled traces for a single nn.Linear's forward, you will find:

  • The same cuBLAS GEMM kernel on the GPU.
  • The same aten::addmm op on the CPU.
  • A few extra rows on the CPU lane unique to compile.

This is worth internalizing. A common reflex is to reach for torch.compile whenever a model feels slow. For a single GEMM-with-bias, compile has very little to do. This is not a bug, this is just that compile needs more than one operation to possibly do any fusing. Let's prove that by looking at an MLP.

Where did the transpose go? Kernel layouts and pre-ops

A careful reader of the two traces (eager vs compile) will notice that the eager CPU dispatch chain has more in it than the compiled one.

Figure 4: Eager dispatch chain where aten::linear walks through aten::t (transpose) and then aten::addmm

Figure 5: Compiled dispatch chain where aten::addmm is called directly, with no transpose

The eager CPU dispatch chain inside aten::linear is aten::t followed by aten::addmm (Figure 4). To understand what aten::t actually does, we need a quick detour into *strides* and *views*.

A tensor stores its data as one flat, contiguous run of numbers in memory. The shape and stride are metadata that sit on top of that run and tell PyTorch how to walk it: a stride of (s0, s1) means "step s0 elements to move one row, step s1 to move one column". Change the metadata and you get a different *view* of the *same* raw data, with no copy:

code
>>> M = torch.tensor([[0, 1],
...                   [2, 3],
...                   [4, 5]])
>>> M.shape, M.stride()
(torch.Size([3, 2]), (2, 1))   # two steps per row, one step per column

>>> T = M.t()                  # transpose
>>> T.shape, T.stride()
(torch.Size([2, 3]), (1, 2))   # shape and stride swapped, data untouched
>>> T
tensor([[0, 2, 4],
        [1, 3, 5]])
>>> T.flatten()                # forced to materialize, so the data is reordered
tensor([0, 2, 4, 1, 3, 5])

M.t() did not move a single number. It returned a new view whose strides are swapped, so reading it row-by-row now walks the original buffer 0, 1, 2, 3, 4, 5 in transposed order. The underlying data is identical; only the metadata differs.

This is exactly what aten::t does inside the linear layer: it does not allocate a new tensor or copy any data, it produces a *view* of the weight with rewritten strides.

As we can see in Figure 5, compile did not remove a GPU kernel: it removed the *CPU overhead* of dispatching that view. Inductor traced through the view chain at compile time, computed the resulting strides once, and emitted a direct aten::addmm call with those strides hard-coded. A few microseconds of CPU work disappear while the GPU does identical math.

As one would expect, when the input data violates the strides precomputed by the compiler, it will throw an error.

If you look at the GPU lane in both traces, there is exactly one kernel per forward, and it is the *same* kernel both times:

code
cutlass_80_wmma_tensorop_bf16_s161616gemm_bf16_32x32_32x1_tn_align8

If no transpose kernel ran, who taught the GEMM to read the weight matrix in transposed order? The answer is in the kernel's name. Look at the suffix:

code
cutlass_80_wmma_tensorop_bf16_s161616gemm_bf16_32x32_32x1_tn_align8
                                                          ^^

That tn is the layout descriptor. cuBLAS and CUTLASS precompile a *separate kernel binary* for each combination of input layouts.

n (non-transposed) and t (transposed) describe how a kernel walks its input during the inner loop. The dispatcher's job is to look at the input strides, decide which suffix combination matches, and pick the right precompiled kernel.

The kernel name in a profiler trace is a hash dump of the kernel's identity. If two runs show the same kernel name, the GPU is doing the same work. If they differ (e.g., _tn_ vs _nn_, bf16 vs fp16, or s16816gemm vs s161616gemm) then the GPU is doing different work, and the dispatcher took a different branch. Learning to read this name is one of the most useful habits when comparing traces.

Stacking three Linears: the MLP

In this section, we will profile a Multilayer Perceptron (MLP). To make this more interesting, we will profile a feed-forward network with the GeGLU activation variant (which is quite heavily used in practice). This is also our way of paying tribute to one of the greatest lines ever written in the history of deep learning research (Figure 6).

Figure 6: The conclusion section of the GLU Variants Improve Transformer paper.

code
class SimpleGeGLUMLP(nn.Module):
    def __init__(self, dim, hidden):
        super().__init__()
        self.gate_proj = nn.Linear(dim, hidden, bias=False)
        self.up_proj = nn.Linear(dim, hidden, bias=False)
        self.down_proj = nn.Linear(hidden, dim, bias=False)

    def forward(self, x):
        g = self.gate_proj(x)
        u = self.up_proj(x)
        h = F.gelu(g, approximate="tanh")
        m = h * u
        y = self.down_proj(m)
        return y

You will find the entire script here: 03_simple_mlp.py. Execute it like so:

code
uv run 03_simple_mlp.py --batch 64 --seq 128 --dim 768 --hidden 3072
uvx trace-util traces -b traces

Before we open the trace, let's think together about what we should expect to see. The forward function does a fair amount of computation, but most of it is already familiar to us.

We should expect three aten::linear dispatches, one for each nn.Linear layer. We should also expect two pointwise kernel launches, one for the GeLU and one for the multiplication. Forming this expectation before looking is the single most useful habit in the profiling journey: you read the trace to *confirm or break* a guess, not to form one from scratch.

Figure 7: The profiler trace for a GeGLU MLP

Figure 8: The occupancy queries highlighted in the linear projection CPU lane

From Figure 7 we can pat ourselves on the back, as our intuition was correct. Per forward pass (one mlp_fwd), the GPU runs exactly 5 kernels. Figure 8 highlights the "occupancy query" as seen in the CPU lane for the linear projection layers.

Op

CPU op

GPU kernel

launches

gate_proj

aten::linear

ampere_bf16_s16816gemm_bf16_128x128_...

occupancy query + cudaLaunchKernel

up_proj

aten::linear

ampere_bf16_s16816gemm_bf16_128x128_...

occupancy query + cudaLaunchKernel

gelu

aten::gelu

vectorized_elementwise_kernel<4, GeluCUDAKernelImpl...>

cudaLaunchKernel

h * u

aten::mul

vectorized_elementwise_kernel<4, ...MulFunctor...>

cudaLaunchKernel

down_proj

aten::linear

ampere_bf16_s16816gemm_bf16_128x256_...

occupancy query + cudaLaunchKernel

The three GEMMs each do an extra cudaOccupancyMaxActiveBlocksPerMultiprocessor call before the launch. We have a separate section on this in Part 1, you can find it here. That is cuBLAS sizing the grid. The pointwise ops (GeLU and mul) launch directly, with no occupancy query. So "a linear" is actually query + launch, while "a pointwise op" is just launch.

Figure 9: The table shows that some ops launch zero kernels

The aten::t, aten::transpose, aten::reshape, aten::view, aten::as_strided, and aten::_unsafe_view ops launch zero kernels. They show 0.000us of CUDA time in the table (Figure 9) because they only rewrite tensor metadata (shape and stride) on the CPU. A reader scanning the table sees around six op names per linear, but only one of them (mm) ever reaches the GPU.

Why are there two types of GEMM kernels?

The MLP flattens [batch, seq, dim] to [batch * seq, dim] for the matmul. In our command-line invocation we used 64 for batch and 128 for seq, so that's where the 8192 (batch * seq = 64 * 128) below comes from.

From the trace:

Linear

aten::mm input dims

M·K·N

cuBLAS kernel

avg CUDA

gate_proj

[8192,768] x [768,3072]

8192·768·3072

…128x128…stages_32x5_tn

0.19ms

up_proj

[8192,768] x [768,3072]

8192·768·3072

…128x128…stages_32x5_tn

0.19ms

down_proj

[8192,3072] x [3072,768]

8192·3072·768

…128x256…stages_64x3_tn

0.17ms

All three GEMMs have the same FLOP count, 2·8192·768·3072 ≈ 38.7 GFLOP each, yet down_proj is about 10% faster. Same work, different shape (N=768 instead of 3072), so cuBLAS picks a different tile (128×256, with a deeper stages_64x3 pipeline) that gets better reuse for that shape.

If you want to learn more about tiling in depth, here is a great resource to get started with.

This is exactly why the table had two GEMM rows (Figure 9): the 128x128 row is gate+up and the 128x256 row is down.

What does torch.compile do?

Before compiling the forward method and visualizing it, let's do the mental exercise again of asking ourselves what we expect to see in the trace. This is a fun experiment, and an important one to repeat every time you profile something yourself. Always build on your intuition, and the moment something does not match, stop and figure out why.

code
uv run 03_simple_mlp.py --batch 64 --seq 128 --dim 768 --hidden 3072 --compile
uvx trace-util traces -b traces

Figure 10: The profiler trace for the compiled GeGLU MLP

In eager mode, each nn.Linear was expanded into a chain of dispatcher ops (aten::linear → aten::t → aten::transpose → aten::matmul → aten::reshape → aten::mm). Those are the high-level wrappers that ATen walks through before reaching the real GEMM. torch.compile removes that chain.

By the time the compiled graph runs, there is no linear, no matmul, no transpose or reshape and those metadata ops were folded into how mm is called. We can see three bare aten::mm external calls (Figure 10). The proof that it is the same GEMM is that the kernel names are byte-for-byte identical to eager: ...128x128...stages_32x5_tn for gate and up, and ...128x256...stages_64x3_tn for down.

The fused Triton kernel

Figure 11: The fused Triton kernel

This is the headline of the whole compile lesson. The two eager pointwise kernels (GeLU and mul) plus a reshape collapsed into one kernel, triton_poi_fused__unsafe_view_gelu_mul_0 (Figure 11). Let's decode the name:

  • triton: generated by Inductor's Triton backend (not cuBLAS, not ATen).
  • poi: pointwise (Inductor tags pointwise kernels poi, reductions red, and persistent reductions per).
  • fused__unsafe_view_gelu_mul: the ops it merged: the _unsafe_view (reshape), the GeLU, and the mul.
  • 0: the unique id within the graph.

Why is this a win? In eager mode, the intermediate h = gelu(g) is a full [8192, 3072] bf16 tensor (around 50 MB) that the GeLU kernel writes to HBM and the mul kernel immediately reads back. Fusion keeps it in registers (memory that resides inside the chip and are closer than the HBM). The Triton kernel reads g and u once, computes gelu(g) * u, and writes the result once. One whole round trip of the intermediate through global memory is gone.

Let's use hand tuned kernels

So far we have let PyTorch (eager) and the compiler (torch.compile) pick our kernels. Now we plug in a kernel that a human expert wrote and tuned by hand. We use the LigerGEGLUMLP layer, that we can easily fetch from the Hugging Face Hub with the kernels library.

code
from kernels import get_kernel

kernels_layers = get_kernel("kernels-community/liger-kernels", version=1).layers
kernels_geglu_mlp = kernels_layers.LigerGEGLUMLP(Config()).to(device, dtype=torch.bfloat16).eval()

The full script is here: 03_kernels_mlp.py.

code
uv run 03_kernels_mlp.py --batch 64 --seq 128 --dim 768 --hidden 3072
uvx trace-util traces -b traces

Figure 12: The profiler trace for the LigerGEGLUMLP layer

Figure 12 shows the profile for the LigerGEGLUMLP layer using the Liger kernels from the Hub.

Why use the kernels library

Writing kernels in Triton or CUDA is one problem and *shipping* them is another. The kernel has to be compiled for your exact combination of GPU architecture, CUDA version, and PyTorch version. This is the step that usually breaks ("works on my machine", missing nvcc, wrong Triton version).

The kernels library moves that build step off your machine. get_kernel("kernels-community/liger-kernels", version=1) downloads a pre-built, version-pinned kernel package from the Hugging Face Hub and caches it locally (here under ~/.cache/...kernels-community--liger-kernels). The benefits are:

  • The kernels are compiled once, in CI, for many architectures and version combinations. You download the right binary instead of compiling it yourself.
  • version=1 pins the exact build, so everyone running your script gets the same kernel. There is no "it got slower after I updated a package".
  • The package exposes a .layers attribute with drop-in nn.Modules (like LigerGEGLUMLP). You swap your module for theirs and nothing else in your model changes.

Why tuned kernels are better

When we say "tuned", we mean two concrete things, and both are visible in the trace.

Figure 13: The compiled run pays for pre-ops (Dynamo, guards, prologue) before any GEMM runs

Figure 14: The Liger kernel has no pre-ops — the box where they would be is empty

  • The fusion is baked in. The LigerGEGLUMLP forward is down_proj(LigerGELUMulFunction.apply(gate_proj(x), up_proj(x))). The LigerGELUMulFunction runs a single Triton kernel, _geglu_tanh_forward_kernel, that computes gelu(gate) * up in one pass. This is exactly what we saw from torch.compile, where the intermediate never makes a round-trip through HBM. We get it here without the compiler, as shown in Figures 13 and 14 (no Dynamo guards, no compile latency, no recompilation risk).
  • The launch parameters were chosen for the hardware. The kernel does not guess its block size at random. Liger's calculate_settings picks them from the column count.

It is worth being honest about the trade-off here, because the raw numbers can be misleading. The Liger kernel runs in 92.8 µs, while Inductor's fused kernel from the compile run was 89.4 µs. At first glance the hand-written kernel looks slightly slower, but that comparison hides the cost that makes it worthwhile.

torch.compile specializes for a static shape. Inductor's 89.4 µs kernel is fast precisely because it was generated for *this exact* [8192, 3072] problem. Change the batch size, the sequence length, or the hidden dimension, Dynamo re-traces, and you pay the compile cost all over again to get a new specialized kernel.

So the real choice is not "slow human kernel vs fast compiled kernel". It is a fast generic kernel vs a kernel specialized for one particular input shape. The Liger kernel takes one set of launch parameters and runs them for *any* shape with no recompilation. It gives up the last few microseconds that per-shape specialization would buy, in exchange for being robust to changing shapes.

Conclusion

The table below collects what each step changed on the GPU and what it left untouched.

Setup

What changed

What stayed the same

Eager nn.Linear

Baseline: bias add is already folded into the GEMM epilogue (addmm), so it is *one* cuBLAS kernel, not a matmul plus an add

—

Compiled nn.Linear

A few CPU dispatch ops (the aten::t view bookkeeping) disappear

Same single cuBLAS GEMM kernel, byte-for-byte. Compile has nothing to fuse

Eager MLP

5 GPU kernels: 3 GEMMs + a GeLU + a mul. The [8192, 3072] intermediate makes a full round-trip through HBM

Each GEMM is still the same bias-free cuBLAS kernel as a standalone linear

Compiled MLP

GeLU + mul + reshape collapse into one fused Triton kernel; the intermediate stays in registers. Pays compile pre-ops (Dynamo, guards)

The 3 GEMMs are untouched with identical cuBLAS kernel names

Liger MLP

Same fusion, but baked into a hand-written Triton kernel with hardware-tuned launch params with no Dynamo, guards, or compile latency

The 3 GEMMs are still the same cuBLAS kernels

If there is one habit to carry forward, it is the one we practiced before every trace: guess first, then look. State what you expect the trace to contain, open it, and treat any mismatch as the most interesting thing on the screen.

This was the second stop in the Profiling in PyTorch series. In the next post we will keep climbing the ladder, moving from this MLP block towards the attention block and, eventually, a full model.

Thanks to Noe Flandre and Pedro Gabriel Gengo Lourenço for their reviews on the early draft of the post!

この記事をシェア

関連記事

Latent Space2026年6月20日 17:06

[AINews] 今日特に大きな出来事はありませんでした

Latent Space は、GLM 5.2 が依然として注目されていると指摘しつつ、AIE WF 2026 の通常チケットが月曜日に完売すると発表しました。同サイト購読者向けに限定割引を提供し、参加者には Warp や Datadog などからのスポンサークレジットも付与されます。

TechCrunch AI★42026年6月20日 01:01

米国がアンソロピックの「Fable 5」発売を禁止、しかし市場は動じず

米国政府は国家安全保障上の懸念から、アマゾンの研究者らがガードレール回避手法を発見したとして、アンソロピックに対し最新モデル「Fable 5」と「Mythos 5」の販売差し止めを命じた。サイバーセキュリティ研究者らはこの措置が危険だとする公開書簡に署名し、同社も他モデルでも同様の抜け道が存在すると指摘している。

GitHub Blog★42026年6月20日 01:00

社内データ分析エージェントの構築方法について

GitHub は、大規模なデータ組織が直面する自己完結型のデータアクセスと洞察提供の課題に対し、AI を活用した信頼性の高い解決策として、社内でデータ分析エージェントを構築したことを発表した。

今日のまとめ

AI日報で今日の重要ニュースをまとめ読み

ニュース一覧に戻る元記事を読む