LLM推論における非決定性の克服
Horace Heらによる、大規模言語モデルの推論における非決定性を克服する研究についての記事。
キーポイント
LLM推論の非決定性は温度パラメータを0に設定しても完全には解決されず、再現性の課題が残る
従来の「並列処理+浮動小数点演算」仮説だけでは説明できない非決定性の根本原因が存在する
GPU上の単純な行列乗算は決定性を示すが、LLM推論では異なる振る舞いが観測される
LLM推論エンジンの実装レベルでの非決定性要因の解明が科学的再現性に重要
影響分析・編集コメントを表示
影響分析
この記事はLLM推論の再現性問題の根本原因に迫る重要な分析を提供しており、科学研究や実運用における信頼性向上に直接影響する。非決定性の真因を解明することで、より安定したLLMシステムの構築と評価基準の確立が期待される。
編集コメント
LLMの実用化が進む中で見過ごされがちな再現性問題に焦点を当てた貴重な分析。ベンチマーク評価や研究検証の信頼性向上に不可欠な視点を提供している。
大規模言語モデル推論における非決定性の克服
科学的進歩の基盤である再現性は、大規模言語モデル(LLM)から一貫した結果を得る上で重大な課題に直面している。例えば、ChatGPTに同じ質問を繰り返しても異なる回答が得られることがある。これは、LLMの出力が確率分布に変換され、トークンが確率的に選択される「サンプリング」プロセスによるもので、温度パラメータを0に設定し(貪欲サンプリング)、理論的には決定論的に最高確率トークンを選ばせても、実際には非決定性が残る。これはAPIサービスだけでなく、vLLMやSGLangのようなオープンソース推論ライブラリを自前のハードウェアで実行する場合でも同様である。
この非決定性の原因について、広く支持されている仮説は「並行処理と浮動小数点演算の組み合わせ」である。GPUにおける浮動小数点演算は非結合性((a + b) + c ≠ a + (b + c))を示し、Transformerアーキテクチャの注意スコアやロジット計算において、並行実行されるスレッドの処理順序の違いが微小な結果の差を生み、それが累積して出力のばらつきを引き起こすと説明される。この見解は多くの技術議論で繰り返し言及されている。
しかし、この仮説だけでは完全な説明とはならない。なぜなら、同じGPU上で同じデータを用いた単純な行列乗算を繰り返し実行すると、ビット単位で完全に同一の結果が得られるからである。つまり、浮動小数点演算と並行処理が存在する環境下でも、特定の条件下では決定性が保たれる。
記事は、LLM推論の非決定性の真の原因を探るため、この矛盾を出発点としている。単なる「並行処理+浮動小数点」仮説を超えて、モデル推論における具体的な計算グラフやアルゴリズムの実装レベルに潜む要因を明らかにする必要性を示唆している。再現性の確保は、LLMの科学的検証、デバッグ、製品への統合において極めて重要であり、その基盤となる決定性を実現するためには、表面的に語られがちな説明ではなく、根本的なメカニズムの解明が不可欠であると論じている。
原文を表示
Defeating Nondeterminism in LLM Inference - Thinking Machines Lab Defeating Nondeterminism in LLM Inference
Reproducibility is a bedrock of scientific progress. However, it’s remarkably difficult to get reproducible results out of large language models.
For example, you might observe that asking ChatGPT the same question multiple times provides different results. This by itself is not surprising, since getting a result from a language model involves “sampling”, a process that converts the language model’s output into a probability distribution and probabilistically selects a token.
What might be more surprising is that even when we adjust the temperature down to 0This means that the LLM always chooses the highest probability token, which is called greedy sampling. (thus making the sampling theoretically deterministic), LLM APIs are still not deterministic in practice (see past discussions here, here, or here). Even when running inference on your own hardware with an OSS inference library like vLLM or SGLang, sampling still isn’t deterministic (see here or here).
But why aren’t LLM inference engines deterministic? One common hypothesis is that some combination of floating-point non-associativity and concurrent execution leads to nondeterminism based on which concurrent core finishes first. We will call this the “concurrency + floating point” hypothesis for LLM inference nondeterminism. For example, a recent arXiv preprint writes:
Floating-point arithmetic in GPUs exhibits non-associativity, meaning $(a + b) + c \neq a + (b + c)$ due to finite precision and rounding errors. This property directly impacts the computation of attention scores and logits in the transformer architecture, where parallel operations across multiple threads can yield different results based on execution order.
You can also find the “concurrency + floating point” hypothesis repeated by others, like here (“There are speed tradeoffs, and in order to make the endpoints fast GPUs are used, which do parallel [nondeterministic] calculations. Any modern GPU neural net calculations will be subject to these."), or here (“Because GPUs are highly parallelized, the ordering of additions or multiplications might be different on each execution, which can cascade into small differences in output.").
While this hypothesis is not entirely wrong, it doesn’t reveal the full picture. For example, even on a GPU, running the same matrix multiplication on the same data repeatedly will always provide bitwise equal results. We’re definitely using floating-point numbers. And our GPU definitely has a lot of concurrency. Why don’t we see nondeterminism in this test?
A = torch.randn(2048, 2048, device='cuda', dtype=torch.bfloat16) B = torch.randn(2048, 2048, device='cuda', dtype=torch.bfloat16) ref = torch.mm(A, B) for _ in range(1000): assert (torch.mm(A, B) - ref).abs().max().item() == 0 To understand the true cause of LLM inference nondeterminism, we must look deeper.
Unfortunately, even defining what it means for LLM inference to be deterministic is difficult. Perhaps confusingly, the following statements are all simultaneously true:
Some kernels on GPUs are nondeterministic.
However, all the kernels used in a language model’s forward pass are deterministic.
Moreover, the forward pass of an LLM inference server (like vLLM) can also be claimed to be deterministic.
Nevertheless, from the perspective of anybody using the inference server, the results are nondeterministic.
In this post, we will explain why the “concurrency + floating point” hypothesis misses the mark, unmask the true culprit behind LLM inference nondeterminism, and explain how to defeat nondeterminism and obtain truly reproducible results in LLM inference.
The original sin: floating-point non-associativity
Before talking about nondeterminism, it’s useful to explain why there are numerical differences at all. After all, we typically think of machine learning models as mathematical functions following structural rules such as commutativity or associativity. Shouldn’t there be a “mathematically correct” result that our machine learning libraries should provide us?
The culprit is floating-point non-associativity. That is, with floating-point numbers:
(0.1 + 1e20) - 1e20 >>> 0 0.1 + (1e20 - 1e20) >>> 0.1 Ironically, breaking associativity is what makes floating-point numbers useful.
Floating-point numbers are useful because they allow for a “dynamic” level of precision. For the purposes of explanation, we will use base 10 (instead of binary), where floating-point numbers are in the format $\text{mantissa} * 10^\text{exponent}$. We will also use 3 digits for the mantissa and 1 digit for the exponent.
For example, for the value 3450, we can represent it exactly as $3.45 * 10^3$. We can also represent much smaller values like 0.486 as $4.86 * 10^{-1}$. In this way, floating point allows us to represent both very small as well as very larg
関連記事
今日のまとめ
AI日報で今日の重要ニュースをまとめ読み