RLax、JAX、Haiku、Optaxを使用してスクラッチからDeep Q-Learning(DQN)を実装し、CartPole強化学習エージェントを訓練する
Google DeepMindが開発した研究志向ライブラリRLaxをJAX、Haiku、Optaxと組み合わせ、CartPole環境を解くDeep Q-Learningエージェントをスクラッチ実装するチュートリアル記事である。
キーポイント
RLaxを活用したカスタムRLパイプライン構築
Google DeepMindが開発した研究志向ライブラリRLaxをJAX、Haiku、Optaxと組み合わせ、完全にパッケージ化されたRLフレームワークではなく、カスタム強化学習パイプラインを構築する方法を解説している。
DQNエージェントのスクラッチ実装
ニューラルネットワークの定義、リプレイバッファの構築、RLaxを用いた時間差分誤差の計算、勾配ベース最適化によるエージェントのトレーニングなど、DQNのコアコンポーネントを詳細に実装している。
CartPole環境での実践的学習
強化学習の古典的なベンチマーク環境であるCartPole-v1を対象に、実装したDQNエージェントが実際に学習し問題を解決するまでのプロセスを示している。
JAXエコシステムの統合活用
効率的な数値計算のためのJAX、ニューラルネットワークモデリングのためのHaiku、最適化のためのOptaxを統合的に活用する方法を実践的に示している。
リプレイバッファの実装とサンプリング
リプレイバッファからミニバッチをサンプリングし、各要素を適切なデータ型に変換して返す関数を実装している。これにより、過去の経験を効率的に学習に活用できる。
ε-greedy探索戦略の実装
フレーム数に基づいてε値を減衰させ、確率εでランダム行動、それ以外でQネットワークに基づく最適行動を選択する関数を定義している。
JAXを用いた効率的な学習関数の実装
TD誤差の計算、Huber損失関数による損失計算、勾配更新をJAXのjitコンパイルと自動微分を活用して効率的に実装している。
影響分析・編集コメントを表示
影響分析
この記事は、強化学習の基礎を理解したい研究者や開発者にとって実践的な教育リソースとして価値があり、特にJAXエコシステムを活用した最新の実装方法を示している点で意義がある。ただし、技術的には既存の手法の応用であり、画期的な新規性は限定的である。
編集コメント
強化学習の基礎をJAXエコシステムで実装する実践的チュートリアルとして、教育・研究用途では有用だが、技術的革新性は限定的。
このチュートリアルでは、Google DeepMindによって開発された研究指向のライブラリであるRLaxを使用して、強化学習エージェントを実装します。RLaxをJAX、Haiku、Optaxと組み合わせて、CartPole環境を解くことを学ぶDeep Q-Learning (DQN) エージェントを構築します。完全にパッケージ化されたRLフレームワークを使用する代わりに、強化学習のコアコンポーネントがどのように相互作用するかを明確に理解できるように、トレーニングパイプラインを自分で組み立てます。ニューラルネットワークを定義し、リプレイバッファを構築し、RLaxで時間的差分誤差を計算し、勾配ベースの最適化を使用してエージェントをトレーニングします。また、RLaxがカスタム強化学習パイプラインに統合可能な再利用可能なRLプリミティブをどのように提供するかを理解することに焦点を当てます。効率的な数値計算にはJAXを、ニューラルネットワークモデリングにはHaikuを、最適化にはOptaxを使用します。
Copy CodeCopiedUse a different Browser
!pip -q install "jax[cpu]" dm-haiku optax rlax gymnasium matplotlib numpy
import os
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
import random
import time
from dataclasses import dataclass
from collections import deque
import gymnasium as gym
import haiku as hk
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import optax
import rlax
seed = 42
random.seed(seed)
np.random.seed(seed)
env = gym.make("CartPole-v1")
eval_env = gym.make("CartPole-v1")
obs_dim = env.observation_space.shape[0]
num_actions = env.action_space.n
def q_network(x):
mlp = hk.Sequential([
hk.Linear(128), jax.nn.relu,
hk.Linear(128), jax.nn.relu,
hk.Linear(num_actions),
])
return mlp(x)
q_net = hk.without_apply_rng(hk.transform(q_network))
dummy_obs = jnp.zeros((1, obs_dim), dtype=jnp.float32)
rng = jax.random.PRNGKey(seed)
params = q_net.init(rng, dummy_obs)
target_params = params
optimizer = optax.chain(
optax.clip_by_global_norm(10.0),
optax.adam(3e-4),
)
opt_state = optimizer.init(params)
必要なライブラリをインストールし、強化学習パイプラインに必要なすべてのモジュールをインポートします。環境を初期化し、Haikuを使用してニューラルネットワークアーキテクチャを定義し、行動価値を予測するQネットワークをセットアップします。また、ネットワークとターゲットネットワークのパラメータ、およびトレーニング中に使用されるオプティマイザを初期化します。
Copy CodeCopiedUse a different Browser
@dataclass
class Transition:
obs: np.ndarray
action: int
reward: float
discount: float
next_obs: np.ndarray
done: float
class ReplayBuffer:
def __init__(self, capacity):
self.buffer = deque(maxlen=capacity)
def add(self, *args):
self.buffer.append(Transition(*args))
def sample(self, batch_size):
batch = random.sample(self.buffer, batch_size)
obs = np.stack([t.obs for t in batch]).astype(np.float32)
action = np.array([t.action for t in batch], dtype=np.int32)
reward = np.array([t.reward for t in batch], dtype=np.float32)
discount = np.array([t.discount for t in batch], dtype=np.float32)
next_obs = np.stack([t.next_obs for t in batch]).astype(np.float32)
done = np.array([t.done for t in batch], dtype=np.float32)
return {
"obs": obs,
"action": action,
"reward": reward,
"discount": discount,
"next_obs": next_obs,
"done": done,
}
def __len__(self):
return len(self.buffer)
replay = ReplayBuffer(capacity=50000)
def epsilon_by_frame(frame_idx, eps_start=1.0, eps_end=0.05, decay_frames=20000):
mix = min(frame_idx / decay_frames, 1.0)
return eps_start + mix * (eps_end - eps_start)
def select_action(params, obs, epsilon):
if random.random() < epsilon:
return env.action_space.sample()
q_values = q_net.apply(params, obs[None, :])
return int(jnp.argmax(q_values[0]))
遷移構造を定義し、環境からの過去の経験を保存するためのリプレイバッファを実装します。遷移を追加し、後にエージェントのトレーニングに使用されるミニバッチをサンプリングするための関数を作成します。また、イプシロングリーディ探索戦略も実装します。
Copy CodeCopiedUse a different Browser
@jax.jit
def soft_update(target_params, online_params, tau):
return jax.tree_util.tree_map(lambda t, s: (1.0 - tau) * t + tau * s, target_params, online_params)
def batch_td_errors(params, target_params, batch):
q_tm1 = q_net.apply(params, batch["obs"])
q_t = q_net.apply(target_params, batch["next_obs"])
td_errors = jax.vmap(
lambda q1, a, r, d, q2: rlax.q_learning(q1, a, r, d, q2)
)(q_tm1, batch["action"], batch["reward"], batch["discount"], q_t)
return td_errors
@jax.jit
def train_step(params, target_params, opt_state, batch):
def loss_fn(p):
td_errors = batch_td_errors(p, target_params, batch)
loss = jnp.mean(rlax.huber_loss(td_errors, delta=1.0))
metrics = {
"loss": loss,
"td_abs_mean": jnp.mean(jnp.abs(td_errors)),
"q_mean": jnp.mean(q_net.apply(p, batch["obs"])),
}
return loss, metrics
(loss, metrics), grads = jax.value_and_grad(loss_fn, has_aux=True)(params)
updates, opt_state = optimizer.update(grads, opt_state, params)
params = optax.apply_updates(params, updates)
return params, opt_state, metrics
トレーニング中に使用されるコア学習関数を定義します。RLaxのQ学習プリミティブを使用して時間的差分誤差を計算し、Huber損失関数を使用して損失を計算します。次に、勾配を計算し、オプティマイザの更新を適用し、トレーニングメトリクスを返すトレーニングステップを実装します。
Copy CodeCopiedUse a different Browser
def evaluate_agent(params, episodes=5):
returns = []
for ep in range(episodes):
obs, _ = eval_env.reset(seed=seed + 1000 + ep)
done = False
truncated = False
total_reward = 0.0
while not (done or truncated):
q_values = q_net.apply(params, obs[None, :])
action = int(jnp.argmax(q_values[0]))
next_obs, reward, done, truncated, _ = eval_env.step(action)
total_reward += reward
obs = next_obs
returns.append(total_reward)
return float(np.mean(returns))
num_frames = 40000
batch_size = 128
warmup_steps = 1000
train_every = 4
eval_every = 2000
gamma = 0.99
tau = 0.01
max_grad_updates_per_step = 1
obs, _ = env.reset(seed=seed)
episode_return = 0.0
episode_returns = []
eval_returns = []
losses = []
td_means = []
q_means = []
eval_steps = []
start_time = time.time()
エージェントのパフォーマンスを測定する評価関数を定義します。フレーム数、バッチサイズ、割引率、ターゲットネットワーク更新率を含むトレーニングハイパーパラメータを設定します。また、エピソードリターン、損失、評価メトリクスを含むトレーニング統計を追跡する変数を初期化します。
Copy CodeCopiedUse a different Browser
for frame_idx in range(1, num_frames + 1):
epsilon = epsilon_by_frame(frame_idx)
action = select_action(params, obs.astype(np.float32), epsilon)
next_obs, reward, done, truncated, _ = env.step(action)
terminal = done or truncated
discount = 0.0 if terminal else gamma
replay.add(
obs.astype(np.float32),
action,
float(reward),
float(discount),
next_obs.astype(np.float32),
float(terminal),
)
obs = next_obs
episode_return += reward
if terminal:
episode_returns.append(episode_return)
obs, _ = env.reset()
episode_return = 0.0
if len(replay) >= warmup_steps and frame_idx % train_every == 0:
for _ in range(max_grad_updates_per_step):
batch_np = replay.sample(batch_size)
batch = {k: jnp.asarray(v) for k, v in batch_np.items()}
params, opt_state, metrics = train_step(params, target_params, opt_state, batch)
target_params = soft_update(target_params, params, tau)
losses.append(float(metrics["loss"]))
td_means.append(float(metrics["td_abs_mean"]))
q_means.append(float(metrics["q_mean"]))
if frame_idx % eval_every == 0:
avg_eval_return = evaluate_agent(params, episodes=5)
eval_returns.append(avg_eval_return)
eval_steps.append(frame_idx)
recent_train = np.mean(episode_returns[-10:]) if episode_returns else 0.0
recent_loss = np.mean(losses[-100:]) if losses else 0.0
print(
f"step={frame_idx:6d} | epsilon={epsilon:.3f} | "
f"recent_train_return={recent_train:7.2f} | "
f"eval_return={avg_eval_return:7.2f} | "
f"recent_loss={recent_loss:.5f} | buffer={len(replay)}"
)
elapsed = time.time() - start_time
final_eval = evaluate_agent(params, episodes=10)
print("\nTraining complete")
print(f"Elapsed time: {elapsed:.1f} seconds")
print(f"Final 10-episode evaluation return: {final_eval:.2f}")
plt.figure(figsize=(14, 4))
plt.subplot(1, 3, 1)
plt.plot(episode_returns)
plt.title("Training Episode Returns")
plt.xlabel("Episode")
plt.ylabel("Return")
plt.subplot(1, 3, 2)
plt.plot(eval_steps, eval_returns)
plt.title("Evaluation Returns")
plt.xlabel("Environment Steps")
plt.ylabel("Avg Return")
plt.subplot(1, 3, 3)
plt.plot(losses, label="Loss")
plt.plot(td_means, label="|TD Error| Mean")
plt.title("Optimization Metrics")
plt.xlabel("Gradient Updates")
plt.legend()
plt.tight_layout()
plt.show()
obs, _ = eval_env.reset(seed=999)
frames = []
done = False
truncated = False
total_reward = 0.0
render_env = gym.make("CartPole-v1", render_mode="rgb_array")
obs, _ = render_env.reset(seed=999)
while not (done or truncated):
frame = render_env.render()
frames.append(frame)
q_values = q_net.apply(params, obs[None, :])
action = int(jnp.argmax(q_values[0]))
obs, reward, done, truncated, _ = render_env.step(action)
total_reward += reward
render_env.close()
print(f"Demo episode return: {total_reward:.2f}")
try:
import matplotlib.animation as animation
from IPython.display import HTML, display
fig = plt.figure(figsize=(6, 4))
patch = plt.imshow(frames[0])
plt.axis("off")
def animate(i):
patch.set_data(frames[i])
return (patch,)
anim = animation.FuncAnimation(fig, animate, frames=len(frames), interval=30, blit=True)
display(HTML(anim.to_jshtml()))
plt.close(fig)
except Exception as e:
print("アニメーション表示をスキップしました:", e)
完全な強化学習のトレーニングループを実行します。定期的にネットワークパラメータを更新し、エージェントのパフォーマンスを評価し、可視化のためのメトリクスを記録します。また、トレーニング結果をプロットし、トレーニング済みエージェントの動作を観察するためにデモンストレーションエピソードをレンダリングします。
結論として、RLaxと現代のJAXベースの機械学習エコシステムを組み合わせて、完全なDeep Q-Learningエージェントを構築しました。行動価値を推定するニューラルネットワークを設計し、学習を安定化するための経験再生を実装し、RLaxのQ学習プリミティブを使用してTD誤差を計算しました。トレーニング中、勾配ベースの最適化を使用してネットワークパラメータを更新し、定期的にエージェントを評価してパフォーマンスの向上を追跡しました。また、RLaxが完全なアルゴリズムではなく再利用可能なアルゴリズムコンポーネントを提供することで、強化学習へのモジュール化されたアプローチを可能にする方法を見ました。この柔軟性により、異なるアーキテクチャ、学習ルール、最適化戦略を簡単に実験できます。この基盤を拡張することで、同じRLaxプリミティブを使用して、Double DQN、分布型強化学習モデル、アクター・クリティック法などのより高度なエージェントを構築できます。
完全なノートブックはこちらでご覧ください。また、Twitterでフォローするか、120k以上のMLサブレディットに参加し、ニュースレターを購読してください。待って!テレグラムを使っていますか?今すぐテレグラムでも参加できます。
この投稿「RLax、JAX、Haiku、Optaxを使用してCartPole強化学習エージェントを訓練するためのDeep Q-Learning(DQN)のスクラッチ実装」は、MarkTechPostで最初に公開されました。
原文を表示
In this tutorial, we implement a reinforcement learning agent using RLax, a research-oriented library developed by Google DeepMind for building reinforcement learning algorithms with JAX. We combine RLax with JAX, Haiku, and Optax to construct a Deep Q-Learning (DQN) agent that learns to solve the CartPole environment. Instead of using a fully packaged RL framework, we assemble the training pipeline ourselves so we can clearly understand how the core components of reinforcement learning interact. We define the neural network, build a replay buffer, compute temporal difference errors with RLax, and train the agent using gradient-based optimization. Also, we focus on understanding how RLax provides reusable RL primitives that can be integrated into custom reinforcement learning pipelines. We use JAX for efficient numerical computation, Haiku for neural network modeling, and Optax for optimization.
Copy CodeCopiedUse a different Browser
!pip -q install "jax[cpu]" dm-haiku optax rlax gymnasium matplotlib numpy
import os
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
import random
import time
from dataclasses import dataclass
from collections import deque
import gymnasium as gym
import haiku as hk
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import optax
import rlax
seed = 42
random.seed(seed)
np.random.seed(seed)
env = gym.make("CartPole-v1")
eval_env = gym.make("CartPole-v1")
obs_dim = env.observation_space.shape[0]
num_actions = env.action_space.n
def q_network(x):
mlp = hk.Sequential([
hk.Linear(128), jax.nn.relu,
hk.Linear(128), jax.nn.relu,
hk.Linear(num_actions),
])
return mlp(x)
q_net = hk.without_apply_rng(hk.transform(q_network))
dummy_obs = jnp.zeros((1, obs_dim), dtype=jnp.float32)
rng = jax.random.PRNGKey(seed)
params = q_net.init(rng, dummy_obs)
target_params = params
optimizer = optax.chain(
optax.clip_by_global_norm(10.0),
optax.adam(3e-4),
)
opt_state = optimizer.init(params)
We install the required libraries and import all the modules needed for the reinforcement learning pipeline. We initialize the environment, define the neural network architecture using Haiku, and set up the Q-network that predicts action values. We also initialize the network and target network parameters, as well as the optimizer to be used during training.
Copy CodeCopiedUse a different Browser
@dataclass
class Transition:
obs: np.ndarray
action: int
reward: float
discount: float
next_obs: np.ndarray
done: float
class ReplayBuffer:
def __init__(self, capacity):
self.buffer = deque(maxlen=capacity)
def add(self, *args):
self.buffer.append(Transition(*args))
def sample(self, batch_size):
batch = random.sample(self.buffer, batch_size)
obs = np.stack([t.obs for t in batch]).astype(np.float32)
action = np.array([t.action for t in batch], dtype=np.int32)
reward = np.array([t.reward for t in batch], dtype=np.float32)
discount = np.array([t.discount for t in batch], dtype=np.float32)
next_obs = np.stack([t.next_obs for t in batch]).astype(np.float32)
done = np.array([t.done for t in batch], dtype=np.float32)
return {
"obs": obs,
"action": action,
"reward": reward,
"discount": discount,
"next_obs": next_obs,
"done": done,
}
def __len__(self):
return len(self.buffer)
replay = ReplayBuffer(capacity=50000)
def epsilon_by_frame(frame_idx, eps_start=1.0, eps_end=0.05, decay_frames=20000):
mix = min(frame_idx / decay_frames, 1.0)
return eps_start + mix * (eps_end - eps_start)
def select_action(params, obs, epsilon):
if random.random() < epsilon:
return env.action_space.sample()
q_values = q_net.apply(params, obs[None, :])
return int(jnp.argmax(q_values[0]))
We define the transition structure and implement a replay buffer to store past experiences from the environment. We create functions to add transitions and sample mini-batches that will later be used to train the agent. We also implement the epsilon-greedy exploration strategy.
Copy CodeCopiedUse a different Browser
@jax.jit
def soft_update(target_params, online_params, tau):
return jax.tree_util.tree_map(lambda t, s: (1.0 - tau) * t + tau * s, target_params, online_params)
def batch_td_errors(params, target_params, batch):
q_tm1 = q_net.apply(params, batch["obs"])
q_t = q_net.apply(target_params, batch["next_obs"])
td_errors = jax.vmap(
lambda q1, a, r, d, q2: rlax.q_learning(q1, a, r, d, q2)
)(q_tm1, batch["action"], batch["reward"], batch["discount"], q_t)
return td_errors
@jax.jit
def train_step(params, target_params, opt_state, batch):
def loss_fn(p):
td_errors = batch_td_errors(p, target_params, batch)
loss = jnp.mean(rlax.huber_loss(td_errors, delta=1.0))
metrics = {
"loss": loss,
"td_abs_mean": jnp.mean(jnp.abs(td_errors)),
"q_mean": jnp.mean(q_net.apply(p, batch["obs"])),
}
return loss, metrics
(loss, metrics), grads = jax.value_and_grad(loss_fn, has_aux=True)(params)
updates, opt_state = optimizer.update(grads, opt_state, params)
params = optax.apply_updates(params, updates)
return params, opt_state, metrics
We define the core learning functions used during training. We compute temporal difference errors using RLax’s Q-learning primitive and calculate the loss using the Huber loss function. We then implement the training step that computes gradients, applies optimizer updates, and returns training metrics.
Copy CodeCopiedUse a different Browser
def evaluate_agent(params, episodes=5):
returns = []
for ep in range(episodes):
obs, _ = eval_env.reset(seed=seed + 1000 + ep)
done = False
truncated = False
total_reward = 0.0
while not (done or truncated):
q_values = q_net.apply(params, obs[None, :])
action = int(jnp.argmax(q_values[0]))
next_obs, reward, done, truncated, _ = eval_env.step(action)
total_reward += reward
obs = next_obs
returns.append(total_reward)
return float(np.mean(returns))
num_frames = 40000
batch_size = 128
warmup_steps = 1000
train_every = 4
eval_every = 2000
gamma = 0.99
tau = 0.01
max_grad_updates_per_step = 1
obs, _ = env.reset(seed=seed)
episode_return = 0.0
episode_returns = []
eval_returns = []
losses = []
td_means = []
q_means = []
eval_steps = []
start_time = time.time()
We define the evaluation function that measures the agent’s performance. We configure the training hyperparameters, including the number of frames, batch size, discount factor, and target network update rate. We also initialize variables that track training statistics, including episode returns, losses, and evaluation metrics.
Copy CodeCopiedUse a different Browser
for frame_idx in range(1, num_frames + 1):
epsilon = epsilon_by_frame(frame_idx)
action = select_action(params, obs.astype(np.float32), epsilon)
next_obs, reward, done, truncated, _ = env.step(action)
terminal = done or truncated
discount = 0.0 if terminal else gamma
replay.add(
obs.astype(np.float32),
action,
float(reward),
float(discount),
next_obs.astype(np.float32),
float(terminal),
)
obs = next_obs
episode_return += reward
if terminal:
episode_returns.append(episode_return)
obs, _ = env.reset()
episode_return = 0.0
if len(replay) >= warmup_steps and frame_idx % train_every == 0:
for _ in range(max_grad_updates_per_step):
batch_np = replay.sample(batch_size)
batch = {k: jnp.asarray(v) for k, v in batch_np.items()}
params, opt_state, metrics = train_step(params, target_params, opt_state, batch)
target_params = soft_update(target_params, params, tau)
losses.append(float(metrics["loss"]))
td_means.append(float(metrics["td_abs_mean"]))
q_means.append(float(metrics["q_mean"]))
if frame_idx % eval_every == 0:
avg_eval_return = evaluate_agent(params, episodes=5)
eval_returns.append(avg_eval_return)
eval_steps.append(frame_idx)
recent_train = np.mean(episode_returns[-10:]) if episode_returns else 0.0
recent_loss = np.mean(losses[-100:]) if losses else 0.0
print(
f"step={frame_idx:6d} | epsilon={epsilon:.3f} | "
f"recent_train_return={recent_train:7.2f} | "
f"eval_return={avg_eval_return:7.2f} | "
f"recent_loss={recent_loss:.5f} | buffer={len(replay)}"
)
elapsed = time.time() - start_time
final_eval = evaluate_agent(params, episodes=10)
print("\nTraining complete")
print(f"Elapsed time: {elapsed:.1f} seconds")
print(f"Final 10-episode evaluation return: {final_eval:.2f}")
plt.figure(figsize=(14, 4))
plt.subplot(1, 3, 1)
plt.plot(episode_returns)
plt.title("Training Episode Returns")
plt.xlabel("Episode")
plt.ylabel("Return")
plt.subplot(1, 3, 2)
plt.plot(eval_steps, eval_returns)
plt.title("Evaluation Returns")
plt.xlabel("Environment Steps")
plt.ylabel("Avg Return")
plt.subplot(1, 3, 3)
plt.plot(losses, label="Loss")
plt.plot(td_means, label="|TD Error| Mean")
plt.title("Optimization Metrics")
plt.xlabel("Gradient Updates")
plt.legend()
plt.tight_layout()
plt.show()
obs, _ = eval_env.reset(seed=999)
frames = []
done = False
truncated = False
total_reward = 0.0
render_env = gym.make("CartPole-v1", render_mode="rgb_array")
obs, _ = render_env.reset(seed=999)
while not (done or truncated):
frame = render_env.render()
frames.append(frame)
q_values = q_net.apply(params, obs[None, :])
action = int(jnp.argmax(q_values[0]))
obs, reward, done, truncated, _ = render_env.step(action)
total_reward += reward
render_env.close()
print(f"Demo episode return: {total_reward:.2f}")
try:
import matplotlib.animation as animation
from IPython.display import HTML, display
fig = plt.figure(figsize=(6, 4))
patch = plt.imshow(frames[0])
plt.axis("off")
def animate(i):
patch.set_data(frames[i])
return (patch,)
anim = animation.FuncAnimation(fig, animate, frames=len(frames), interval=30, blit=True)
display(HTML(anim.to_jshtml()))
plt.close(fig)
except Exception as e:
print("Animation display skipped:", e)
We run the full reinforcement learning training loop. We periodically update the network parameters, evaluate the agent’s performance, and record metrics for visualization. Also, we plot the training results and render a demonstration episode to observe how the trained agent behaves.
In conclusion, we built a complete Deep Q-Learning agent by combining RLax with the modern JAX-based machine learning ecosystem. We designed a neural network to estimate action values, implement experience replay to stabilize learning, and compute TD errors using RLax’s Q-learning primitive. During training, we updated the network parameters using gradient-based optimization and periodically evaluated the agent to track performance improvements. Also, we saw how RLax enables a modular approach to reinforcement learning by providing reusable algorithmic components rather than full algorithms. This flexibility allows us to easily experiment with different architectures, learning rules, and optimization strategies. By extending this foundation, we can build more advanced agents, such as Double DQN, distributional reinforcement learning models, and actor–critic methods, using the same RLax primitives.
Check out the Full Noteboo here. Also, feel free to follow us on Twitter and don’t forget to join our 120k+ ML SubReddit and Subscribe to our Newsletter. Wait! are you on telegram? now you can join us on telegram as well.
The post Implementing Deep Q-Learning (DQN) from Scratch Using RLax JAX Haiku and Optax to Train a CartPole Reinforcement Learning Agent appeared first on MarkTechPost.
関連記事
Anthropic、再現可能なゲノム・プロテオーム・ケミインフォマティクスパイプライン向けマルチエージェント AI ワークベンチ「Claude Science Beta」をリリース
NVIDIA HORIZON:Git ワークツリーを自律的に進化させるハンズフリーエージェントが RTL ベンチマークで完全達成
NVIDIA AI が自己改善型ロボットフレームワーク「ASPIRE」を発表、LIBERO-Pro の長期タスクでゼロショット成功率 31% を達成
今日のまとめ
AI日報で今日の重要ニュースをまとめ読み