エンドツーエンドFP8精度による高スループット強化学習トレーニングの実行
NVIDIAはLLMの複雑な推論能力向上を目指す強化学習(RL)トレーニングにおいて、FP8精度をエンドツーエンドで活用することで、高スループットな学習環境を実現する技術ガイドを発表した。
キーポイント
RLの核心化とFP8適用
LLMが単純生成から複雑推論へ移行する中、強化学習(RL)が学習の中心となり、FP8精度の適用が不可欠となる。
エンドツーエンドのFP8最適化
計算とメモリ帯域を最適化するFP8形式をトレーニング全体に適用し、スループットと学習効率を大幅に向上させる。
最新アルゴリズムとの連携
Group Relative Policy Optimization(GRPO)などの先進的なRLアルゴリズムとFP8ハードウェアサポートを組み合わせ、実用的な学習パイプラインを構築する。
影響分析・編集コメントを表示
影響分析
FP8エンドツーエンドのRLトレーニングは、大規模モデルの推論能力向上に必要な計算リソースを現実的な範囲に収める突破口となる。これにより、AIラボの学習コストが劇的に削減され、競争激化する推論型LLM開発のアクセラレーションが期待される。
編集コメント
FP8をRLトレーニング全体に適用するアプローチは、メモリ帯域のボトルネックを解消し、次世代推論モデルの開発コストを劇的に下げる可能性を秘めている。実装ガイドの公開は、業界全体の学習効率標準化を加速させる一因となるだろう。
LLMが単純なテキスト生成から複雑な推論へと移行するにつれ、強化学習(Reinforcement Learning: RL)が中心的な役割を果たしています。グループ相対方策最適化(Group Relative Policy Optimization: GRPO)などのアルゴリズムがこの移行を推進し、推論グレードのモデルが反復フィードバックを通じて継続的に改善できるようになります。標準的な教師ありファインチューニング(Supervised Fine-Tuning)とは異なり、RLのトレーニングループは、厳格なレイテンシ要件を持つ生成フェーズと、高いスループットを必要とするトレーニングフェーズという2つの明確で高負荷なフェーズに分かれます。
これらのワークロードを実用化するため、研究者やエンジニアはトレーニング性能とスループット指向の生成を向上させるために、FP8などの低精度データ型に注目しています。さらに、生成がGPUメモリ帯域幅によって制約される一部のシナリオでは、パラメータあたりのバイト数が少なくなるため、低精度パラメータを使用することでパフォーマンスが向上することもあります。
この記事では、低精度RLのシステム的な課題に深く迫り、NVIDIA NeMoフレームワーク内のオープンソースライブラリであるNVIDIA NeMo RLが、精度を維持しながらRLワークロードをどのように高速化するかを検証します。
RLにおける線形層(Linear Layers)のFP8
当社のレシピは、DeepSeek-V3 Technical Reportで導入されたブロック単位量子化FP8(Block-wise Quantized FP8)を使用しています。表1は、線形射影層(Linear Projection Layers)におけるテンソル形式(Tensor Formats)の詳細を示しています。データ型 / 量子化粒度 / スケーリングファクター / スケーリングタイプ:重み / FP8 (E4M3) / [128, 128] / FP32 / ブロック単位、入力活性化値 / FP8 (E4M3) / [1, 128] / FP32 / ブロック単位、出力勾配 / FP8 (E4M3) / [1, 128] / FP32 / ブロック単位。表1. 線形射影層におけるテンソル形式
本レシピにより、線形層はBF16演算と比較してピークスループットが2倍のFP8演算で計算できます。アテンション(Attention)、正規化(Normalization)、非線形関数(Non-Linear Functions)、出力射影(Output Projections)を含む他のモジュールは、BF16演算で計算されます。
RLにおける数値不一致(Numerical Disagreement)の課題
RLパイプラインは通常、ロールアウト(Rollouts)にvLLMを、トレーニングにNVIDIA Megatron Coreを使用する別々のエンジンを利用します。それぞれはパフォーマンス最大化のために独自の専用NVIDIA CUDAカーネル(CUDA Kernels)を採用しています。これにより本質的に数値の差異が生じ、追加の量子化・逆量子化ロジック(Quantization and Dequantization Logic)により低精度において累積的に増幅されます。この数値差異をトークン乗算確率誤差(Token Multiplicative Probability Error)として定量化します:
完全な一致にはスコア1が与えられ、追加の手法を使用しない場合、『許容範囲』の値は通常<1.03〜1.05となります。
線形層におけるエンドツーエンドのFP8は数値不一致を低減する
FP8レシピの開発過程中、以下の3つのレシピを実験しました:
ベースラインレシピ:生成とトレーニングの両方にBF16を使用。
レシピ候補1:生成時のみFP8を適用し、ポリシーモデル(Policy Model)のトレーニングはBF16で実施。
最終レシピ:エンドツーエンドFP8:生成エンジンとトレーニングエンジンの両方でFP8を使用。
生成(generation)にのみFP8を使用するレシピ候補1と比較して、最終的なレシピは生成と学習の間の数値不一致(numerical disagreement)が常に低くなっていることが確認できます。なお、ベースラインのレシピは常に最も低い数値不一致を示します。図1は、3つのレシピにおけるトークン乗算確率誤差(token multiplicative probability error)指標を示しています。
図1. 3つのレシピにおけるトークン乗算確率誤差
インポータンスサンプリングによる数値不一致の軽減
インポータンスサンプリング(importance sampling)は、データを生成するモデル(つまり分布)と学習対象となるモデル(つまり分布)の間の分布ミスマッチを補正するために使用されます。これは、損失に掛けられるトークンごとの重みです。インポータンスサンプリングの詳細な理論的背景については、当社のGRPOドキュメントを参照してください。
実験結果は以下の通りです:
レシピ候補1(FP8生成とBF16学習)の場合、インポータンスサンプリングはBF16強化学習(RL)との精度差を縮小できますが、完全に埋めることはできません。
最終的なレシピ(エンドツーエンドFP8)の場合、インポータンスサンプリングはBF16学習との精度差を完全に埋めます。図2は、異なるレシピにおける学習中の検証精度(validation accuracy)を示しています。
図2. Llama 3.1 8B Instructモデルおよび数学データセットにおけるGRPO学習の検証精度
FP8線形層(Linear Layer)エンドツーエンド(E2E)の結果
我々は、デンスモデル(dense models)およびMixture-of-Experts(MoE)モデルの両方で、エンドツーエンドFP8レシピを評価し、BF16ベースラインに対する検証精度と学習スループット(training throughput)を測定しました。
デンスモデルにおけるFP8エンドツーエンド:Llama 3.1 8B Instruct
表2は、Llama 3.1 8B instructモデルおよび数学データセットに対して4000ステップまで学習したGRPOトレーニングにおける、FP8エンドツーエンドレシピとBF16レシピの精度を示しています。
精度 BF16 FP8生成のみ FP8エンドツーエンド 検証精度 0.616 0.586 0.613 表2:異なる精度設定におけるLlama3 8Bの検証精度結果
スピードアップの観点では、FP8レシピはBF16と比較して一貫して>15%のスループット向上を達成しています。図3は、2つのレシピにおける1000ステップのGRPO学習(GPUあたりのトークン数/秒)を示しています。
図3. 2つのレシピのスループット(GPUあたりのトークン数/秒)(青:BF16、ピンク:FP8エンドツーエンド)
FP8のBF16に対する理論的な速度向上は2倍ですが、実際には線形層のみが高速な数学演算スループットの恩恵を受け、アテンション(attention)や要素ごとの演算層(elementwise layers)は同じままのため、これよりも低くなります。線形層の前に追加される量子化カーネル(quantization kernels)により、いくつかのオーバーヘッドが発生します。15%〜25%の速度向上は、vLLMのスタンドアロンテスト結果と一致します。vLLMにおける量子化カーネルの融合(fusing)などのさらなる最適化により、速度向上はさらに1.25倍まで改善できると見込んでいます。
MoEモデルにおけるFP8エンドツーエンド:Qwen3-30B
同様の実験をMixture-of-Experts(MoE)モデルで実行し、Qwen3-30Bの結果は一致する精度曲線を示しました。FP8はBF16と同様の精度を達成しています。速度向上については現在調査中です。
Figure 4. Qwen3-30B GRPOのOpenMathInstruct-2データセットを用いた8ノードH100上での精度曲線。青はBF16、ピンクはFP8エンドツーエンド
KVキャッシュとアテンションへのFP8の拡張
Transformerモデルにおいて、線形層が唯一のボトルネックとなるわけではありません。KVキャッシュの成長とアテンション計算は、長い出力シーケンス長(OSL: Output Sequence Length)を伴う強化学習(RL: Reinforcement Learning)ワークフローにおいて、エンドツーエンドのロールアウト時間を支配することが多く、メモリ帯域幅を飽和させつつトークン生成を遅らせる要因ともなります。このことが、RLのループ内でKVキャッシュとアテンションにFP8を適用する探索を促しました。ここではPer-tensor scaling FP8(テンソル単位スケーリングFP8)が使用されます。
強化学習環境においてKVキャッシュにFP8を適用するのは、ポリシー重みがステップごとに更新されるため、特有の課題となります。一度だけキャリブレーション(較正)が行われる静的推論とは異なり、RLでは量子化スケールを動的に処理する必要があります。
NeMo RLは、この課題を解決するために以下のアプローチを採用しています:
再較正(Recalibration):各トレーニングステップの終了時、トレーナーは更新されたポリシー重みを用いて、Query, Key, Value(QKV)のスケールを再較正します。
データ選択(Data selection):この較正は、現在の分布を反映したスケールを保証するため、トレーニングデータ(プロンプトと生成された応答)を用いて実行されます。
同期(Synchronization):新たに計算されたスケールは、次のロールアウトフェーズのために推論エンジン(vLLM)に同期されます。
Figure 5. FP8 KVキャッシュを用いたRLワークフロー
この設計により、ロールアウトエンジンは常に最新のポリシー状態から導き出された最適な量子化スケールを使用し、精度の低下を最小限に抑えます。較正によるオーバーヘッドは極めて小さく、総ステップ時間の約2〜3%を消費するだけです。
データのTensorType、スケーリング係数、スケールの種類:QKVアテンション活性化(FP8 (E4M3)、FP32、Tensor-wise)、保存済みKVキャッシュ(FP8 (E4M3)、FP32、Tensor-wise)
表3:アテンション活性化と保存済みKVキャッシュのテンソル形式
KVキャッシュとアテンションにおけるFP8の結果概要
Qwen3-8B-Baseモデルを用いてGRPOアルゴリズムで実験を実施し、ロールアウトにFP8を適用してトレーニングにはBF16を使用しました。KVキャッシュとアテンションの両方を量子化すると、誤差が累積するためミスマッチKLダイバージェンス(KL divergence)がやや高くなる傾向がありますが、当社のレシピにより不安定性は緩和されます。トークンレベルの切り捨て重要度サンプリング(truncated importance sampling)を有効にすることで、線形層+KVキャッシュ+アテンションの両方にFP8を適用した場合でも、BF16ベースラインおよび線形層のみFP8(W8A8)と検証精度が一致します。
Figure 6. Qwen3-8B-Baseのトレーニング精度曲線
KVキャッシュとアテンション操作の両方にFP8を有効にすると、線形W8A8構成と比較してロールアウト段階でさらに約30%の高速化が得られ、BF16ベースライン全体では約48%の高速化を実現します。これらの向上は、アテンション計算が全体のワークロードを占める割合が大きくなる長い応答長において特に顕著です。QKVスケールの再較正プロセスは総ステップ時間の約2〜3%を消費しますが、達成された大幅な加速に比べれば微々たるコストです。
Figure 7. Qwen3-8B-Baseモデルのロールアウト性能曲線
Try End-to-End FP8 with NVIDIA NeMo RL
生成(generation)および学習(training)バックエンドの両方で線形層(linear layers)に FP8 を有効化するため、以下の設定マップは各調整パラメータが学習および生成バックエンドにどのように渡されるかを示しています。
図8. NVIDIA NeMo RL における線形層への FP8 有効化
KV キャッシュ(KV cache)とアテンション(attention)に FP8 を有効化するには、ポリシーの vllm_cfg 内で kv_cache_dtype パラメータを設定する必要があります。これにより、トレーナー側での QKV スケール再較正と vLLM バックエンドとの同期が自動的に行われます。
policy: generation: vllm_cfg: precision: fp8 # 線形層に FP8 を有効化 kv_cache_dtype: fp8 # KV キャッシュに FP8 を有効化
生成および学習のための高度な FP8 設定オプション
これまでに、線形層および KV キャッシュ+アテンション層への FP8 実装について紹介しました。上級ユーザーは、レシピのバリエーションを試すことができます。以下に一部の特徴の例を示します:
生成時に最初の N 層および/または最後の M 層のトランスフォーマー(transformer)レイヤーを BF16 で保持する(N, M は整数)
policy: generation: vllm_cfg: num_first_layers_in_bf16: N # 整数のNに置き換え num_last_layers_in_bf16: M # 整数のMに置き換え
生成および/または学習で FP32 の代わりに 2 のべき乗スケーリングファクタ(power-of-2 scaling factor)タイプを使用するよう設定
policy: generation: vllm_cfg: pow2_weight_scaling_factors: true pow2_activation_scaling_factors: true megatron_cfg: env_vars: NVTE_FP8_BLOCK_SCALING_FP32_SCALES: "0"
開発者は、表1に示すように、デフォルトのブロック単位量子化(block-wise quantized)FP8 レシピに代わって、Megatron Core バックエンド用に事前に定義された FP8 レシピのバリエーションを使用できます。詳細は引数のドキュメント文字列(docstring)を参照してください。
policy: megatron_cfg: fp8_cfg: fp8: "e4m3" fp8_recipe: "blockwise"
はじめに
ユーザーは、NeMo RL の GitHub にある llama-3.1-8b および moonlight-16b のレシピを参照することから始められます。
謝辞
この作業はチーム間の共同作業でした。FP8 レシピの開発、実験、および NeMo RL への統合に取り組んだ Jimmy Zhang、Victor Cui、Zhiyu Li、Lark Zhang に感謝します。
原文を表示
As LLMs transition from simple text generation to complex reasoning, reinforcement learning (RL) plays a central role. Algorithms like Group Relative Policy Optimization (GRPO) power this transition, enabling reasoning-grade models to continuously improve through iterative feedback. Unlike standard supervised fine-tuning, RL training loops are bifurcated into two distinct, high-intensity phases: a generation phase with a stringent latency requirement and a training phase requiring high throughput.
To make these workloads viable, researchers and engineers are turning to low-precision datatypes like FP8 to boost performance in training and throughput-oriented generation. Moreover, in some scenarios where generation is bound by GPU memory bandwidth, using low-precision parameters can improve performance due to fewer bytes per parameter.
This post dives deep into the systemic challenges of low-precision RL and how NVIDIA NeMo RL—an open source library within the NVIDIA NeMo framework—speeds up RL workloads while maintaining accuracy.
FP8 for linear layers in RL
Our recipe uses the block-wise quantized FP8 introduced by the DeepSeek-V3 Technical Report. Table 1 gives the details of tensor formats in linear projection layers.
TensorType of dataQuantization granularityScaling factor Type of ScalingWeightsFP8 (E4M3)[128, 128]FP32Block-wiseInput activationsFP8 (E4M3)[1, 128]FP32Block-wiseOutput gradientsFP8 (E4M3)[1, 128]FP32Block-wiseTable 1. Tensor formats in linear projection layers
With this recipe, linear layers can be computed with FP8 math, which has 2x peak throughput versus BF16 math. Other modules, including attention, normalization, non-linear functions, and output projections, are computed with BF16 math.
The challenge of numerical disagreement in RL
RL pipelines typically use separate engines: vLLM for rollouts and NVIDIA Megatron Core for training. Each uses unique custom NVIDIA CUDA kernels to maximize performance. This inherently introduces numerical differences that cumulatively magnify in lower precision due to additional quantization and dequantization logic. We quantify this numeric difference as a token multiplicative probability error:
Perfect alignment earns a score of 1, and we typically find ‘acceptable’ values to be <1.03-1.05 when not using any additional techniques.
End-to-end FP8 in linear layers reduce numerical disagreement
During the development of the FP8 recipe, we experimented with three recipes:
Baseline recipe: BF16 for both generation and training.
Recipe candidate 1: FP8 is applied exclusively during generation, while policy model training is conducted in BF16.
Final recipe: End-to-end FP8: we use FP8 in both generation and training engines
We observe that compared to recipe candidate 1 with FP8 only for generation, the final recipe consistently shows a lower numerical disagreement between generation and training. Note that the baseline recipe always gives the lowest numerical disagreement. Figure 1 shows the token multiplicative probability error metric of the three recipes.
Figure 1. Token multiplicative probability error in three recipes
Mitigating numerical disagreement with importance sampling
Importance sampling is used to correct the distribution mismatch between the model (i.e., distribution) that generates the data, and the model (i.e., distribution) that is being trained. It is a per-token weight multiplied by loss. You can refer to our GRPO documentation for the detailed theoretical background of importance sampling.
Experiments show that:
For recipe candidate 1 (FP8 generation and BF16 training), importance sampling can narrow the accuracy gap from BF16 RL, but can’t close the gap.
For the final recipe (end-to-end FP8), importance sampling completely closes the gap from BF16 training. Figure 2 shows the validation accuracy during training for different recipes.
Figure 2. Validation accuracy of GRPO training on Llama 3.1 8B Instruct model and math dataset
Results for FP8 Linear Layer E2E
We evaluate the end-to-end FP8 recipe on both dense and mixture-of-experts models, measuring validation accuracy and training throughput against the BF16 baseline.
FP8 end-to-end on dense models: Llama 3.1 8B Instruct
Table 2 shows the accuracy of the FP8 end-to-end recipe and BF16 recipe in GRPO training of Llama 3.1 8B instruct model and math dataset trained to 4000 steps.
PrecisionBF16FP8 generation onlyFP8 End-to-EndValidation accuracy0.6160.5860.613Table 2: Accuracy results for Llama3 8B validation accuracy across different precision configs
In terms of speed up, the FP8 recipe achieves a consistent >15% throughput improvement compared to BF16. Figure 3 is the GRPO training (tokens per second per GPU) of two recipes over 1000 steps.
Figure 3. Throughput (tokens per second per GPU) of the two recipes (blue: BF16 and pink: FP8 end-to-end)
Although the theoretical speedup of FP8 over BF16 is 2x, in practice, it is lower because only linear layers benefit from faster math throughput, whereas the attention and elementwise layers stay the same. The extra quantization kernels added before linear layers introduce some overhead. The 15%-25% speedup matches our standalone test of vLLM. With further optimizations such as fusing quantization kernels in vLLM, we project that the speedup can be further improved to 1.25x.
FP8 end-to-end on MoE models: Qwen3-30B
Similar experiments were run on mixture-of-experts (MoE) models, with results for Qwen3-30B showing matching accuracy curves. FP8 achieves similar accuracy to BF16. Speed gain is being investigated.
Figure 4. Accuracy curves for Qwen3-30B GRPO with OpenMathInstruct-2 dataset, on 8 nodes of H100. Blue is BF16, and pink is FP8 end-to-end
Extending FP8 for KV cache and attention
With a transformer model, linear layers are not the only bottleneck. KV cache growth and attention computation often dominate the end-to-end rollout time in RL workflows with long output sequence lengths (OSL) while also saturating memory bandwidth and slowing down token generation. This motivated us to explore FP8 for KV cache and attention in the loop of RL. Per-tensor scaling FP8 is used.
Implementing FP8 for KV-cache in an RL setting is uniquely challenging because policy weights change at every step. Unlike static inference, where calibration happens once, RL requires dynamic handling of quantization scales.
NeMo RL adopts the following approach to solve this:
Recalibration: At the end of each training step, the trainer recalibrates the Query, Key, Value (QKV) scales using the updated policy weights.
Data selection: This calibration is performed using the training data (prompts and generated responses) to ensure the scales reflect the current distribution.
Synchronization: The newly calculated scales are then synchronized to the inference engine (vLLM) for the subsequent rollout phase.
Figure 5. The RL workflow with FP8 KV cache
This design ensures that the rollout engine always uses optimal quantization scales derived from the latest policy state, minimizing accuracy degradation. The calibration overhead is minimal, consuming approximately 2-3% of the total step time.
TensorType of dataScaling factor Type of scalingQKV attention activationsFP8 (E4M3)FP32Tensor-wiseStored KV cacheFP8 (E4M3)FP32Tensor-wiseTable 3: Tensor formats for attention activations and stored KV cache
Summary of results for FP8 on KV cache and attention
We ran results on the Qwen3-8B-Base model using the GRPO algorithm, with FP8 applied in rollout and BF16 for training. While the mismatch KL divergence is slightly higher when quantizing both KV cache and attention due to compounded errors, our recipe mitigates instability. By enabling token-level truncated importance sampling, the FP8 for both linear + KV cache + attention achieves validation accuracy alignment with the BF16 baseline and the FP8 for the linear layer (W8A8).
Figure 6. Training Accuracy curves for Qwen3-8B-Base
Enabling FP8 for both KV-cache and attention operations yields an additional ~30% speedup on the rollout stage over the linear W8A8 configuration, resulting in an overall ~48% speedup compared to the BF16 baseline. These gains are particularly pronounced at longer response lengths, where attention computations constitute a larger fraction of the overall workload. The QKV scale recalibration process consumes approximately 2-3% of the total step time, representing a minor cost relative to the substantial acceleration achieved.
Figure 7. Rollout performance curves for Qwen3-8B-Base model
Try End-to-End FP8 with NVIDIA NeMo RL
To enable FP8 for linear layers in both generation and training backends, the following config map shows how each tuning parameter gets passed to the training and generation backends.
Figure 8. Enabling FP8 for linear layers in NVIDIA NeMo RL
To enable FP8 for KV cache and attention, one needs to configure the kv_cache_dtype parameter in vllm_cfg for the policy, which automatically handles the QKV scale recalibration on the trainer side and synchronization with the vLLM backend.
policy: generation: vllm_cfg: precision: fp8 # Enable FP8 for linear layers kv_cache_dtype: fp8 # Enable FP8 for KV-cache
Advanced FP8 configuration options for generation and training
So far, we have introduced the implementation of FP8 for linear layers and KV cache + attention layers. Advanced users can experiment with variants of the recipe. The following are examples of some of the features:
Keeping first N and/or last M transformer layers in BF16 during generation (N, M are integers)
policy: generation: vllm_cfg: num_first_layers_in_bf16: N # replace N with an integer num_last_layers_in_bf16: M # replace M with an integer
Configure generation and/or training to use power-of-2 scaling factor type instead of FP32
policy: generation: vllm_cfg: pow2_weight_scaling_factors: true pow2_activation_scaling_factors: true megatron_cfg: env_vars: NVTE_FP8_BLOCK_SCALING_FP32_SCALES: "0"
Developers can use variants of FP8 recipes predefined for the Megatron Core backend, instead of the default block-wise quantized FP8 recipe, as Table 1 shows. Refer to the argument docstring for details.
policy: megatron_cfg: fp8_cfg: fp8: "e4m3" fp8_recipe: "blockwise"
Get started
Users can start by referring to the llama-3.1-8b and moonlight-16b recipes in the NeMo RL GitHub.
Acknowledgements
This work was a collaborative effort across teams. We’d like to thank Jimmy Zhang, Victor Cui, Zhiyu Li, and Lark Zhang for their work on the FP8 recipe development, experimentation, and integration into NeMo RL.
関連記事
継続学習のための「睡眠」アプローチ(24 分読)
Google の研究者らは、モデルが短期間の文脈内知識を長期パラメータに統合する新手法「Sleep」を提案した。この手法は蒸留と再生成を用い、さらに強化学習による「夢見」段階で合成カリキュラムを生成して自己改善を図る。
ヒルクライミング機械の構築:7 つの新規 MAI モデルを発表(5 分読了)
マイクロソフトは、開発者がモデル重みを調整し日常製品に統合できる 7 つの新規 MAI モデル「MAI」を発表した。これらは強化学習環境を用いたフロンティア・チューニング技術を採用しており、またメイヨー・クリニックとの医療 AI 共同開発も発表した。
3D プリンタ対応の人間型ロボット脚がロボティクス実験を加速
Hugging Face が公開した約 2,500 ドルの安価な 3D プリント製人間型ロボット脚により、研究者は実世界での AI ロボットソフトウェアテストと訓練を容易に行えるようになった。