エージェント型テスト時スケーリング(GitHub リポジトリ)
UMD や Google などの研究チームが、LLM の呼び出しを一切行わずに環境シミュレーションのみで推論戦略を自動発見する「AutoTTS」を発表し、コストと計算リソースを劇的に削減した。
キーポイント
エージェント駆動の自動検索アプローチ
手動でのヒューリスティック設計から脱却し、コーディングエージェントがオフライン環境でコード定義コントローラーを反復提案・改善する「テストタイムスケーリング」手法。
LLM 呼び出しゼロの低コスト実行
探索プロセス中に LLM の呼び出しを 0 回とし、キャッシュされたセグメントのリプレイのみで動作するため、1 回の完全発見ランニングに約 40 ドルと 160 分という極めて低いコストを実現。
高効率な推論戦略の発見
発見された「Confidence Momentum Controller (CMC)」により、SC@64 と同等の精度を維持しつつ約 69.5% のトークンを削減する成果を達成。
ベータ値に依存する動的スケジューリング
パラメータ(初期数、最大ブランチ数など)はベータ値[0,1]の滑らかな解析関数として定義され、ベータが増加すると予算使用量が増加するように設計されています。
探索と停止のバランス制御
信頼性閾値(conf_thresh)はベータ増加とともに厳しくなり、傾向閾値(trend_thresh)は緩くなることで、高ベータ時に広範な探索を促進しつつ予算を最適化します。
指数移動平均の慣性調整
ベータ値が高いほどEMAのアルファ値が低下し、より遅い更新(高い慣性)を実現することで、高負荷時の安定した予算配分を確保します。
動的スケジューリングパラメータの初期化
コンフィグからベータ値を取得し、スケジュール関数を通じて初期探索数や最大ブランチ使用数などの複数の制御パラメータを動的に設定する。
影響分析・編集コメントを表示
影響分析
この研究は、LLM の推論効率化において「探索コスト」自体を排除するパラダイムシフトをもたらす可能性があります。従来の試行錯誤型アプローチでは避けられなかった LLM 呼び出しの重荷を取り除くことで、大規模な自動最適化が実用的かつ低コストで実現可能となり、リソース制約下での高性能推論システムの構築に新たな道筋を開きます。
編集コメント
LLM の利用コストがボトルネックとなる中、探索プロセス自体を LLM に依存しない設計は非常に革新的です。実運用における推論効率化の新たな基準となる可能性があります。
LLMs が LLMs を改善する:テストタイムスケーリングのためのエージェント型発見
Tong Zheng, Haolin Liu, Chengsong Huang, Huiwen Bao, Sheng Zhang, Rui Liu, Runpeng Dai, Ruibo Chen, Chenxi Liu, Tianyi Xiong, Xidong Wu, Hongming Zhang, Heng Huang
*UMD · UVA · WUSTL · UNC · Google · Meta*
AutoTTS は、テストタイムスケーリング戦略の設計を手作業によるヒューリスティックの作成から環境駆動型の自動探索へと再定義します。人間はオフラインリプレイ環境(状態、行動、フィードバック、目的)のみを構築し、コーディングエージェントがその中でコードで定義されたコントローラを反復的に提案・改良します。コード編集のみで、勾配更新は不要。低コスト:LLM 呼び出し 0 回、完全なリプレイ実行。
注目すべき結果
- β ≈ 0.5 の条件下で SC@64 と比較して約 69.5% のトークンを節約;保持された平均精度は、4 つのバックボーンスケール全体で SC@64 に匹敵。
- 完全な発見実行にかかる推定費用は 39.9 ドル。
- 同じ実行に要する実時間(壁時計)は 160 分。
- 発見評価中は LLM の呼び出しが 0 回(キャッシュされたセグメントのみをリプレイ)。
発見されたコントローラは信頼度モメンタムコントローラ (Confidence Momentum Controller: CMC)であり、トレンドベースの停止、結合した幅・深さ制御、アライメント意識的な深さ割り当て、保守的な分岐放棄という特徴を持つ。
問題設定
我々は、適応型テストタイム推論を、固定長の区間における枝(ブランチ)に対して有限の予算を配分する問題として扱う。
t ステップにおける状態 (State):
s_t = (q, m_t, I_t, ℓ_t, Ω_t)
q: 質問; m_t: インスタンス化されたブランチ数; I_t: アクティブなブランチセット; ℓ_t: 深さベクトル; Ω_t: 公開されたプローブトリプル。
許容される行動 A(s_t):
- BRANCH — 最初の区間を通じて新しいブランチを開く。
- CONTINUE(i) — ブランチ i を1区間進める。
- PROBE(i) — 深さを進めずに ω_{i,ℓ} を公開する(プローブ:探索的検証)。
- PRUNE(i) — ブランチ i を非アクティブ化; 深さと過去のプローブは記録されたまま残る。
- ANSWER — 終了し、コントローラのターミナルアグリゲータを適用する。
コスト(区間単位):
Cost(s_t) = Σ_i ℓ_{t,i} + κ_probe · |Ω_t| (通常 κ_probe = 0)
目的関数。 コードで定義された方策 π(· | s, β) は、すべての内部ハイパーパラメータを決定論的にスケジューリングするスカラーのメタパラメータβによってパラメータ化される。タスク (q, y) ~ 𝒟 に対して:
max_{π, β} E_{q,y}[ 1{ŷ_{π,β}(q) = y} − γ · C_{π,β}(q) ]
外側ループは、πの実装について探索を行う。各候補はオフラインキャッシュ上でリプレイ評価され、トレースとスケーリング曲線が次のラウンドの履歴に組み込まれる。
環境構築((モデル,ベンチマーク)ごとに1回実行)
上記の MDP は、発見ループ開始前に具体的なリプレイ環境としてインスタンス化される:
- インターフェースを指定する。s_t, A(s_t), Cost(s_t)、および精度–コスト目的関数を固定する。
(MDP: マルコフ決定過程)
- オフライン軌跡収集。各クエリについて、バックボーンから N 本の並列独立した推論トレース(完全文字列として)をサンプリングし、各トレースを Δトークン長の固定長セグメントに分割して、ブランチ接頭辞 z_{i,k} とプローブ応答 ω_{i,k} を列挙します。
- リプレイストアの具体化。環境遷移はすべてアーカイブされたテーブルを参照します。例えば、PROBE(i) は新しいデコーディングを行わずにキャッシュされた ω_{i,k} を取得します。
- 発見プロセスへの引き渡し。候補コントローラーは、観測/ステップ操作を通じてのみシミュレーションされます。漸近的な評価コストの大部分はテーブルリプレイによって支配されます。
ステップ 1~3 は一度だけ実行され、反復的なコーディングエージェントによる発見は、リプレイストアが凍結された後に開始されます。
このリポジトリ内では:
- efficient_reasoning_controller/workspace/code_base/environment/ — 検索セット用リプレイストアです。
- efficient_reasoning_controller/test_environment/ — ホールドアウト用リプレイストアであり、提案者には決して公開されません。
発見:β パラメータ化とトレースフィードバック
- β パラメータ化。各候補コントローラーは、単一のスカラー値 β と、β から内部のすべてのノブ(調整項目)への決定論的かつ単調なマッピングをエクスポートします。外部探索は β の掃引に集約され、検索セットのみに対して微調整された脆い閾値が排除されます。
- 実行トレースによる履歴拡張。各ラウンドのβ掃引に伴い、経験的なスケーリング曲線と、リプレイ中に再構築された行動ごとの完全な軌跡をアーカイブします。トレースは、コードを書き換える前に欠陥を局在化するための微細な振る舞いの証拠を探索者に提供します。
メイン結果
AutoTTS は AIME24 のリプレイ構成に対して最適化され、4 つの Qwen3 バックボーンスケールにわたって保持された AIME25 / HMMT25 ベンチマークで評価されました。プロジェクトページでは以下の傾向が報告されています:
- 精度とトークンのトレードオフにおける改善。発見されたコントローラーは、SC@64、ASC、ESC、Parallel-Probe といった手作業ベースラインの経験的パレートフロンティアを通常上回ります。
- 保持されたベンチマークへの一般化能力。AIME24 で発見されたポリシーは保持されたベンチマークへ転移し、4 つのバックボンスケールのうち 3 つで平均精度においてすべての手作業ベースラインを上回り、Qwen3-8B においても競争力のある結果を示しました。
- β = 0.5 の動作点。これは SC@64 と比較して集約トークン使用量を約 69.5% 削減しつつ、モデル全体で平均の保持された精度を維持します。
- β = 1.0 の動作点。プロジェクトページの表計算された比較セルのうち 8 つ中 5 つにおいて、すべての手作業ベースラインを超える最高精度を実現します。
β を掃引することで精度とトークンのスケーリング曲線が追跡され、より大きな β は一般的に高予算・精度優先の振る舞いへと移行し、より小さな β は低コストな推論を好むことが示されます。
発見プロセスの進化
ラウンドレベルの軌跡(例えば上記図の t1 -> t5)は、探索プロセスを通じて一貫してより良い目的関数値へと向かう動きを示しています:
- 探索ベンチマーク上では、後続のラウンドで精度が向上しつつトークンの増加が制御されており、これはランダムな変動ではなく、漸進的に改善されたポリシー構造を意味します。
- ホールドアウトベンチマークにおいて、同じ軌道は競争力を維持し、しばしば改善を示すことから、発見された制御ロジックが最適化分割を超えて転移可能であることが示唆されます。
- この軌道は、勾配更新を伴わない目的指向のコード進化を反映しています:エージェントは明示的なコントローラープログラムを編集し、リプレイベースの精度・コストフィードバックを受け取り、より良い実証的トレードオフへと行動を反復的にシフトさせます。
これは AutoTTS の重要なポイントです。最適化は、バックプロパゲーションやバックボーンモデルのパラメータ微調整ではなく、固定されたリプレイ環境内での反復的なプログラム探索によって達成されます。
発見されたコントローラー:CMC
発見されたコントローラーはConfidence Momentum Controller (CMC) と名付けられています。その主な機構は以下の通りです。
- トレンドベースの停止。CMC はプール自信度の指数移動平均(Exponential Moving Average: EMA)を維持し、自信度レベルが高くかつトレンドが非負である場合にのみ停止します。これにより、一時的な自信度のスパイクで停止することが回避されます。
- 結合された幅・深さ制御。幅広げと深さ増は EMA のデルタを通じてリンクされています:強い自信度の向上は新しいブランチの生成を抑制し、停滞または後退は幅広げを引き起こします。
- アライメントを考慮した深さ配分。最新の回答がプールの勝者と一致するブランチには追加のプローブステップが付与され、計算リソースが出現するコンセンサスに集中される一方で、アクティブなブランチも前進し続けます。
- 保守的なブランチの放棄。ブランチは、複数ラウンドにわたって持続的に逸脱した場合のみ放棄され、少なくとも 2 つのアクティブなブランチが保持されます。
これらのメカニズムは、コード定義されたコントローラーロジックとして実装され、ハンドクラフトされたベースラインと同じリプレイ環境を通じて評価されます。
完全な OptimalController ソースを表示 (CMC、クリックして展開)
class OptimalController(LLMDesignedMethod):
"""
Confidence Momentum Controller (CMC)。
コアのアイデア
過去のすべての提案(IBC, SCR, DGCC)は、同じ根本的な停止信号を共有しています:現在のステップで完了した回答プールから計算された「瞬間的」な Beta 多数派の信頼度です。これは単一ステップでの信頼度のスパイクに脆弱です:運良く初期に同一の回答がクラスター化すると、分布が安定する前にゲートが早期に発動してしまう可能性があります。
CMC は、瞬間的な信頼性ゲートをモメンタム認識型のゲートに置き換えます:
- 直近
T_emaラウンドにおけるプール信頼性の指数移動平均(EMA)を追跡する:ema_conf = alpha * conf + (1 - alpha) * ema_conf - 直近の改善度合いであるデルタを追跡する:delta = ema_conf - ema_conf_prev
- ゲートが作動するのは、以下の両方が満たされた場合のみ:
(a) ema_conf >= conf_thresh(レベル要件)
(b) delta >= -slack(非悪化するモメンタム;slack は減少信号で停止しないようにする小さな許容値)
これは、コントローラーが 1 ラウンドのスパイクだけで停止できないことを意味します。EMA が高水準にあり、かつ積極的に低下していない必要があります。
プローブ年齢優先度による適応的深さ割り当て
各アクティブな未完了ブランチは probe_count(受け取ったプローブステップの数)を追跡しています。各ラウンドにおいて、コントローラーは probe_budget ステップの 1 ラウンドあたりのプローブ予算を、プローブ数降順でソートされた優先度キューを用いてアクティブなブランチ間に分配します。最も多くの投資がなされているブランチが最優先で処理され(各ブランチ最大 burst_senior 追加ステップまで)、残りの予算は投資の少ないブランチに回されます。これにより、完了に近いブランチに深さを集中させつつ、若手ブランチも前進させることが可能となり、均一な割り当てや純粋に整列バイアス付き(SCR)の割り当て、あるいは怠惰な待機(DGCC)とは異なります。
3 つの階層によるブランチ分類
ウォームアップ後:
- "aligned"(整列済み): 最新の回答 == プール勝者
- "deviant"(逸脱): 最新の回答 != プール勝者、かつ >= 1 ラウンドで意見が分かれた
- "neutral"(中立): まだプール勝者がいない、または意見が分かれ始めた最初のラウンド
階層は各ブランチのプローブ乗数に影響します:
aligned -> 乗数 = burst_aligned (例:高いベータ値では 2)
neutral -> 乗数 = 1
deviant -> 乗数 = 1、ただし >= abandon_patience ラウンドにわたって逸脱している場合、そのブランチは放棄される
信頼性傾向の拡大
拡大(新しいブランチの生成)は、信頼性の*傾向*(デルタ)が正で大きいのか、それとも弱いか負であるかによって駆動されます:
- もし delta > trend_thresh: 信頼性が加速している -> 拡大しない
(まもなく停止する見込み)
- もし delta <= trend_thresh: 横ばいまたは後退 ->
widen_burst個の新しいブランチを拡大させる。ただし max_branch の上限まで。
これは、深層化が証拠品質の向上をもたらしているかどうかに基づいて幅の決定と直接結合されたフィードバックループであり、以前の提案には存在しなかったものです。
ベータスケジュール
すべてのハイパーパラメータは、[0,1] 範囲内の単一のベータ値に対する決定論的関数です。
beta=0 -> 保守的(ブランチ数が少なく、EMA の慣性も低く、停止しやすい)
beta=1 -> ほぼフル予算(ブランチ数が多い、慣性が強く、停止しにくい)
新規性 versus 先行研究
ASC / ESC:完全読了;逐次的プローブなし。
Parallel_Probe:固定コホート;瞬時の多数決;プール/完了の区別なし;EMA なし。
IBC (r0001):瞬時のプール信頼度ゲート;均一な 1 ステッププローブ;ラウンドごとの 1 ブランチ拡大;EMA やトレンドなし。
SCR (r0002):非対称バースト(整合した方がより多くのステップを獲得);プラトートリガーによる拡大;瞬時ゲート;EMA なし。
DGCC (r0003):二重瞬時ゲート(プライマリ+ソフト裏付け);ロックされたブランチに対する遅延スリープ;投票差比例拡大;EMA モメンタムなし。
CMC:すべての瞬時ゲートを単一の EMA モメンタムゲートに置換;プローブ年齢優先スケジューリングを導入(均一でもバースト整合のみでもない);信頼度トレンドによる拡大(プラトーでも投票差でもない);3 段階分類は、追加のハイパーパラメータなしで DGCC の二重ゲートに対する自然な簡略化である。
"""
NAME = "optimal_controller"
_MAX_BRANCH = 64
_MAX_OUTER = 500
def _schedule(self, beta: float) -> dict:
"""
すべてのスケジュールは、[0,1] 内の beta に対する滑らかな解析関数です。
単調性:
- バジェット使用を制御するパラメータ(n_init, max_branch_use,
burst_aligned, widen_burst, warm_up, abandon_patience, T_ema)は、
beta に対して非減少(NON-DECREASING)です。
- conf_thresh は beta に対して非減少です(停止しにくくなる → バジェット増加)。
- trend_thresh は beta に対して非増加(NON-INCREASING)です(beta が高い場合に広げるトリガーが容易になる → より広い探索によるバジェット増加)。
- ema_alpha は beta に対して非増加です(アルファ値が低い = EMA が遅い = 慣性が大きい = beta が高い場合のバジェット増加)。
"""
b = max(0.0, min(1.0, float(beta)))
n_init = max(2, round(2 + 6 * b))
max_branch_use = min(self._MAX_BRANCH, round(4 + 60 * b))
warm_up = max(2, round(2 + 8 * b))
abandon_patience = max(3, round(3 + 9 * b))
T_ema = max(2, round(2 + 6 * b))
ema_alpha = 0.70 - 0.40 * b
conf_thresh = 0.85 + 0.12 * b
delta_slack = 0.04 - 0.03 * b
burst_aligned = max(1, round(1 + 2 * b))
widen_burst = max(1, round(1 + 3 * b))
trend_thresh = 0.04 - 0.03 * b
min_complete = max(2, round(2 + 3 * b))
return {
"n_init": n_init,
"max_branch_use": max_branch_use,
"warm_up": warm_up,
"abandon_patience": abandon_patience,
"T_ema": T_ema,
"ema_alpha": round(ema_alpha, 4),
"conf_thresh": round(conf_thresh, 4),
"delta_slack": round(delta_slack, 4),
"burst_aligned": burst_aligned,
"widen_burst": widen_burst,
"trend_thresh": round(trend_thresh, 4),
"min_complete": min_complete,
}
def __init__(self, config: Optional[Dict[str, Any]] = None):
super().__init__(config)
self._beta = float((config or {}).get("beta", 0.5))
sched = self._schedule(self._beta)
self.n_init = sched["n_init"]
self.max_branch_use = sched["max_branch_use"]
self.warm_up = sched["warm_up"]
self.abandon_patience = sched["abandon_patience"]
self.T_ema = sched["T_ema"]
self.ema_alpha = sched["ema_alpha"]
self.conf_thresh = sched["conf_thresh"]
self.delta_slack = sched["delta_slack"]
self.burst_aligned = sched["burst_aligned"]
self.widen_burst = sched["widen_burst"]
self.trend_thresh = sched["trend_thresh"]
self.min_complete = sched["min_complete"]
self.trace_recorder = MethodTraceRecorder()
def _reset_trace(self) -> None:
self.trace_recorder = MethodTraceRecorder()
def _trace_step(
self,
*,
event: str,
goal: str,
step_input: Dict[str, Any],
step_output: Any,
state: Dict[str, Any],
decision: str,
) -> None:
self.trace_recorder.add_step(
event=event,
goal=goal,
input=step_input,
output=step_output,
state=state,
decision=decision,
)
def get_last_trace(self) -> List[Dict[str, Any]]:
return self.trace_recorder.to_list()
def solve_with_trace(self, question) -> Dict[str, Any]:
answer = self.solve(question)
return {"answer": answer, "trace": self.get_last_trace()}
def _pool_stats(self, completed: List[str]):
"""(winner, top1, top2, conf) over completed-answer pool."""
if not completed:
return None, 0, 0, 0.0
winner, top1, top2, _ = _vote_stats(completed)
conf = _beta_majority_confidence(top1, top2)
return winner, top1, top2, conf
def _update_ema(self, ema_prev: float, new_val: float) -> float:
"""EMA update: ema = (1 - alpha) * ema_prev + alpha * new_val."""
return (1.0 - self.ema_alpha) * ema_prev + self.ema_alpha * new_val
def _classify_branch(
self,
br: Dict[str, Any],
pool_winner,
warm_enough: bool,
) -> str:
if not warm_enough or pool_winner is None:
return "neutral"
if br["latest_ans"] == pool_winner:
return "aligned"
return "deviant"
def _probe_branch(
self,
question,
br: Dict[str, Any],
completed_answers: List[str],
n_steps: int,
) -> None:
"""ブランチ br を最大 n_steps 回プローブし、完了した回答を記録する。"""
for _ in range(n_steps):
if br["finished"]:
break
out = _safe_probe_more(question, br["index"])
if out is None:
br["finished"] = True
if br["latest_ans"] is not None:
completed_answers.append(br["latest_ans"])
break
new_ans, is_finish = out
br["probe_count"] += 1
br["latest_ans"] = new_ans
br["finished"] = is_finish
if is_finish:
completed_answers.append(new_ans)
break
def solve(self, question) -> Optional[str]:
self._reset_trace()
self._trace_step(
event="start",
goal="initialize CMC run",
step_input={"beta": self._beta},
step_output="initialized",
state={
"n_init": self.n_init,
"max_branch_use": self.max_branch_use,
"warm_up": self.warm_up,
"abandon_patience": self.abandon_patience,
"T_ema": self.T_ema,
"ema_alpha": self.ema_alpha,
"conf_thresh": self.conf_thresh,
"delta_slack": self.delta_slack,
"burst_aligned": self.burst_aligned,
"widen_burst": self.widen_burst,
"trend_thresh": self.trend_thresh,
"min_complete": self.min_complete,
},
decision="start confidence momentum controller",
)
ブランチの状態:
# index : probe_new から取得した安定ブランチインデックス
# latest_ans : 現在の回答(中間または最終)
# finished : bool — ブランチが予算を完全に使い果たしたか
# abandoned : bool — 持続的な逸脱により破棄されたか
# probe_count : probe_more ステップの受信回数
# disagree_rounds: 回答がプール勝者と一致しなかった連続ラウンド数
branches: List[Dict[str, Any]] = []
completed_answers: List[str] = []
total_spawned = 0
# ---- フェーズ 0: n_init ブランチを開く ----
for _ in range(self.n_init):
out = _safe_probe_new(question)
if out is None:
break
ans, idx, is_finish = out
total_spawned += 1
br: Dict[str, Any] = {
"index": idx,
"latest_ans": ans,
"finished": is_finish,
"abandoned": False,
"probe_count": 0,
"disagree_rounds": 0,
}
branches.append(br)
if is_finish:
completed_answers.append(ans)
self._trace_step(
event="init_branches",
goal="open initial branch batch",
step_input={"n_init": self.n_init},
step_output={
"n_spawned": total_spawned,
"n_completed": len(completed_answers),
},
state={"total_spawned": total_spawned},
decision="proceed to main loop",
)
if not branches:
self._trace_step(
event="finish",
goal="return final answer",
step_input={},
step_output={"answer": None, "stop_reason": "no_branches"},
state={"total_spawned": 0},
decision="no branches available",
)
return None
# EMA state — initialised to 0 (no evidence yet)
ema_conf = 0.0
ema_conf_prev = 0.0
ema_history: List[float] = []
outer_step = 0
while outer_step < self._MAX_OUTER:
# ---- Compute current pool stats ----
pool_winner, top1, top2, pool_conf = self._pool_stats(completed_answers)
n_complete = len(completed_answers)
warm_enough = (outer_step >= self.warm_up)
# ---- Update EMA ----
ema_conf_prev = ema_conf
ema_conf = self._update_ema(ema_conf, pool_conf) # Exponential Moving Average (EMA)
ema_history.append(ema_conf)
if len(ema_history) > self.T_ema:
ema_history.pop(0)
if len(ema_history) >= 2:
ema_delta = ema_history[-1] - ema_history[0]
else:
ema_delta = 0.0
# ---- Classify branches and update disagree_rounds ----
if warm_enough and pool_winner is not None:
for br in branches:
if br["abandoned"] or br["finished"]:
continue
tier = self._classify_branch(br, pool_winner, warm_enough)
if tier == "deviant":
br["disagree_rounds"] += 1
else:
br["disagree_rounds"] = 0
---- 一貫して逸脱した枝を放棄する(生存数を 2 以上維持) ----
abandoned_this: List[int] = []
if warm_enough and pool_winner is not None:
n_alive = sum(
1 for br in branches
if not br["abandoned"] and not br["finished"]
)
cands = sorted(
[
br for br in branches
if not br["abandoned"]
and not br["finished"]
and br["disagree_rounds"] >= self.abandon_patience
],
key=lambda b: -b["disagree_rounds"],
)
max_abandon = max(0, n_alive - 2)
for br in cands[:max_abandon]:
br["abandoned"] = True
abandoned_this.append(br["index"])
# ---- 優先度に基づく深さ配分 ----
active_brs = [
br for br in branches
if not br["abandoned"] and not br["finished"]
]
active_brs_sorted = sorted(active_brs, key=lambda b: -b["probe_count"])
probed_this: int = 0
for br in active_brs_sorted:
tier = self._classify_branch(br, pool_winner, warm_enough)
n_steps = self.burst_aligned if tier == "aligned" else 1
self._probe_branch(question, br, completed_answers, n_steps)
probed_this += n_steps
pool_winner, top1, top2, pool_conf = self._pool_stats(completed_answers)
n_complete = len(completed_answers)
ema_conf = self._update_ema(ema_conf, pool_conf)
if ema_history:
ema_history[-1] = ema_conf
if len(ema_history) >= 2:
ema_delta = ema_history[-1] - ema_history[0]
else:
ema_delta = 0.0
n_active = sum(
1 for br in branches if not br["abandoned"] and not br["finished"]
)
self._trace_step(
event="forward",
goal="優先スケジューリング付きプローブと EMA の更新",
step_input={
"outer_step": outer_step,
"pool_winner": pool_winner,
"pool_conf": round(pool_conf, 4),
},
step_output={
"n_complete": n_complete,
"n_active": n_active,
"probed_this": probed_this,
"ema_conf": round(ema_conf, 4),
"ema_delta": round(ema_delta, 4),
"abandoned_now": abandoned_this,
},
state={"total_spawned": total_spawned},
decision="モーメンタムゲートとウィドニングの評価",
)
# ---- EMA モーメンタム停止ゲート ----
gate_eligible = (
warm_enough
and n_complete >= self.min_complete
)
gate_fires = (
gate_eligible
and ema_conf >= self.conf_thresh
and ema_delta >= -self.delta_slack
)
self._trace_step(
event="terminate_check",
goal="EMA momentum gate evaluation",
step_input={
"outer_step": outer_step,
"conf_thresh": self.conf_thresh,
"delta_slack": self.delta_slack,
"min_complete": self.min_complete,
"warm_up": self.warm_up,
},
step_output={
"ema_conf": round(ema_conf, 4),
"ema_delta": round(ema_delta, 4),
"pool_conf": round(pool_conf, 4),
"n_complete": n_complete,
"gate_eligible": gate_eligible,
"gate_fires": gate_fires,
},
state={"total_spawned": total_spawned},
decision="stop if EMA gate fires",
)
if gate_fires:
self._trace_step(
event="finish",
goal="return final answer",
step_input={"outer_step": outer_step},
step_output={
"answer": pool_winner,
"stop_reason": "ema_momentum_gate",
"ema_conf": round(ema_conf, 4),
"ema_delta": round(ema_delta, 4),
"n_complete": n_complete,
},
state={"total_spawned": total_spawned},
decision="EMA level high + momentum non-negative",
)
return pool_winner
# ---- All branches resolved? ----
all_resolved = all(br["finished"] or br["abandoned"] for br in branches)
if all_resolved:
break
# ---- Confidence-trend widening ----
can_widen = (
total_spawned < self.max_branch_use
and total_spawned < self._MAX_BRANCH
)
trend_weak = ema_delta <= self.trend_thresh
want_widen = (
can_widen
and trend_weak
and outer_step >= max(1, self.warm_up // 2)
and ema_conf < self.conf_thresh
)
spawned_now = 0
if want_widen:
for _ in range(self.widen_burst):
if total_spawned >= self.max_branch_use:
break
if total_spawned >= self._MAX_BRANCH:
break
out = _safe_probe_new(question)
if out is None:
break
ans, idx, is_finish = out
total_spawned +=
原文を表示
LLMs Improving LLMs: Agentic Discovery for Test-Time Scaling
Tong Zheng, Haolin Liu, Chengsong Huang, Huiwen Bao, Sheng Zhang, Rui Liu, Runpeng Dai, Ruibo Chen, Chenxi Liu, Tianyi Xiong, Xidong Wu, Hongming Zhang, Heng Huang
*UMD · UVA · WUSTL · UNC · Google · Meta*
AutoTTS reframes TTS strategy design from hand-crafting heuristics to environment-driven automatic search: humans only construct an offline replay environment (states, actions, feedback, objectives), and a coding agent iteratively proposes and refines code-defined controllers within it — code edits, no gradient updates. Cheap: 0 LLM calls, fully replay.
Quick links: Install · Reproduction · Citation
Highlighted results
- ~69.5% tokens saved vs SC@64 at β ≈ 0.5; held-out average accuracy matches SC@64 across four backbone scales.
- $39.9 estimated monetary cost for one full discovery run.
- 160 minutes wall-clock for the same run.
- 0 LLM calls during discovery evaluation (replays cached segments only).
The discovered controller is the Confidence Momentum Controller (CMC), characterized by trend-based stopping, coupled width–depth control, alignment-aware depth allocation, and conservative branch abandonment.
Problem setup
We treat adaptive test-time inference as allocating a finite budget over branches in fixed-length intervals.
State at step t:
s_t = (q, m_t, I_t, ℓ_t, Ω_t)
q: question; m_t: number of instantiated branches; I_t: active branch set; ℓ_t: depth vector; Ω_t: revealed probe triples.
Admissible actions A(s_t):
- BRANCH — open a new branch through the first interval.
- CONTINUE(i) — advance branch i by one interval.
- PROBE(i) — reveal ω_{i,ℓ} without advancing depth.
- PRUNE(i) — deactivate branch i; depths and past probes stay recorded.
- ANSWER — terminate and apply the controller's terminal aggregator.
Cost in interval units:
Cost(s_t) = Σ_i ℓ_{t,i} + κ_probe · |Ω_t| (often κ_probe = 0)
Objective. A code-defined policy π(· | s, β) is parameterized by a scalar meta-parameter β that deterministically schedules every internal hyper-parameter. Over tasks (q, y) ~ 𝒟:
max_{π, β} E_{q,y}[ 1{ŷ_{π,β}(q) = y} − γ · C_{π,β}(q) ]
The outer loop searches over implementations of π. Each candidate is replay-evaluated on offline caches; traces and scaling curves enter the next round's history.
Environment construction (run once per (model, benchmark))
The MDP above is instantiated as a concrete replay environment before the discovery loop starts:
- Specify the interface. Fix s_t, A(s_t), Cost(s_t), and the accuracy–cost objective.
- Offline trajectory collection. For each query, draw N parallel independent reasoning traces from the backbone (full strings first), then partition each trace into fixed-length segments of Δ tokens and enumerate branch prefixes z_{i,k} with probe responses ω_{i,k}.
- Materialize the replay store. Every environment transition consults the archived table; e.g. PROBE(i) retrieves the cached ω_{i,k} without any new decoding.
- Hand off to discovery. Candidate controllers are simulated exclusively through observe/step. Asymptotic evaluation cost is dominated by table replay.
Steps 1–3 run once. Iterative coding-agent discovery starts only after the replay store is frozen.
In this repository:
- efficient_reasoning_controller/workspace/code_base/environment/ — search-set replay store.
- efficient_reasoning_controller/test_environment/ — held-out replay store; never exposed to the proposer.
Discovery: β parameterization & trace feedback
- β parameterization. Each candidate controller exports a single scalar β plus a deterministic, monotonic map from β to every internal knob. Outer search collapses to sweeping β, eliminating brittle thresholds tuned only to the search set.
- History augmentation with execution traces. Alongside each round's β-sweep we archive both empirical scaling curves and the full action-by-action trajectories reconstructed during replay. Traces give the explorer fine-grained behavioral evidence to localize defects before rewriting code.
Main results
AutoTTS is optimized on AIME24 replay constructions and evaluated on held-out AIME25 / HMMT25 benchmarks across four Qwen3 backbone scales. The project page reports the following trends:
- Better accuracy–token trade-offs. Discovered controllers typically shift the empirical Pareto frontier beyond handcrafted baselines such as SC@64, ASC, ESC, and Parallel-Probe.
- Held-out generalization. Policies discovered on AIME24 transfer to held-out benchmarks, outperforming every handcrafted baseline on average accuracy for three of four backbone scales and remaining competitive on Qwen3-8B.
- β = 0.5 operating point. Cuts aggregate token usage by roughly 69.5% compared with SC@64 while matching mean held-out accuracy across models.
- β = 1.0 operating point. Pushes peak accuracy beyond all handcrafted baselines in five of the eight tabulated comparison cells on the project page.
Sweeping β traces accuracy–token scaling curves: larger β generally moves toward higher-budget, accuracy-first behavior, while smaller β favors cheaper inference.
Evolution of the discovery process
The round-level trajectory (e.g., t1 -> t5 in the figure above) shows a consistent move toward better objective values over the search process:
- On the search benchmark, later rounds improve accuracy while keeping token growth controlled, indicating progressively better policy structure rather than random fluctuation.
- On held-out benchmarks, the same trajectory remains competitive and often improves, suggesting that the discovered control logic transfers beyond the optimization split.
- The trajectory reflects objective-seeking code evolution without gradient updates: the agent edits explicit controller programs, receives replay-based accuracy/cost feedback, and iteratively shifts behavior toward better empirical trade-offs.
This is a key point of AutoTTS: optimization is achieved through iterative program search in a fixed replay environment, not through backpropagation or parameter fine-tuning of the backbone model.
Discovered controller: CMC
The discovered controller is named the Confidence Momentum Controller (CMC). Its main mechanisms are:
- Trend-based stopping. CMC maintains an exponential moving average of pool confidence and stops only when the confidence level is high and the trend is non-negative. This avoids stopping on transient confidence spikes.
- Coupled width–depth control. Widening and deepening are linked through the EMA delta: strong confidence gains suppress new branch spawning, while stagnation or regression triggers widening.
- Alignment-aware depth allocation. Branches whose latest answer matches the pool winner receive extra probe steps, concentrating compute on the emerging consensus while still advancing active branches.
- Conservative branch abandonment. A branch is abandoned only after persistently deviating for multiple rounds, and at least two active branches are preserved.
These mechanisms are implemented as code-defined controller logic and evaluated through the same replay environment as the handcrafted baselines.
Show full OptimalController source (CMC, click to expand)
class OptimalController(LLMDesignedMethod):
"""
Confidence Momentum Controller (CMC).
Core idea
All prior proposals (IBC, SCR, DGCC) share the same fundamental stopping
signal: "instantaneous" Beta-majority confidence computed from the
completed-answer pool at the current step. This is susceptible to
single-step confidence spikes: a lucky early cluster of identical answers
can fire the gate prematurely before the distribution has stabilised.
CMC replaces the instantaneous confidence gate with a momentum-aware
gate:
- Track an exponential moving average (EMA) of pool confidence over
the last T_ema rounds: ema_conf = alpha * conf + (1 - alpha) * ema_conf
- Track the recent improvement delta: delta = ema_conf - ema_conf_prev
- Gate fires when BOTH of the following hold:
(a) ema_conf >= conf_thresh (level requirement)
(b) delta >= -slack (non-deteriorating momentum; slack is
a small tolerance that prevents stopping on a declining signal)
This means the controller cannot stop on a one-round spike; the EMA
must be high and not actively falling.
Adaptive depth allocation via probe-age priority
Each active unfinished branch tracks probe_count (how many probe steps
it has received). In each round the controller allocates a per-round
probe budget of probe_budget steps distributed across active branches
using a priority queue sorted by probe_count descending. The most-
invested branches get served first (up to burst_senior extra steps
each), then remaining budget goes to less-invested branches.
This concentrates depth on branches that are closest to completion while
still advancing younger branches, rather than uniform or purely aligned-
biased allocation (SCR) or lazy sleeping (DGCC).
Three-tier branch classification
After warm_up:
- "aligned": latest answer == pool_winner
- "deviant": latest answer != pool_winner, disagreed for >= 1 round
- "neutral": no pool winner yet, or first round of disagreement
Tier affects the per-branch probe multiplier:
aligned -> multiplier = burst_aligned (e.g. 2 at high beta)
neutral -> multiplier = 1
deviant -> multiplier = 1, but if deviant for >= abandon_patience
rounds the branch is abandoned
Confidence-trend widening
Widening (spawning new branches) is driven by whether the confidence
*trend* (delta) is positive and large, or weak/negative:
- if delta > trend_thresh: confidence is accelerating -> no widening
(we're on track to stop soon)
- if delta <= trend_thresh: plateau or regression -> widen by
widen_burst new branches, up to max_branch ceiling
This directly couples width decision to whether deepening is yielding
evidence-quality gains, a feedback loop not present in prior proposals.
Beta schedule
All hyperparameters are deterministic functions of a single beta in [0,1].
beta=0 -> conservative (few branches, low EMA inertia, easier to stop)
beta=1 -> near-full budget (many branches, high inertia, harder to stop)
Novelty vs prior work
ASC / ESC: full reads; no incremental probing.
Parallel_Probe: fixed cohort; instantaneous majority; no pool/completion
distinction; no EMA.
IBC (r0001): instantaneous pool confidence gate; uniform 1-step probing;
1-branch-per-round widening; no EMA or trend.
SCR (r0002): asymmetric burst (aligned gets more steps); plateau-triggered
widening; instantaneous gate; no EMA.
DGCC (r0003): dual instantaneous gate (primary + soft corroboration);
lazy sleeping for locked branches; vote-gap proportional widening;
no EMA momentum.
CMC: replaces ALL instantaneous gates with a single EMA momentum gate;
introduces probe-age priority scheduling (neither uniform nor burst-
aligned-only); confidence-trend widening (neither plateau nor vote-gap);
three-tier classification is a natural simplification vs DGCC's dual
gate without adding extra hyperparameters.
"""
NAME = "optimal_controller"
_MAX_BRANCH = 64
_MAX_OUTER = 500
def _schedule(self, beta: float) -> dict:
"""
All schedules are smooth analytic functions of beta in [0,1].
Monotonicity:
- Parameters controlling budget use (n_init, max_branch_use,
burst_aligned, widen_burst, warm_up, abandon_patience, T_ema)
are NON-DECREASING in beta.
- conf_thresh is NON-DECREASING in beta (harder to stop -> more budget).
- trend_thresh is NON-INCREASING in beta (easier to trigger widening
at high beta -> more budget via wider exploration).
- ema_alpha is NON-INCREASING in beta (lower alpha = slower EMA =
more inertia = more budget at high beta).
"""
b = max(0.0, min(1.0, float(beta)))
n_init = max(2, round(2 + 6 * b))
max_branch_use = min(self._MAX_BRANCH, round(4 + 60 * b))
warm_up = max(2, round(2 + 8 * b))
abandon_patience = max(3, round(3 + 9 * b))
T_ema = max(2, round(2 + 6 * b))
ema_alpha = 0.70 - 0.40 * b
conf_thresh = 0.85 + 0.12 * b
delta_slack = 0.04 - 0.03 * b
burst_aligned = max(1, round(1 + 2 * b))
widen_burst = max(1, round(1 + 3 * b))
trend_thresh = 0.04 - 0.03 * b
min_complete = max(2, round(2 + 3 * b))
return {
"n_init": n_init,
"max_branch_use": max_branch_use,
"warm_up": warm_up,
"abandon_patience": abandon_patience,
"T_ema": T_ema,
"ema_alpha": round(ema_alpha, 4),
"conf_thresh": round(conf_thresh, 4),
"delta_slack": round(delta_slack, 4),
"burst_aligned": burst_aligned,
"widen_burst": widen_burst,
"trend_thresh": round(trend_thresh, 4),
"min_complete": min_complete,
}
def __init__(self, config: Optional[Dict[str, Any]] = None):
super().__init__(config)
self._beta = float((config or {}).get("beta", 0.5))
sched = self._schedule(self._beta)
self.n_init = sched["n_init"]
self.max_branch_use = sched["max_branch_use"]
self.warm_up = sched["warm_up"]
self.abandon_patience = sched["abandon_patience"]
self.T_ema = sched["T_ema"]
self.ema_alpha = sched["ema_alpha"]
self.conf_thresh = sched["conf_thresh"]
self.delta_slack = sched["delta_slack"]
self.burst_aligned = sched["burst_aligned"]
self.widen_burst = sched["widen_burst"]
self.trend_thresh = sched["trend_thresh"]
self.min_complete = sched["min_complete"]
self.trace_recorder = MethodTraceRecorder()
def _reset_trace(self) -> None:
self.trace_recorder = MethodTraceRecorder()
def _trace_step(
self,
*,
event: str,
goal: str,
step_input: Dict[str, Any],
step_output: Any,
state: Dict[str, Any],
decision: str,
) -> None:
self.trace_recorder.add_step(
event=event,
goal=goal,
input=step_input,
output=step_output,
state=state,
decision=decision,
)
def get_last_trace(self) -> List[Dict[str, Any]]:
return self.trace_recorder.to_list()
def solve_with_trace(self, question) -> Dict[str, Any]:
answer = self.solve(question)
return {"answer": answer, "trace": self.get_last_trace()}
def _pool_stats(self, completed: List[str]):
"""(winner, top1, top2, conf) over completed-answer pool."""
if not completed:
return None, 0, 0, 0.0
winner, top1, top2, _ = _vote_stats(completed)
conf = _beta_majority_confidence(top1, top2)
return winner, top1, top2, conf
def _update_ema(self, ema_prev: float, new_val: float) -> float:
"""EMA update: ema = (1 - alpha) * ema_prev + alpha * new_val."""
return (1.0 - self.ema_alpha) * ema_prev + self.ema_alpha * new_val
def _classify_branch(
self,
br: Dict[str, Any],
pool_winner,
warm_enough: bool,
) -> str:
if not warm_enough or pool_winner is None:
return "neutral"
if br["latest_ans"] == pool_winner:
return "aligned"
return "deviant"
def _probe_branch(
self,
question,
br: Dict[str, Any],
completed_answers: List[str],
n_steps: int,
) -> None:
"""Probe branch br for up to n_steps steps; record completions."""
for _ in range(n_steps):
if br["finished"]:
break
out = _safe_probe_more(question, br["index"])
if out is None:
br["finished"] = True
if br["latest_ans"] is not None:
completed_answers.append(br["latest_ans"])
break
new_ans, is_finish = out
br["probe_count"] += 1
br["latest_ans"] = new_ans
br["finished"] = is_finish
if is_finish:
completed_answers.append(new_ans)
break
def solve(self, question) -> Optional[str]:
self._reset_trace()
self._trace_step(
event="start",
goal="initialize CMC run",
step_input={"beta": self._beta},
step_output="initialized",
state={
"n_init": self.n_init,
"max_branch_use": self.max_branch_use,
"warm_up": self.warm_up,
"abandon_patience": self.abandon_patience,
"T_ema": self.T_ema,
"ema_alpha": self.ema_alpha,
"conf_thresh": self.conf_thresh,
"delta_slack": self.delta_slack,
"burst_aligned": self.burst_aligned,
"widen_burst": self.widen_burst,
"trend_thresh": self.trend_thresh,
"min_complete": self.min_complete,
},
decision="start confidence momentum controller",
)
Branch state:
index : stable branch_index from probe_new
latest_ans : current answer (intermediate or final)
finished : bool — branch exhausted its full budget
abandoned : bool — dropped due to persistent deviance
probe_count : number of probe_more steps received
disagree_rounds: consecutive rounds where answer != pool_winner
branches: List[Dict[str, Any]] = []
completed_answers: List[str] = []
total_spawned = 0
---- Phase 0: open n_init branches ----
for _ in range(self.n_init):
out = _safe_probe_new(question)
if out is None:
break
ans, idx, is_finish = out
total_spawned += 1
br: Dict[str, Any] = {
"index": idx,
"latest_ans": ans,
"finished": is_finish,
"abandoned": False,
"probe_count": 0,
"disagree_rounds": 0,
}
branches.append(br)
if is_finish:
completed_answers.append(ans)
self._trace_step(
event="init_branches",
goal="open initial branch batch",
step_input={"n_init": self.n_init},
step_output={
"n_spawned": total_spawned,
"n_completed": len(completed_answers),
},
state={"total_spawned": total_spawned},
decision="proceed to main loop",
)
if not branches:
self._trace_step(
event="finish",
goal="return final answer",
step_input={},
step_output={"answer": None, "stop_reason": "no_branches"},
state={"total_spawned": 0},
decision="no branches available",
)
return None
EMA state — initialised to 0 (no evidence yet)
ema_conf = 0.0
ema_conf_prev = 0.0
ema_history: List[float] = []
outer_step = 0
while outer_step < self._MAX_OUTER:
---- Compute current pool stats ----
pool_winner, top1, top2, pool_conf = self._pool_stats(completed_answers)
n_complete = len(completed_answers)
warm_enough = (outer_step >= self.warm_up)
---- Update EMA ----
ema_conf_prev = ema_conf
ema_conf = self._update_ema(ema_conf, pool_conf)
ema_history.append(ema_conf)
if len(ema_history) > self.T_ema:
ema_history.pop(0)
if len(ema_history) >= 2:
ema_delta = ema_history[-1] - ema_history[0]
else:
ema_delta = 0.0
---- Classify branches and update disagree_rounds ----
if warm_enough and pool_winner is not None:
for br in branches:
if br["abandoned"] or br["finished"]:
continue
tier = self._classify_branch(br, pool_winner, warm_enough)
if tier == "deviant":
br["disagree_rounds"] += 1
else:
br["disagree_rounds"] = 0
---- Abandon persistently deviant branches (keep >= 2 alive) ----
abandoned_this: List[int] = []
if warm_enough and pool_winner is not None:
n_alive = sum(
1 for br in branches
if not br["abandoned"] and not br["finished"]
)
cands = sorted(
[
br for br in branches
if not br["abandoned"]
and not br["finished"]
and br["disagree_rounds"] >= self.abandon_patience
],
key=lambda b: -b["disagree_rounds"],
)
max_abandon = max(0, n_alive - 2)
for br in cands[:max_abandon]:
br["abandoned"] = True
abandoned_this.append(br["index"])
---- Prioritised depth allocation ----
active_brs = [
br for br in branches
if not br["abandoned"] and not br["finished"]
]
active_brs_sorted = sorted(active_brs, key=lambda b: -b["probe_count"])
probed_this: int = 0
for br in active_brs_sorted:
tier = self._classify_branch(br, pool_winner, warm_enough)
n_steps = self.burst_aligned if tier == "aligned" else 1
self._probe_branch(question, br, completed_answers, n_steps)
probed_this += n_steps
pool_winner, top1, top2, pool_conf = self._pool_stats(completed_answers)
n_complete = len(completed_answers)
ema_conf = self._update_ema(ema_conf, pool_conf)
if ema_history:
ema_history[-1] = ema_conf
if len(ema_history) >= 2:
ema_delta = ema_history[-1] - ema_history[0]
else:
ema_delta = 0.0
n_active = sum(
1 for br in branches if not br["abandoned"] and not br["finished"]
)
self._trace_step(
event="forward",
goal="probe with priority scheduling + update EMA",
step_input={
"outer_step": outer_step,
"pool_winner": pool_winner,
"pool_conf": round(pool_conf, 4),
},
step_output={
"n_complete": n_complete,
"n_active": n_active,
"probed_this": probed_this,
"ema_conf": round(ema_conf, 4),
"ema_delta": round(ema_delta, 4),
"abandoned_now": abandoned_this,
},
state={"total_spawned": total_spawned},
decision="evaluate momentum gate and widening",
)
---- EMA momentum stopping gate ----
gate_eligible = (
warm_enough
and n_complete >= self.min_complete
)
gate_fires = (
gate_eligible
and ema_conf >= self.conf_thresh
and ema_delta >= -self.delta_slack
)
self._trace_step(
event="terminate_check",
goal="EMA momentum gate evaluation",
step_input={
"outer_step": outer_step,
"conf_thresh": self.conf_thresh,
"delta_slack": self.delta_slack,
"min_complete": self.min_complete,
"warm_up": self.warm_up,
},
step_output={
"ema_conf": round(ema_conf, 4),
"ema_delta": round(ema_delta, 4),
"pool_conf": round(pool_conf, 4),
"n_complete": n_complete,
"gate_eligible": gate_eligible,
"gate_fires": gate_fires,
},
state={"total_spawned": total_spawned},
decision="stop if EMA gate fires",
)
if gate_fires:
self._trace_step(
event="finish",
goal="return final answer",
step_input={"outer_step": outer_step},
step_output={
"answer": pool_winner,
"stop_reason": "ema_momentum_gate",
"ema_conf": round(ema_conf, 4),
"ema_delta": round(ema_delta, 4),
"n_complete": n_complete,
},
state={"total_spawned": total_spawned},
decision="EMA level high + momentum non-negative",
)
return pool_winner
---- All branches resolved? ----
all_resolved = all(br["finished"] or br["abandoned"] for br in branches)
if all_resolved:
break
---- Confidence-trend widening ----
can_widen = (
total_spawned < self.max_branch_use
and total_spawned < self._MAX_BRANCH
)
trend_weak = ema_delta <= self.trend_thresh
want_widen = (
can_widen
and trend_weak
and outer_step >= max(1, self.warm_up // 2)
and ema_conf < self.conf_thresh
)
spawned_now = 0
if want_widen:
for _ in range(self.widen_burst):
if total_spawned >= self.max_branch_use:
break
if total_spawned >= self._MAX_BRANCH:
break
out = _safe_probe_new(question)
if out is None:
break
ans, idx, is_finish = out
total_spawned +=
関連記事
今日のまとめ
AI日報で今日の重要ニュースをまとめ読み