PyTorch のプロファイリング(第 2 部):nn.Linear から融合 MLP へ
Hugging Face Blog は、PyTorch のプロファイリング手法を解説し、従来の nn.Linear レベルから計算効率の高い融合 MLP 構造への最適化プロセスと具体的な実装手順を詳述している。
キーポイント
詳細なプロファイリング手法の解説
PyTorch の標準的なツールやカスタムスクリプトを用いて、モデル内のボトルネックとなっている nn.Linear レベルの詳細なパフォーマンス分析を行う方法を段階的に説明している。
融合 MLP 構造への最適化プロセス
複数の線形層を単一の計算グラフに統合する「Fused MLP」への変換手法を示し、メモリ帯域幅の削減と演算効率の向上を実現する具体的なコード例を提供している。
パフォーマンス改善の実証データ
最適化前後のプロファイルデータを比較し、推論速度やトレーニング時間の短縮といった定量的な効果を示すことで、実装の必要性を裏付けている。
影響分析・編集コメントを表示
影響分析
本記事は、大規模言語モデルや深層学習モデルの開発現場において、理論上の最適化が実際のコードにどう反映されるかを具体的に示した点で極めて重要です。開発者がプロファイリングツールを活用してボトルネックを特定し、効率的なアーキテクチャへ移行する際の具体的なロードマップを提供することで、実装コストの削減とパフォーマンス向上に直接寄与します。
編集コメント
PyTorch のパフォーマンスチューニングにおいて、プロファイリングから最適化まで一貫したアプローチを示す貴重な実践ガイドです。特に大規模モデルを扱う開発者にとって、即座に適用可能な知見が得られる内容となっています。
このシリーズの「PyTorch でのプロファイリング」第 1 部では、torch.add(torch.matmul(x, w), b) を用いて PyTorch プロファイラーのトレースを読み取る方法を学びました。また、CPU ディスパッチチェーン、起動オーバーヘッド、オーバーヘッドバウンドと計算バウンドのレジームの違い、そして torch.compile の内部構造など、いくつかの他のトピックについても議論しました。
今回の第 2 部(本ブログ記事)では、さらに一歩踏み込みます。手書きの matmul-add ペアを nn.Linear (bias=True) に置き換えます。これはあらゆる深層学習モデルが使用する基本構成要素です。その後、活性化関数を挟みながらこれらを 3 つ積み重ね(例に特化した構成)、Multilayer Perceptron (MLP) ブロックを形成します。
**
このブログ記事のスクリプトはここにあります:02_linear.py、03_simple_mlp.py、および 03_kernels_mlp.py。前回の投稿と同様に、別タブでこれらを開きながらコードを追いかけて読むと理解が深まります。スクリプトの実行には NVIDIA A100-SXM4-80GB GPU を使用しています。Hugging Face のインフラ上で GPU をセットアップし、Dev Mode with Spaces を用いてスクリプトを実験するのは非常に簡単です。また、Hugging Face Jobs pipeline を使用してスクリプトを実行することも可能です。
始める前に、後で繰り返し参照することになる 2 つの概念を簡単に復習しておきましょう:
- GPU カーネル(GPU kernel)とは、GPU の多数のスレッド上で並列実行されるプログラムのことです。
- CPU はこれらのカーネルのスケジューリングと起動を担当します。プロファイラトレースで目にする PyTorch のオーバーヘッドの大部分は、このスケジューリング作業によるものです。
matmul-add から Linear へ
nn.Linear は、Part 1 で既にプロファイル済みの行列乗算と加算をラップするモジュールです。唯一の違いは、重み(weight)とバイアス(bias)をパラメータとして保持し、PyTorch ユーザーが慣れ親しんだ forward メソッドを公開している点にあります。
bias=True は、シリーズのパート 1 で見た乗算と加算演算を真にエミュレートします
linear_layer = nn.Linear(in_dim, out_dim, bias=True)
y = linear_layer(x)
ここで扱っている操作は以下のように記述できます:
y = x @ w.T + b
ここで、x は入力、w は重み、b はバイアスです。02_linear.py を実行してプロファイルを確認してみましょう。
uv run 02_linear.py --batch 1024 --in_dim 32 --out_dim 64
uvx trace-util traces -b traces
trace-util は、トレースを Hugging Face bucket に同期し、ターミナル上で Preffeto URLs を提供するユーティリティです。
Figure 1: Profiler trace of nn.Linear
Figure 1 は、線形層のフォワード呼び出しのプロファイラトレースを示しています。この線形層のフォワード呼び出しを、以前のトレースと同様のスケジュール設定(wait=1, warmup=1, active=3)でトレースしました。そのため、CPU および GPU ラーンに 3 つの Profile Steps が表示されています。
トランスポーズは何をしているのか?
Figure 2: The transpose CPU row
Figure 2 のようにプロファイラトレースを拡大して見ると、aten::addmm(乗算と加算)演算の前に aten::t(トランスポーズ)演算があることに気づきます。これにより、nn.Linear は重みパラメータをトランスポーズしてから入力と乗算していることがすでに推測できます。これが aten::t 演算が表示される理由です。
注意すべき重要な点は、aten::t は実際にはデータをコピーしたり再編成したりするものではないということです。これは CPU 上で転置行列を表すためにテンソルのメタデータ(形状とストライド)を書き換えるだけであり、GPU でカーネルを起動することはありません。これを検証するには2つの方法があります。1つ目はトレースの GPU ラインを確認する方法、もう1つはプロファイラーテーブル内の aten::t 行を確認し、CUDA 上で要した時間をチェックする方法です。
なぜ個別の mul および add カーネルがないのか?
Figure 3: No aten::add in the profile of a linear layer
図3に示されるように、linear レイヤーのディスパッチチェーンには aten::add(バイアス加算)が存在しません。これは、バイアスの加算が「エピローグ」と呼ばれる仕組みを用いて行列乗算カーネルに *folded* (統合)されているためです。
エピローグとは、GEMM (GEneral Matrix Multiply) カーネルが結果を HBM (High Bandwidth Memory、GPU のメインメモリ) に書き戻す直前に実行する小さな計算処理のことです。バイアスの加算、活性化関数の適用、定数によるスケーリングなどはすべて古典的なエピローグの例です。エピローグの目的は、メモリアクセスが操作のコストを高めるため、HBM への読み込みまたは書き出しを2回行うことを回避することにあります。
nn.Linear は torch.nn.functional.linear を呼び出し、それがさらに aten::linear を呼び出します。aten::linear は入力を確認し、バイアスが渡されていることに気づくと、行列乗算と加算を別々に実行するのではなく、代わりに aten::addmm(bias, x, weight) をディスパッチします。addmm は以下のように計算を行います:
out = x @ weight.T + bias
GPU で実行される cuBLAS GEMM カーネルには、バイアス加算バリアントが組み込まれており、それが aten::addmm が選択するカーネルです。この加算は独立したカーネルとして現れることはありません。なぜなら、それはmatmul カーネルの書き戻し(writeback)の一部であり、まさにエピローグ(epilogue)が果たす役割そのものだからです。
ここで注意すべき微妙な点があります。Part 1 の --compile 下 で見たカーネル(addmm)は、すでに eager モードの nn.Linear が使用しているカーネルです。torch.compile にはここで融合できる余地が何もなく、これが次章で検証する内容となります。
--compile は単一の Linear に役立つか?
フォワード呼び出しをコンパイルしてプロファイラトレースを確認してみましょう。(プロファイラトレースは次のセクションで可視化されます)
uv run 02_linear.py --batch 1024 --in_dim 32 --out_dim 64 --compile
uvx trace-util traces -b traces
単一の nn.Linear のフォワードに対する eager モードとコンパイルモードのトレースを比較すると、以下のことがわかります。
- GPU 上では同じ cuBLAS GEMM カーネルが実行される。
- CPU 上では同じ aten::addmm オペレーション(op)が使用される。
- CPU ラーンには、compile モード固有の追加行がいくつか存在する。
これは深く理解しておく価値があります。モデルが遅いと感じた際に、誰もが反射的に torch.compile を使おうとする傾向があります。しかし、バイアス付きの単一 GEMM においては、compile の効果は非常に限定的です。これはバグではなく、compile が何らかの融合を行うためには複数のオペレーションが必要であるという事実によるものです。これを証明するために、MLP(多層パーセプトロン)を見てみましょう。
トランスポーズはどこへ行ったのか?カーネルレイアウトと事前演算
2 つのトレース(Eager モードとコンパイルモード)を注意深く読めば、Eager CPU のディスパッチチェーンにはコンパイル版よりも多くの要素が含まれていることに気づくでしょう。
図 4: aten::linear が aten::t (トランスポーズ) を経由して aten::addmm に至る Eager モードのディスパッチチェーン
図 5: トランスポーズを介さず、直接 aten::addmm が呼び出されるコンパイル版のディスパッチチェーン
aten::linear 内部の Eager CPU のディスパッチチェーンは、まず aten::t に続き、次に aten::addmm です(図 4)。aten::t が実際に何を行うのかを理解するには、*ストライド (strides)* と *ビュー (views)* について少し立ち寄る必要があります。
テンソルは、メモリ上では数値の連続した 1 つのフラットな配列としてデータを格納しています。形状(shape)とストライドは、その配列の上に位置するメタデータであり、PyTorch がどのように走査するかを指示します:ストライドが (s0, s1) の場合、「行を 1 つ進むには s0 要素分ステップし、列を 1 つ進むには s1 要素分ステップする」という意味です。このメタデータを変更すれば、コピーなしで同じ生データに対する異なる*ビュー*を得ることができます。
>> M = torch.tensor([[0, 1],
... [2, 3],
... [4, 5]])
>> M.shape, M.stride()
(torch.Size([3, 2]), (2, 1)) # 行あたり 2 ステップ、列あたり 1 ステップ
>> T = M.t() # トランスポーズ
>> T.shape, T.stride()
(torch.Size([2, 3]), (1, 2)) # 形状とストライドが入れ替わり、データは不変
>> T
tensor([[0, 2, 4],
[1, 3, 5]])
>> T.flatten() # 強制的に実体化されるため、データ順序が入れ替わる
tensor([0, 2, 4, 1, 3, 5])
M.t() は数値を一つも移動させませんでした。これはストライドを入れ替えた新しいビューを返すものであり、行ごとに読み込むと、転置された順序で元のバッファ 0, 1, 2, 3, 4, 5 を辿ることになります。基盤となるデータは同一であり、異なるのはメタデータだけです。
これはまさに linear レイヤー内部で aten::t が行っていることです:新しいテンソルの割り当てやデータのコピーを行わず、書き換えられたストライドを持つ重みの *ビュー* を生成します。
図 5 に示されている通り、compile は GPU カーネルを削除したのではなく、そのビューのディスパッチにかかる *CPU オーバーヘッド* を削除しました。Inductor はコンパイル時にビューチェーンを追跡し、結果として得られるストライドを一度計算して、それらのストライドがハードコードされた直接の aten::addmm 呼び出しを生成します。GPU が同じ数学演算を行う間、数マイクロ秒分の CPU 作業が消滅します。
当然のことながら、入力データがコンパイラによって事前に計算されたストライドに違反した場合、エラーが発生します。
両方のトレースにおける GPU ラーンを見ると、フォワードパスごとにカーネルはちょうど一つであり、その二回とも *同じ* カーネルです:
cutlass_80_wmma_tensorop_bf16_s161616gemm_bf16_32x32_32x1_tn_align8
もし転置カーネルが実行されなかったなら、GEMM が重み行列を転置された順序で読み込むよう誰が教えたのでしょうか?その答えはカーネル名にあります。接尾辞を見てみましょう:
cutlass_80_wmma_tensorop_bf16_s161616gemm_bf16_32x32_32x1_tn_align8
^^
この tn がレイアウト記述子です。cuBLAS と CUTLASS は、入力レイアウトの各組み合わせに対して *個別のカーネルバイナリ* を事前にコンパイルしています。
n(非転置)と t(転置)は、カーネルが内部ループ内で入力をどのように走査するかを示します。ディスパッチャの役割は、入力ストライドを確認し、どの接尾辞の組み合わせが一致するかを判断して、適切な事前コンパイル済みカーネルを選択することです。
**
プロファイラトレースにおけるカーネル名は、カーネルのアイデンティティのハッシュダンプです。2 つの実行で同じカーネル名が表示される場合、GPU は同じ作業を行っています。名前が異なる場合(例:_tn_ と _nn_、bf16 と fp16、または s16816gemm と s161616gemm)、GPU は異なる作業を行っており、ディスパッチャが別の分岐を選択したことを意味します。この名前を読み解く方法を学ぶことは、トレースを比較する際に最も有用な習慣の一つです。
3 つの Linear をスタック:MLP
このセクションでは、Multilayer Perceptron(MLP)をプロファイリングします。より興味深いものにするため、GeGLU アクティベーション変種を持つ順方向ネットワークをプロファイリングします(これは実際には非常に頻繁に使用されています)。また、これは深層学習研究の歴史において書かれた最も素晴らしい行の一つへのオマージュでもあります(図 6)。
Figure 6: GLU Variants Improve Transformer 論文の結論セクション。
class SimpleGeGLUMLP(nn.Module):
def __init__(self, dim, hidden):
super().__init__()
self.gate_proj = nn.Linear(dim, hidden, bias=False)
self.up_proj = nn.Linear(dim, hidden, bias=False)
self.down_proj = nn.Linear(hidden, dim, bias=False)
def forward(self, x):
g = self.gate_proj(x)
u = self.up_proj(x)
h = F.gelu(g, approximate="tanh")
m = h * u
y = self.down_proj(m)
return y
完全なスクリプトはここで見ることができます:03_simple_mlp.py。以下のように実行してください:
uv run 03_simple_mlp.py --batch 64 --seq 128 --dim 768 --hidden 3072
uvx trace-util traces -b traces
トレースを開く前に、一緒に何を期待すべきか考えてみましょう。forward 関数はかなりの計算量を行いますが、その大部分はすでに私たちに馴染みがあります。
各 nn.Linear レイヤーに対して 1 つずつ、合計 3 つの aten::linear ディスパッチが発生することを期待できます。また、GeLU と乗算それぞれに対して 1 つずつ、計 2 つのポイントワイズカーネル起動も予想されます。トレースを見る前にこの期待を形成することは、プロファイリングの旅において最も有用な習慣の一つです:あなたはトレースを読んで推測を*確認または反証するため*に読むのであって、ゼロから推測を作るためではありません。
Figure 7: GeGLU MLP のプロファイラートレース
Figure 8: 線形投影 CPU ラーンにおける"occupancy query"のハイライト
Figure 7 から、私たちの直感が正しかったと自分自身を褒め称えることができます。1 つのフォワードパス(1 つの mlp_fwd)あたり、GPU は正確に 5 つのカーネルを実行します。Figure 8 は、線形投影レイヤーの CPU ラーンで見た"occupancy query"をハイライトしています。
Op
CPU op
GPU kernel
launches
gate_proj
aten::linear
ampere_bf16_s16816gemm_bf16_128x128_...
occupancy query + cudaLaunchKernel
up_proj
aten::linear
ampere_bf16_s16816gemm_bf16_128x128_...
occupancy query + cudaLaunchKernel
gelu
aten::gelu
vectorized_elementwise_kernel<4, GeluCUDAKernelImpl...>
cudaLaunchKernel
h * u
aten::mul
vectorized_elementwise_kernel<4, ...MulFunctor...>
cudaLaunchKernel
down_proj
aten::linear
ampere_bf16_s16816gemm_bf16_128x256_...
occupancy query + cudaLaunchKernel
3 つの GEMM(行列積演算)は、それぞれ起動前に cudaOccupancyMaxActiveBlocksPerMultiprocessor 呼び出しを余分に行っています。これについては Part 1 で別のセクションを設けており、こちらで確認できます。これは cuBLAS がグリッドサイズを決定しているためです。一方、ポイントワイズ演算(GeLU と mul)は、occupancy クエリを行わずに直接起動されます。つまり、「線形演算」は実際には「クエリ+起動」ですが、「ポイントワイズ演算」は単なる「起動」です。
Figure 9: The table shows that some ops launch zero kernels
aten::t, aten::transpose, aten::reshape, aten::view, aten::as_strided, および aten::_unsafe_view の各演算は、ゼロ個のカーネルを起動します。これらは表(Figure 9)で CUDA 実行時間が 0.000us と表示されますが、これは CPU 上でテンソルのメタデータ(形状とストライド)を書き換えるだけだからです。表をスキャンする読者は、線形演算ごとに約 6 つの演算名を目にしますが、GPU に到達するのはそのうちの一つ(mm)だけです。
なぜ GEMM カーネルには 2 種類あるのか?
MLP は、行列積のために [batch, seq, dim] を [batch * seq, dim] にフラット化します。コマンドラインでの呼び出しではバッチサイズに 64、シーケンス長に 128 を使用したため、以下の 8192(バッチサイズ×シーケンス長 = 64 × 128)という値がここから来ています。
トレースより:
Linear
aten::mm input dims
M·K·N
cuBLAS kernel
avg CUDA
gate_proj
[8192,768] x [768,3072]
8192·768·3072
…128x128…stages_32x5_tn
0.19ms
up_proj
[8192,768] x [768,3072]
8192·768·3072
…128x128…stages_32x5_tn
0.19ms
down_proj
[8192,3072] x [3072,768]
8192·3072·768
…128x256…stages_64x3_tn
0.17ms
これら 3 つの GEMM(General Matrix Multiplication: 一般行列積)はすべて同じ FLOP 数、すなわち各々 2·8192·768·3072 ≈ 38.7 GFLOP を有していますが、down_proj は約 10% 高速です。同じ計算量でありながら形状が異なる(N が 3072 ではなく 768)ため、cuBLAS はその形状に対してより良い再利用性を得られる別のタイル(128×256、およびより深い stages_64x3 パイプラインを持つもの)を選択します。
タイル処理について深く学びたい場合は、こちら が素晴らしいリソースです。ここから始められます。
まさにこの理由により、表には 2 つの GEMM 行(Figure 9)が存在しました:128x128 の行は gate と up を示し、128x256 の行は down を示しています。
torch.compile は何をするのか?
フォワードメソッドをコンパイルしてトレースを可視化する前に、再び頭の中で「トレースで何を期待するか」を考える演習を行いましょう。これは楽しい実験であり、自分で何かをプロファイリングするたびに繰り返すべき重要な作業です。常に直感を基盤とし、何かが一致しない瞬間には必ず立ち止まってその理由を探ってください。
uv run 03_simple_mlp.py --batch 64 --seq 128 --dim 768 --hidden 3072 --compile
uvx trace-util traces -b traces
Figure 10: The profiler trace for the compiled GeGLU MLP
Eager モードでは、各 nn.Linear は dispatcher ops の連鎖(aten::linear → aten::t → aten::transpose → aten::matmul → aten::reshape → aten::mm)に展開されます。これらは、ATen が実際の GEMM に到達する前に通る高レベルなラッパーです。torch.compile はこの連鎖を除去します。
コンパイルされたグラフが実行される時点では、linear も matmul も transpose も reshape も存在せず、それらのメタデータ演算は mm の呼び出し方に折り畳まれています。Figure 10 に示すように、3 つの裸の aten::mm 外部呼び出しが見られます。これが同じ GEMM であることの証拠は、カーネル名が eager モードとバイト単位で完全に一致していることです:ゲートおよびアップ用には ...128x128...stages_32x5_tn、ダウン用には ...128x256...stages_64x3_tn です。
The fused Triton kernel
Figure 11: The fused Triton kernel
これはコンパイルの教訓全体の要です。2 つの eager ポイントワイズカーネル(GeLU と mul)と reshape が、1 つのカーネル triton_poi_fused__unsafe_view_gelu_mul_0 に統合されました(Figure 11)。名前を解読してみましょう:
- triton: Inductor の Triton バックエンドによって生成されたもの(cuBLAS でも ATen でもありません)。
- poi: pointwise(Inductor はポイントワイズカーネルに poi、リダクションに red、永続的リダクションに per というタグを付けます)。
- fused__unsafe_view_gelu_mul: 統合された演算:_unsafe_view(reshape)、GeLU、および mul です。
- 0: グラフ内の一意な ID です。
なぜこれが勝利となるのでしょうか? eager モードでは、中間変数 h = gelu(g) は [8192, 3072] の bf16 テンソル(約 50 MB)であり、GeLU カーネルがこれを HBM に書き込み、mul カーネルが即座に読み戻します。融合により、このデータはレジスタ(チップ内部に存在し、HBM よりも近いメモリ)内に保持されます。Triton カーネルは g と u をそれぞれ 1 回だけ読み取り、gelu(g) * u を計算して結果を 1 回だけ書き込みます。これにより、中間データがグローバルメモリを経由する往復のラウンドトリップが完全に消滅します。
ハンドチューニングされたカーネルを使おう
これまで私たちは PyTorch(eager モード)とコンパイラー(torch.compile)にカーネル選択を任せてきました。今度は、人間のエキスパートが手書きで調整したカーネルを組み込みます。ここでは「kernels」ライブラリから容易に取得できる LigerGEGLUMLP レイヤーを使用します。
from kernels import get_kernel
kernels_layers = get_kernel("kernels-community/liger-kernels", version=1).layers
kernels_geglu_mlp = kernels_layers.LigerGEGLUMLP(Config()).to(device, dtype=torch.bfloat16).eval()
完全なスクリプトはこちらです:03_kernels_mlp.py。
uv run 03_kernels_mlp.py --batch 64 --seq 128 --dim 768 --hidden 3072
uvx trace-util traces -b traces
Figure 12: The profiler trace for the LigerGEGLUMLP layer
Figure 12 shows the profile for the LigerGEGLUMLP layer using the Liger kernels from the Hub.
なぜカーネルライブラリを使うのか
Triton や CUDA でカーネルを書くことは一つの課題ですが、それらを*配布する*ことは別の問題です。このカーネルは、お使いの GPU アーキテクチャ、CUDA バージョン、PyTorch バージョンという正確な組み合わせのためにコンパイルされなければなりません。これが通常失敗するステップです(「自分のマシンでは動く」、nvcc の欠如、間違った Triton バージョンなど)。
kernels ライブラリは、このビルドステップをお使いのマシンから外します。get_kernel("kernels-community/liger-kernels", version=1) は、Hugging Face Hub から事前ビルドされバージョンが固定されたカーネルパッケージをダウンロードし、ローカル(ここでは ~/.cache/...kernels-community--liger-kernels 以下)にキャッシュします。その利点は以下の通りです。
- カーネルは CI 内で一度だけコンパイルされ、多くのアーキテクチャとバージョンの組み合わせに対応しています。自分でコンパイルするのではなく、正しいバイナリをダウンロードできます。
version=1は正確なビルドを固定するため、スクリプトを実行する誰もが同じカーネルを使用します。「パッケージを更新したら遅くなった」という現象は発生しません。
- パッケージには
.layers属性が公開されており、そこにはドロップイン型の nn.Modules(LigerGEGLUMLP など)が含まれています。モデルの他の部分は一切変更せず、モジュールをそれらに置き換えるだけです。
なぜチューニングされたカーネルの方が優れているのか
「チューニングされた」と言うとき、私たちは2つの具体的なことを指しており、どちらもトレース(trace)で確認できます。
Figure 13: コンパイル済み実行では、GEMM が実行される前に Dynamo、ガード、プロローグなどの事前演算のコストが発生する
Figure 14: Liger カーネルには事前演算がない — 本来そこにあるべきボックスは空である
- この融合は組み込まれています。LigerGEGLUMLP の forward は down_proj(LigerGELUMulFunction.apply(gate_proj(x), up_proj(x))) です。LigerGELUMulFunction は単一の Triton カーネル _geglu_tanh_forward_kernel を実行し、gelu(gate) * up を 1 パスで計算します。これは torch.compile から見たものと同じで、中間結果が HBM(High Bandwidth Memory)を経由して往復することはありません。図 13 と図 14 に示すように、ここではコンパイラなしでこれを実現しています(Dynamo ガードも不要、コンパイル遅延もなし、再コンパイルのリスクもありません)。
- ローンチパラメータはハードウェアに合わせて選択されています。このカーネルがランダムにブロックサイズを推測することはありません。Liger の calculate_settings は列数からそれらを選択します。
ここでトレードオフについて正直になる価値があります。なぜなら、生データだけを見ると誤解を招く可能性があるからです。Liger カーネルの実行時間は92.8 µsですが、コンパイル実行における Inductor の融合カーネルは89.4 µsです。一見すると手書きのカーネルがわずかに遅いように見えますが、この比較にはその価値を正当化するコストが見落とされています。
torch.compile は静的形状に対して最適化されます。Inductor の 89.4 µs カーネルが高速なのは、まさに*この特定の*[8192, 3072]問題のために生成されたからです。バッチサイズやシーケンス長、隠れ次元を変更すると、Dynamo は再トレースを行い、新しい特殊化カーネルを得るために再びコンパイルコストを支払う必要があります。
つまり、真の選択は「遅い人手によるカーネル vs 高速なコンパイル済みカーネル」ではありません。それは汎用的に高速なカーネル vs 特定の1つの入力形状に特化したカーネルです。Liger カーネルは1セットのパラメータを起動パラメータとして受け取り、再コンパイルを行うことなく*あらゆる*形状に対して実行します。これは、形状ごとの特化がもたらす数マイクロ秒の速度向上を犠牲にする代わりに、変化する形状に対する堅牢性を獲得するものです。
結論
以下の表は、各ステップで GPU に何の変更を加え、何が変更されなかったかをまとめています。
Setup | What changed | What stayed the same
---|---|---
Eager nn.Linear | ベースライン:バイアス加算はすでに GEMM エピローグ(addmm)に折り込まれており、*1 つの* cuBLAS カーネルであり、行列乗算と加算の組み合わせではない | —
Compiled nn.Linear | いくつかの CPU ディスパッチ演算(aten::t ビューのブックキーピング)が消失 | 同じ単一の cuBLAS GEMM カーネルで、バイト単位でも同一。コンパイルでは融合は行われない
Eager MLP | GPU カーネル5個:3 つの GEMM + GeLU + 乗算。[8192, 3072] の中間結果が HBM をフルラウンドトリップする | 各 GEMM は、独立した線形演算と同じバイアスなし cuBLAS カーネルである
Compiled MLP | GeLU + 乗算 + リシェイプが1 つの融合された Triton カーネルに統合され、中間結果はレジスタ内に残る。コンパイル前処理(Dynamo、ガード)のコストを払う | 3 つの GEMM は変更されず、同一の cuBLAS カーネル名を使用
Liger MLP | 同じ融合だが、ハードウェアチューニングされた起動パラメータを持つ手書きの Triton カーネルに組み込まれており、Dynamo、ガード、コンパイル遅延は不要 | 3 つの GEMM は依然として同じ cuBLAS カーネル
継続すべき習慣は一つだけあります。それは、以前に各トレースを行う前に実践していた「まず推測し、その後確認する」というものです。トレースに含まれると予想される内容を明確にし、実際に開いてみて、不一致が見られたらそれを画面で最も興味深いものとして扱ってください。
これはPyTorch におけるプロファイリングシリーズの二回目の投稿でした。次の投稿では、この MLP ブロックからアテンションブロックへと、そして最終的には完全なモデルへと、さらに階段を上って進んでいきます。
記事の初期草案に対するレビューを寄せてくださった Noe Flandre 氏と Pedro Gabriel Gengo Lourenço 氏に感謝いたします!
原文を表示
In the first part of this series "Profiling in PyTorch", we used torch.add(torch.matmul(x, w), b) to learn how to read PyTorch profiler traces. We also discussed several other topics that came our way - the CPU dispatch chain, launch overhead, the difference between an overhead-bound and a compute-bound regime, and some internals of torch.compile.
In the second iteration (this blog post), we climb one rung up the ladder. We replace the hand-written matmul-add pair with an nn.Linear (with bias=True). This is the building block every deep learning model uses. We then stack three of them (specific to our example), with an activation in between, to form a Multilayer Perceptron (MLP) block.
The scripts for this blog post live here: 02_linear.py, 03_simple_mlp.py, and 03_kernels_mlp.py. Like before, it helps to open them in a separate tab and walk through the code as you read. We use an NVIDIA A100-SXM4-80GB GPU to run the scripts. It is really easy to set up a GPU on the Hugging Face infrastructure and experiment with the scripts using Dev Mode with Spaces. One could also run the scripts with the Hugging Face Jobs pipeline.
Before we begin, a quick recap of two ideas we will lean on repeatedly:
- A GPU kernel is a program that runs in parallel on many threads of the GPU.
- The CPU schedules and launches these kernels. Most of the PyTorch overhead you see in a profiler trace is this scheduling work.
From matmul-add to Linear
nn.Linear is a module wrapper around the same matrix multiplication and addition we already profiled in Part 1. The only difference is that it owns its weight and bias as parameters and exposes a forward method that PyTorch users have grown familiar with.
# bias=True would truly emulate the multiplication and addition
# operations we have seen in part 1 of the series
linear_layer = nn.Linear(in_dim, out_dim, bias=True)
y = linear_layer(x)
The operation at hand can be written as:
y = x @ w.T + b
Where x is the input, w is the weight and b is the bias. Let's run 02_linear.py and check the profile.
uv run 02_linear.py --batch 1024 --in_dim 32 --out_dim 64
uvx trace-util traces -b traces
trace-util is a utility that will sync your traces to a Hugging Face bucket and then provide the Preffeto URLs on your terminal.
Figure 1: Profiler trace of nn.Linear
Figure 1 shows the profiler trace of a forward call of the linear layer. We trace the forward call of the linear layer with a similar schedule setup as the previous traces, with wait=1, warmup=1 and active=3. This is why we see three Profile Steps in the CPU and GPU lanes.
What is the transpose doing?
Figure 2: The transpose CPU row
If we zoom into the profiler trace, as we do in Figure 2, we notice an aten::t (transpose) op before the aten::addmm (multiplication and addition) op. We can already figure out that nn.Linear transposes the weight parameter and then multiplies it with the input. This is the reason we see an aten::t op.
An important thing to notice is that aten::t does not really copy or reorganize data: it only rewrites tensor metadata (shape and stride) on the CPU to represent the transposed matrix. It does not launch a kernel on the GPU. One can verify this two ways: by looking at the GPU lane in the trace, or by checking the aten::t row in the profiler table and the time it took on CUDA.
Why are there no separate mul and add kernels?
Figure 3: No aten::add in the profile of a linear layer
There is no aten::add (the bias addition) in the dispatch chain of the linear layer, as seen in Figure 3. This is because the bias addition has been *folded* into the matrix multiplication kernel, using what is called an epilogue.
An epilogue is a small computation that a GEMM (GEneral Matrix Multiply) kernel does at the very end, just before it writes its result back to HBM (High Bandwidth Memory, the GPU's main memory). Adding a bias, applying an activation, or scaling by a constant are all classic epilogues. The point of an epilogue is to avoid loading or writing to HBM a second time, since memory traffic makes an operation expensive.
nn.Linear calls torch.nn.functional.linear, which, in turn, calls aten::linear. aten::linear looks at the inputs, notices that a bias was passed, and dispatches aten::addmm(bias, x, weight) instead of doing a matmul and an add separately. addmm computes:
out = x @ weight.T + bias
The cuBLAS GEMM kernel that runs on the GPU has a bias-add variant built in, and that's the kernel aten::addmm picks. The add never appears as a separate kernel because it is part of the matmul kernel's writeback, which is exactly what an epilogue is.
This is the moment to notice something subtle. The kernel you saw in Part 1 under --compile (addmm) is the kernel that eager nn.Linear already uses. There is nothing left for torch.compile to fuse here, which is the next thing we will verify.
Can --compile help a single Linear?
Let's compile the forward call and look at the profiler trace. (The profiler trace is visualized in the next section)
uv run 02_linear.py --batch 1024 --in_dim 32 --out_dim 64 --compile
uvx trace-util traces -b traces
If you compare the eager and compiled traces for a single nn.Linear's forward, you will find:
- The same cuBLAS GEMM kernel on the GPU.
- The same aten::addmm op on the CPU.
- A few extra rows on the CPU lane unique to compile.
This is worth internalizing. A common reflex is to reach for torch.compile whenever a model feels slow. For a single GEMM-with-bias, compile has very little to do. This is not a bug, this is just that compile needs more than one operation to possibly do any fusing. Let's prove that by looking at an MLP.
Where did the transpose go? Kernel layouts and pre-ops
A careful reader of the two traces (eager vs compile) will notice that the eager CPU dispatch chain has more in it than the compiled one.
Figure 4: Eager dispatch chain where aten::linear walks through aten::t (transpose) and then aten::addmm
Figure 5: Compiled dispatch chain where aten::addmm is called directly, with no transpose
The eager CPU dispatch chain inside aten::linear is aten::t followed by aten::addmm (Figure 4). To understand what aten::t actually does, we need a quick detour into *strides* and *views*.
A tensor stores its data as one flat, contiguous run of numbers in memory. The shape and stride are metadata that sit on top of that run and tell PyTorch how to walk it: a stride of (s0, s1) means "step s0 elements to move one row, step s1 to move one column". Change the metadata and you get a different *view* of the *same* raw data, with no copy:
>>> M = torch.tensor([[0, 1],
... [2, 3],
... [4, 5]])
>>> M.shape, M.stride()
(torch.Size([3, 2]), (2, 1)) # two steps per row, one step per column
>>> T = M.t() # transpose
>>> T.shape, T.stride()
(torch.Size([2, 3]), (1, 2)) # shape and stride swapped, data untouched
>>> T
tensor([[0, 2, 4],
[1, 3, 5]])
>>> T.flatten() # forced to materialize, so the data is reordered
tensor([0, 2, 4, 1, 3, 5])
M.t() did not move a single number. It returned a new view whose strides are swapped, so reading it row-by-row now walks the original buffer 0, 1, 2, 3, 4, 5 in transposed order. The underlying data is identical; only the metadata differs.
This is exactly what aten::t does inside the linear layer: it does not allocate a new tensor or copy any data, it produces a *view* of the weight with rewritten strides.
As we can see in Figure 5, compile did not remove a GPU kernel: it removed the *CPU overhead* of dispatching that view. Inductor traced through the view chain at compile time, computed the resulting strides once, and emitted a direct aten::addmm call with those strides hard-coded. A few microseconds of CPU work disappear while the GPU does identical math.
As one would expect, when the input data violates the strides precomputed by the compiler, it will throw an error.
If you look at the GPU lane in both traces, there is exactly one kernel per forward, and it is the *same* kernel both times:
cutlass_80_wmma_tensorop_bf16_s161616gemm_bf16_32x32_32x1_tn_align8
If no transpose kernel ran, who taught the GEMM to read the weight matrix in transposed order? The answer is in the kernel's name. Look at the suffix:
cutlass_80_wmma_tensorop_bf16_s161616gemm_bf16_32x32_32x1_tn_align8
^^
That tn is the layout descriptor. cuBLAS and CUTLASS precompile a *separate kernel binary* for each combination of input layouts.
n (non-transposed) and t (transposed) describe how a kernel walks its input during the inner loop. The dispatcher's job is to look at the input strides, decide which suffix combination matches, and pick the right precompiled kernel.
The kernel name in a profiler trace is a hash dump of the kernel's identity. If two runs show the same kernel name, the GPU is doing the same work. If they differ (e.g., _tn_ vs _nn_, bf16 vs fp16, or s16816gemm vs s161616gemm) then the GPU is doing different work, and the dispatcher took a different branch. Learning to read this name is one of the most useful habits when comparing traces.
Stacking three Linears: the MLP
In this section, we will profile a Multilayer Perceptron (MLP). To make this more interesting, we will profile a feed-forward network with the GeGLU activation variant (which is quite heavily used in practice). This is also our way of paying tribute to one of the greatest lines ever written in the history of deep learning research (Figure 6).
Figure 6: The conclusion section of the GLU Variants Improve Transformer paper.
class SimpleGeGLUMLP(nn.Module):
def __init__(self, dim, hidden):
super().__init__()
self.gate_proj = nn.Linear(dim, hidden, bias=False)
self.up_proj = nn.Linear(dim, hidden, bias=False)
self.down_proj = nn.Linear(hidden, dim, bias=False)
def forward(self, x):
g = self.gate_proj(x)
u = self.up_proj(x)
h = F.gelu(g, approximate="tanh")
m = h * u
y = self.down_proj(m)
return y
You will find the entire script here: 03_simple_mlp.py. Execute it like so:
uv run 03_simple_mlp.py --batch 64 --seq 128 --dim 768 --hidden 3072
uvx trace-util traces -b traces
Before we open the trace, let's think together about what we should expect to see. The forward function does a fair amount of computation, but most of it is already familiar to us.
We should expect three aten::linear dispatches, one for each nn.Linear layer. We should also expect two pointwise kernel launches, one for the GeLU and one for the multiplication. Forming this expectation before looking is the single most useful habit in the profiling journey: you read the trace to *confirm or break* a guess, not to form one from scratch.
Figure 7: The profiler trace for a GeGLU MLP
Figure 8: The occupancy queries highlighted in the linear projection CPU lane
From Figure 7 we can pat ourselves on the back, as our intuition was correct. Per forward pass (one mlp_fwd), the GPU runs exactly 5 kernels. Figure 8 highlights the "occupancy query" as seen in the CPU lane for the linear projection layers.
Op
CPU op
GPU kernel
launches
gate_proj
aten::linear
ampere_bf16_s16816gemm_bf16_128x128_...
occupancy query + cudaLaunchKernel
up_proj
aten::linear
ampere_bf16_s16816gemm_bf16_128x128_...
occupancy query + cudaLaunchKernel
gelu
aten::gelu
vectorized_elementwise_kernel<4, GeluCUDAKernelImpl...>
cudaLaunchKernel
h * u
aten::mul
vectorized_elementwise_kernel<4, ...MulFunctor...>
cudaLaunchKernel
down_proj
aten::linear
ampere_bf16_s16816gemm_bf16_128x256_...
occupancy query + cudaLaunchKernel
The three GEMMs each do an extra cudaOccupancyMaxActiveBlocksPerMultiprocessor call before the launch. We have a separate section on this in Part 1, you can find it here. That is cuBLAS sizing the grid. The pointwise ops (GeLU and mul) launch directly, with no occupancy query. So "a linear" is actually query + launch, while "a pointwise op" is just launch.
Figure 9: The table shows that some ops launch zero kernels
The aten::t, aten::transpose, aten::reshape, aten::view, aten::as_strided, and aten::_unsafe_view ops launch zero kernels. They show 0.000us of CUDA time in the table (Figure 9) because they only rewrite tensor metadata (shape and stride) on the CPU. A reader scanning the table sees around six op names per linear, but only one of them (mm) ever reaches the GPU.
Why are there two types of GEMM kernels?
The MLP flattens [batch, seq, dim] to [batch * seq, dim] for the matmul. In our command-line invocation we used 64 for batch and 128 for seq, so that's where the 8192 (batch * seq = 64 * 128) below comes from.
From the trace:
Linear
aten::mm input dims
M·K·N
cuBLAS kernel
avg CUDA
gate_proj
[8192,768] x [768,3072]
8192·768·3072
…128x128…stages_32x5_tn
0.19ms
up_proj
[8192,768] x [768,3072]
8192·768·3072
…128x128…stages_32x5_tn
0.19ms
down_proj
[8192,3072] x [3072,768]
8192·3072·768
…128x256…stages_64x3_tn
0.17ms
All three GEMMs have the same FLOP count, 2·8192·768·3072 ≈ 38.7 GFLOP each, yet down_proj is about 10% faster. Same work, different shape (N=768 instead of 3072), so cuBLAS picks a different tile (128×256, with a deeper stages_64x3 pipeline) that gets better reuse for that shape.
If you want to learn more about tiling in depth, here is a great resource to get started with.
This is exactly why the table had two GEMM rows (Figure 9): the 128x128 row is gate+up and the 128x256 row is down.
What does torch.compile do?
Before compiling the forward method and visualizing it, let's do the mental exercise again of asking ourselves what we expect to see in the trace. This is a fun experiment, and an important one to repeat every time you profile something yourself. Always build on your intuition, and the moment something does not match, stop and figure out why.
uv run 03_simple_mlp.py --batch 64 --seq 128 --dim 768 --hidden 3072 --compile
uvx trace-util traces -b traces
Figure 10: The profiler trace for the compiled GeGLU MLP
In eager mode, each nn.Linear was expanded into a chain of dispatcher ops (aten::linear → aten::t → aten::transpose → aten::matmul → aten::reshape → aten::mm). Those are the high-level wrappers that ATen walks through before reaching the real GEMM. torch.compile removes that chain.
By the time the compiled graph runs, there is no linear, no matmul, no transpose or reshape and those metadata ops were folded into how mm is called. We can see three bare aten::mm external calls (Figure 10). The proof that it is the same GEMM is that the kernel names are byte-for-byte identical to eager: ...128x128...stages_32x5_tn for gate and up, and ...128x256...stages_64x3_tn for down.
The fused Triton kernel
Figure 11: The fused Triton kernel
This is the headline of the whole compile lesson. The two eager pointwise kernels (GeLU and mul) plus a reshape collapsed into one kernel, triton_poi_fused__unsafe_view_gelu_mul_0 (Figure 11). Let's decode the name:
- triton: generated by Inductor's Triton backend (not cuBLAS, not ATen).
- poi: pointwise (Inductor tags pointwise kernels poi, reductions red, and persistent reductions per).
- fused__unsafe_view_gelu_mul: the ops it merged: the _unsafe_view (reshape), the GeLU, and the mul.
- 0: the unique id within the graph.
Why is this a win? In eager mode, the intermediate h = gelu(g) is a full [8192, 3072] bf16 tensor (around 50 MB) that the GeLU kernel writes to HBM and the mul kernel immediately reads back. Fusion keeps it in registers (memory that resides inside the chip and are closer than the HBM). The Triton kernel reads g and u once, computes gelu(g) * u, and writes the result once. One whole round trip of the intermediate through global memory is gone.
Let's use hand tuned kernels
So far we have let PyTorch (eager) and the compiler (torch.compile) pick our kernels. Now we plug in a kernel that a human expert wrote and tuned by hand. We use the LigerGEGLUMLP layer, that we can easily fetch from the Hugging Face Hub with the kernels library.
from kernels import get_kernel
kernels_layers = get_kernel("kernels-community/liger-kernels", version=1).layers
kernels_geglu_mlp = kernels_layers.LigerGEGLUMLP(Config()).to(device, dtype=torch.bfloat16).eval()
The full script is here: 03_kernels_mlp.py.
uv run 03_kernels_mlp.py --batch 64 --seq 128 --dim 768 --hidden 3072
uvx trace-util traces -b traces
Figure 12: The profiler trace for the LigerGEGLUMLP layer
Figure 12 shows the profile for the LigerGEGLUMLP layer using the Liger kernels from the Hub.
Why use the kernels library
Writing kernels in Triton or CUDA is one problem and *shipping* them is another. The kernel has to be compiled for your exact combination of GPU architecture, CUDA version, and PyTorch version. This is the step that usually breaks ("works on my machine", missing nvcc, wrong Triton version).
The kernels library moves that build step off your machine. get_kernel("kernels-community/liger-kernels", version=1) downloads a pre-built, version-pinned kernel package from the Hugging Face Hub and caches it locally (here under ~/.cache/...kernels-community--liger-kernels). The benefits are:
- The kernels are compiled once, in CI, for many architectures and version combinations. You download the right binary instead of compiling it yourself.
- version=1 pins the exact build, so everyone running your script gets the same kernel. There is no "it got slower after I updated a package".
- The package exposes a .layers attribute with drop-in nn.Modules (like LigerGEGLUMLP). You swap your module for theirs and nothing else in your model changes.
Why tuned kernels are better
When we say "tuned", we mean two concrete things, and both are visible in the trace.
Figure 13: The compiled run pays for pre-ops (Dynamo, guards, prologue) before any GEMM runs
Figure 14: The Liger kernel has no pre-ops — the box where they would be is empty
- The fusion is baked in. The LigerGEGLUMLP forward is down_proj(LigerGELUMulFunction.apply(gate_proj(x), up_proj(x))). The LigerGELUMulFunction runs a single Triton kernel, _geglu_tanh_forward_kernel, that computes gelu(gate) * up in one pass. This is exactly what we saw from torch.compile, where the intermediate never makes a round-trip through HBM. We get it here without the compiler, as shown in Figures 13 and 14 (no Dynamo guards, no compile latency, no recompilation risk).
- The launch parameters were chosen for the hardware. The kernel does not guess its block size at random. Liger's calculate_settings picks them from the column count.
It is worth being honest about the trade-off here, because the raw numbers can be misleading. The Liger kernel runs in 92.8 µs, while Inductor's fused kernel from the compile run was 89.4 µs. At first glance the hand-written kernel looks slightly slower, but that comparison hides the cost that makes it worthwhile.
torch.compile specializes for a static shape. Inductor's 89.4 µs kernel is fast precisely because it was generated for *this exact* [8192, 3072] problem. Change the batch size, the sequence length, or the hidden dimension, Dynamo re-traces, and you pay the compile cost all over again to get a new specialized kernel.
So the real choice is not "slow human kernel vs fast compiled kernel". It is a fast generic kernel vs a kernel specialized for one particular input shape. The Liger kernel takes one set of launch parameters and runs them for *any* shape with no recompilation. It gives up the last few microseconds that per-shape specialization would buy, in exchange for being robust to changing shapes.
Conclusion
The table below collects what each step changed on the GPU and what it left untouched.
Setup
What changed
What stayed the same
Eager nn.Linear
Baseline: bias add is already folded into the GEMM epilogue (addmm), so it is *one* cuBLAS kernel, not a matmul plus an add
—
Compiled nn.Linear
A few CPU dispatch ops (the aten::t view bookkeeping) disappear
Same single cuBLAS GEMM kernel, byte-for-byte. Compile has nothing to fuse
Eager MLP
5 GPU kernels: 3 GEMMs + a GeLU + a mul. The [8192, 3072] intermediate makes a full round-trip through HBM
Each GEMM is still the same bias-free cuBLAS kernel as a standalone linear
Compiled MLP
GeLU + mul + reshape collapse into one fused Triton kernel; the intermediate stays in registers. Pays compile pre-ops (Dynamo, guards)
The 3 GEMMs are untouched with identical cuBLAS kernel names
Liger MLP
Same fusion, but baked into a hand-written Triton kernel with hardware-tuned launch params with no Dynamo, guards, or compile latency
The 3 GEMMs are still the same cuBLAS kernels
If there is one habit to carry forward, it is the one we practiced before every trace: guess first, then look. State what you expect the trace to contain, open it, and treat any mismatch as the most interesting thing on the screen.
This was the second stop in the Profiling in PyTorch series. In the next post we will keep climbing the ladder, moving from this MLP block towards the attention block and, eventually, a full model.
Thanks to Noe Flandre and Pedro Gabriel Gengo Lourenço for their reviews on the early draft of the post!
関連記事
PyTorch のプロファイリング(第 1 部):torch.profiler を始めるための初心者ガイド
Hugging Face Blog が、PyTorch のパフォーマンス解析ツールである torch.profiler の基本的な使い方と導入方法を解説した入門記事を発表しました。
AI エンジニアが知っておくべき Python の必須概念 5 つ
KDnuggets は、AI エンジニアが習得すべき Python の重要な概念を 5 つ紹介する記事を発表しました。
TRL でデルタ重み同期を実装:トリリオンパラメータをハブバケットで管理
TRL は非同期強化学習において、変更されたモデルパラメータのみを送信する「デルタ重み同期」手法を導入し、データ転送量をギガバイトからメガバイトに削減した。また、Hugging Face Hub のバケット機能を活用して学習器と推論エンジンの通信を分離し、帯域幅の大幅な節約を実現した。
今日のまとめ
AI日報で今日の重要ニュースをまとめ読み