DiffraxとJAXを使用した高度な微分方程式ソルバー、確率的シミュレーション、ニューラル常微分方程式の実装コーディングガイド
MarkTechPostは、DiffraxとJAXを使用して微分方程式ソルバー、確率シミュレーション、ニューラル常微分方程式を実装する詳細なコーディングガイドを提供し、実践的な技術チュートリアルを公開した。
キーポイント
DiffraxとJAXによる微分方程式解法の実装
ロジスティック成長モデルなどの常微分方程式をDiffraxの適応型ソルバーで解き、任意の時間点での解を補間する方法を実装している。
確率微分方程式とバッチシミュレーションの実装
確率微分方程式のシミュレーションとJAXのベクトル化機能を活用したバッチ処理による効率的な計算手法を紹介している。
ニューラル常微分方程式モデルの構築準備
動的システムから生成したデータを用いてニューラル常微分方程式モデルを訓練するための前段階として、データ生成と環境構築を実践的に示している。
PyTreeベースの状態管理と高度な制御
PyTreeベースの状態管理とPIDコントローラーによるステップサイズ制御など、Diffraxの高度な機能を活用した実装例を提供している。
PyTree状態のサポート
DiffraxはPyTree(辞書など)を状態変数として扱うことができ、複雑なシステムを柔軟にモデル化できます。
バッチ処理による効率的なシミュレーション
jax.vmapを使用して複数の初期条件で並列に微分方程式を解くことで、計算効率を大幅に向上させることができます。
適応型ステップサイズ制御
PIDControllerを使用して相対許容誤差(rtol)と絶対許容誤差(atol)を設定し、精度と計算コストのバランスを取ります。
影響分析・編集コメントを表示
影響分析
この記事は、高度な微分方程式解法とニューラル微分方程式の実装に関する実践的な知識を提供することで、研究開発者の技術習得を促進する。特にJAXエコシステムの活用方法を示すことで、高速な数値計算と機械学習の統合を推進する教育的価値が高い。
編集コメント
実践的なコーディングチュートリアルとして完成度が高く、JAXエコシステムを活用した科学計算と機械学習の統合事例として参考になる。ただし、新規性よりも既存技術の応用解説に重点がある。
このチュートリアルでは、Diffrax ライブラリを使用して微分方程式を解き、ニューラル微分方程式モデルを構築する方法を探ります。まず、クリーンな計算環境を設定し、JAX、Diffrax、Equinox、Optax といった必要な科学計算ライブラリをインストールします。次に、適応型ソルバーを用いて常微分方程式を解き、任意の時間点で解を問い合わせるための密補間(dense interpolation)を実行する方法を示します。進んでいくにつれ、Diffrax のより高度な機能、すなわち古典的な力学系の求解、PyTree ベースの状態(state)の扱い、および JAX のベクトル化機能を用いたバッチ処理シミュレーションの実行について調査します。また、確率微分方程式(stochastic differential equations)のシミュレーションも行い、後にニューラル常微分方程式モデルを訓練するために使用されるデータセットを力学系から生成します。
Copy CodeCopiedUse a different Browser
import os, sys, subprocess, importlib, pathlib
SENTINEL = "/tmp/diffrax_colab_ready_v3"
def _run(cmd):
subprocess.check_call(cmd)
def _need_install():
try:
import numpy
import jax
import diffrax
import equinox
import optax
import matplotlib
return False
except Exception:
return True
if not os.path.exists(SENTINEL) or _need_install():
_run([sys.executable, "-m", "pip", "uninstall", "-y", "numpy", "jax", "jaxlib", "diffrax", "equinox", "optax"])
_run([sys.executable, "-m", "pip", "install", "-q", "--upgrade", "pip"])
_run([
sys.executable, "-m", "pip", "install", "-q",
"numpy==1.26.4",
"jax[cpu]==0.4.38",
"jaxlib==0.4.38",
"diffrax",
"equinox",
"optax",
"matplotlib"
])
pathlib.Path(SENTINEL).write_text("ready")
print("Packages installed cleanly. Runtime will restart now. After reconnect, run this same cell again.")
os._exit(0)
import time
import math
import numpy as np
import jax
import jax.numpy as jnp
import jax.random as jr
import diffrax
import equinox as eqx
import optax
import matplotlib.pyplot as plt
print("NumPy:", np.__version__)
print("JAX:", jax.__version__)
print("Backend:", jax.default_backend())
def logistic(t, y, args):
r, k = args
return r * y * (1 - y / k)
t0, t1 = 0.0, 10.0
ts = jnp.linspace(t0, t1, 300)
y0 = jnp.array(0.4)
args = (2.0, 5.0)
sol_logistic = diffrax.diffeqsolve(
diffrax.ODETerm(logistic),
diffrax.Tsit5(),
t0=t0,
t1=t1,
dt0=0.05,
y0=y0,
args=args,
saveat=diffrax.SaveAt(ts=ts, dense=True),
stepsize_controller=diffrax.PIDController(rtol=1e-6, atol=1e-8),
max_steps=100000,
)
query_ts = jnp.array([0.7, 2.35, 4.8, 9.2])
query_ys = jax.vmap(sol_logistic.evaluate)(query_ts)
print("\n=== Example 1: Logistic growth ===")
print("Saved solution shape:", sol_logistic.ys.shape)
print("Interpolated values:")
for t_, y_ in zip(query_ts, query_ys):
print(f"t={float(t_):.3f} -> y={float(y_):.6f}")
def lotka_volterra(t, y, args):
alpha, beta, delta, gamma = args
prey, predator = y
dprey = alpha * prey - beta * prey * predator
dpred = delta * prey * predator - gamma * predator
return jnp.array([dprey, dpred])
lv_y0 = jnp.array([10.0, 2.0])
lv_args = (1.5, 1.0, 0.75, 1.0)
lv_ts = jnp.linspace(0.0, 15.0, 500)
sol_lv = diffrax.diffeqsolve(
diffrax.ODETerm(lotka_volterra),
diffrax.Dopri5(),
t0=0.0,
t1=15.0,
dt0=0.02,
y0=lv_y0,
args=lv_args,
saveat=diffrax.SaveAt(ts=lv_ts),
stepsize_controller=diffrax.PIDController(rtol=1e-6, atol=1e-8),
max_steps=100000,
)
print("\n=== Example 2: Lotka-Volterra ===")
print("Shape:", sol_lv.ys.shape)
環境を設定し、必要な科学計算ライブラリが正しくインストールされていることを確認します。JAX、Diffrax、Equinox、Optax、および可視化ツールをインポートして微分方程式シミュレーションの構築と実行を行います。その後、適応型ソルバーを使用してロジスティック成長常微分方程式を解き、任意の時間点で解をクエリするための密な補間(dense interpolation)を実演します。
Copy CodeCopiedUse a different Browser
def spring_mass_damper(t, state, args):
k, c, m = args["k"], args["c"], args["m"]
x = state["x"]
v = state["v"]
dx = v
dv = -(k / m) * x - (c / m) * v
return {"x": dx, "v": dv}
pytree_state0 = {"x": jnp.array([2.0]), "v": jnp.array([0.0])}
pytree_args = {"k": 6.0, "c": 0.6, "m": 1.5}
pytree_ts = jnp.linspace(0.0, 12.0, 400)
sol_pytree = diffrax.diffeqsolve(
diffrax.ODETerm(spring_mass_damper),
diffrax.Tsit5(),
t0=0.0,
t1=12.0,
dt0=0.02,
y0=pytree_state0,
args=pytree_args,
saveat=diffrax.SaveAt(ts=pytree_ts),
stepsize_controller=diffrax.PIDController(rtol=1e-6, atol=1e-8),
max_steps=100000,
)
print("\n=== Example 3: PyTree state ===")
print("x shape:", sol_pytree.ys["x"].shape)
print("v shape:", sol_pytree.ys["v"].shape)
def damped_oscillator(t, y, args):
omega, zeta = args
x, v = y
dx = v
dv = -(omega ** 2) * x - 2.0 * zeta * omega * v
return jnp.array([dx, dv])
batch_y0 = jnp.array([
[1.0, 0.0],
[1.5, 0.0],
[2.0, 0.0],
[2.5, 0.0],
[3.0, 0.0],
])
batch_args = (2.5, 0.15)
batch_ts = jnp.linspace(0.0, 10.0, 300)
def solve_single(y0_single):
sol = diffrax.diffeqsolve(
diffrax.ODETerm(damped_oscillator),
diffrax.Tsit5(),
t0=0.0,
t1=10.0,
dt0=0.02,
y0=y0_single,
args=batch_args,
saveat=diffrax.SaveAt(ts=batch_ts),
stepsize_controller=diffrax.PIDController(rtol=1e-5, atol=1e-7),
max_steps=100000,
)
return sol.ys
batched_ys = jax.vmap(solve_single)(batch_y0)
print("\n=== Example 4: Batched solves ===")
print("Batched shape:", batched_ys.shape)
We model the Lotka–Volterra predator–prey system to study the dynamics of interacting populations over time. We then introduce a PyTree-based state representation to simulate a spring–mass–damper system where the system state is stored as structured data. Finally, we perform batched differential equation solves using JAX’s vmap to efficiently simulate multiple systems in parallel.
Copy CodeCopiedUse a different Browser
sigma = 0.30
theta = 1.20
mu = 1.50
sde_ts = jnp.linspace(0.0, 6.0, 400)
def ou_drift(t, y, args):
theta_, mu_ = args
return theta_ * (mu_ - y)
def ou_diffusion(t, y, args):
return jnp.array([[sigma]])
def solve_ou(key):
bm = diffrax.VirtualBrownianTree(
t0=0.0,
t1=6.0,
tol=1e-3,
shape=(1,),
key=key,
)
terms = diffrax.MultiTerm(
diffrax.ODETerm(ou_drift),
diffrax.ControlTerm(ou_diffusion, bm),
)
sol = diffrax.diffeqsolve(
terms,
diffrax.EulerHeun(),
t0=0.0,
t1=6.0,
dt0=0.01,
y0=jnp.array([0.0]),
args=(theta, mu),
saveat=diffrax.SaveAt(ts=sde_ts),
max_steps=100000,
)
return sol.ys[:, 0]
sde_keys = jr.split(jr.PRNGKey(0), 5)
sde_paths = jax.vmap(solve_ou)(sde_keys)
print("\n=== Example 5: SDE ===")
print("SDE paths shape:", sde_paths.shape)
true_a = 0.25
true_b = 2.20
train_ts = jnp.linspace(0.0, 6.0, 120)
def true_dynamics(t, y, args):
x, v = y
dx = v
dv = -true_b * x - true_a * v + 0.1 * jnp.sin(2.0 * t)
return jnp.array([dx, dv])
true_sol = diffrax.diffeqsolve(
diffrax.ODETerm(true_dynamics),
diffrax.Tsit5(),
t0=0.0,
t1=6.0,
dt0=0.01,
y0=jnp.array([1.0, 0.0]),
saveat=diffrax.SaveAt(ts=train_ts),
stepsize_controller=diffrax.PIDController(rtol=1e-6, atol=1e-8),
max_steps=100000,
)
noise_key = jr.PRNGKey(42)
train_y = true_sol.ys + 0.01 * jr.normal(noise_key, true_sol.ys.shape)
We simulate a stochastic differential equation representing an Ornstein–Uhlenbeck process. We construct a Brownian motion process and integrate it with the drift and diffusion terms to generate multiple stochastic trajectories. We then create a synthetic dataset by solving a physical dynamical system that will later be used to train a neural differential equation model.
Copy CodeCopiedUse a different Browser
class ODEFunc(eqx.Module):
mlp: eqx.nn.MLP
def __init__(self, key, width=64, depth=2):
self.mlp = eqx.nn.MLP(
in_size=3,
out_size=2,
width_size=width,
depth=depth,
activation=jax.nn.tanh,
final_activation=lambda x: x,
key=key,
)
def __call__(self, t, y, args):
inp = jnp.concatenate([y, jnp.array([t])], axis=0)
return self.mlp(inp)
class NeuralODE(eqx.Module):
func: ODEFunc
def __init__(self, key):
self.func = ODEFunc(key)
def __call__(self, ts, y0):
sol = diffrax.diffeqsolve(
diffrax.ODETerm(self.func),
diffrax.Tsit5(),
t0=ts[0],
t1=ts[-1],
dt0=0.01,
y0=y0,
saveat=diffrax.SaveAt(ts=ts),
stepsize_controller=diffrax.PIDController(rtol=1e-4, atol=1e-6),
max_steps=100000,
)
return sol.ys
model = NeuralODE(jr.PRNGKey(123))
optim = optax.adam(1e-2)
opt_state = optim.init(eqx.filter(model, eqx.is_array))
@eqx.filter_value_and_grad
def loss_fn(model, ts, y0, target):
pred = model(ts, y0)
return jnp.mean((pred - target) ** 2)
@eqx.filter_jit
def train_step(model, opt_state, ts, y0, target):
loss, grads = loss_fn(model, ts, y0, target)
updates, opt_state = optim.update(grads, opt_state, model)
model = eqx.apply_updates(model, updates)
return model, opt_state, loss
print("\n=== Example 6: Neural ODE training ===")
losses = []
start = time.time()
for step in range(200):
model, opt_state, loss = train_step(model, opt_state, train_ts, jnp.array([1.0, 0.0]), train_y)
losses.append(float(loss))
if step % 40 == 0 or step == 199:
print(f"step={step:03d} loss={float(loss):.8f}")
elapsed = time.time() - start
pred_y = model(train_ts, jnp.array([1.0, 0.0]))
print(f"Training time: {elapsed:.2f}s")
jit_solver = jax.jit(solve_single)
_ = jit_solver(batch_y0[0]).block_until_ready()
bench_start = time.time()
_ = jit_solver(batch_y0[0]).block_until_ready()
bench_end = time.time()
print("\n=== Example 7: JIT benchmark ===")
print(f"Single compiled solve latency: {(bench_end - bench_start) * 1000:.2f} ms")
Equinox を用いてニューラル常微分方程式モデルを構築し、ニューラルネットワークによってシステムダイナミクスを表現します。Optax を用いて損失関数と最適化手順を定義することで、モデルがデータから背後にあるダイナミクスを学習できるようにします。その後、常微分方程式ソルバーを用いてニューラル ODE を訓練し、その性能を評価するとともに、JAX の JIT コンパイル(Just-In-Time Compilation)を用いてソルバーのベンチマークを行います。
plt.figure(figsize=(8, 4))
plt.plot(ts, sol_logistic.ys, label="solution")
plt.scatter(np.array(query_ts), np.array(query_ys), s=30, label="dense interpolation")
plt.title("Adaptive ODE + Dense Interpolation")
plt.xlabel("t")
plt.ylabel("y")
plt.legend()
plt.tight_layout()
plt.show()
plt.figure(figsize=(8, 4))
plt.plot(lv_ts, sol_lv.ys[:, 0], label="prey")
plt.plot(lv_ts, sol_lv.ys[:, 1], label="predator")
plt.title("Lotka-Volterra")
plt.xlabel("t")
plt.ylabel("population")
plt.legend()
plt.tight_layout()
plt.show()
plt.figure(figsize=(8, 4))
plt.plot(pytree_ts, sol_pytree.ys["x"][:, 0], label="position")
plt.plot(pytree_ts, sol_pytree.ys["v"][:, 0], label="velocity")
plt.title("PyTree State Solve")
plt.xlabel("t")
plt.legend()
plt.tight_layout()
plt.show()
plt.figure(figsize=(8, 4))
for i in range(batched_ys.shape[0]):
plt.plot(batch_ts, batched_ys[i, :, 0], label=f"x0={float(batch_y0[i,0]):.1f}")
plt.title("Batched Solves with vmap")
plt.xlabel("t")
plt.ylabel("x(t)")
plt.legend()
plt.tight_layout()
plt.show()
plt.figure(figsize=(8, 4))
for i in range(sde_paths.shape[0]):
plt.plot(sde_ts, sde_paths[i], alpha=0.8)
plt.title("SDE Sample Paths (Ornstein-Uhlenbeck)")
plt.xlabel("t")
plt.ylabel("state")
plt.tight_layout()
plt.show()
plt.figure(figsize=(8, 4))
plt.plot(train_ts, train_y[:, 0], label="target x")
plt.plot(train_ts, pred_y[:, 0], "--", label="pred x")
plt.plot(train_ts, train_y[:, 1], label="target v")
plt.plot(train_ts, pred_y[:, 1], "--", label="pred v")
plt.title("Neural ODE Fit")
plt.xlabel("t")
plt.legend()
plt.tight_layout()
plt.show()
plt.figure(figsize=(8, 4))
plt.plot(losses)
plt.yscale("log")
plt.title("Neural ODE Training Loss")
plt.xlabel("step")
plt.ylabel("MSE")
plt.tight_layout()
plt.show()
print("\n=== SUMMARY ===")
print("1. Tsit5 を用いた適応型 ODE ソルバー")
print("2. solution.evaluate による密な補間")
print("3. PyTree 値を持つ状態")
print("4. jax.vmap を用いたバッチ処理ソルブ")
print("5. VirtualBrownianTree を用いた SDE シミュレーション")
print("6. Equinox + Optax による Neural ODE の学習")
print("7. JIT コンパイルされたソルブベンチマーク完了")
モデル化したシステムの挙動を理解するために、シミュレーションおよび学習プロセスの結果を可視化します。ロジスティック成長の解、捕食者 - 被食者のダイナミクス、PyTree システムの状態、バッチ処理された振動子の軌跡、そして確率的な経路を描画します。さらに、Neural ODE の予測値とターゲットデータを比較し、モデル全体の性能を要約するために学習損失を表示します。
結論として、Diffrax と JAX エコシステムを用いて、科学計算および機械学習のための完全なワークフローを実装しました。決定論的および確率的微分方程式の求解、バッチ処理されたシミュレーションの実行、そしてデータからシステムの基礎的なダイナミクスを学習する Neural ODE モデルのトレーニングを行いました。このプロセス全体を通じて、JAX の即時コンパイル(Just-in-Time Compilation)と自動微分を活用し、効率的な計算とスケーラブルな実験を実現しました。Diffrax を Equinox および Optax と組み合わせることで、微分方程式ソルバーが現代の深層学習フレームワークとどのようにシームレスに統合されるかを示すことができました。
詳細はここからノートブックをご覧ください。また、Twitter でフォローしていただくこともお気軽にどうぞ。12 万人以上の ML サブレッドに参加することや、ニュースレターを購読することも忘れないでください。待ってください!Telegram をご利用ですか?今なら Telegram でも私たちに参加できます。
本記事「Diffrax と JAX を用いた高度な微分方程式ソルバー、確率シミュレーション、ニューラル常微分方程式の実装ガイド」は、MarkTechPost で最初に公開されました。
原文を表示
In this tutorial, we explore how to solve differential equations and build neural differential equation models using the Diffrax library. We begin by setting up a clean computational environment and installing the required scientific computing libraries such as JAX, Diffrax, Equinox, and Optax. We then demonstrate how to solve ordinary differential equations using adaptive solvers and perform dense interpolation to query solutions at arbitrary time points. As we progress, we investigate more advanced capabilities of Diffrax, including solving classical dynamical systems, working with PyTree-based states, and running batched simulations using JAX’s vectorization features. We also simulate stochastic differential equations and generate data from a dynamical system that will later be used to train a neural ordinary differential equation model.
Copy CodeCopiedUse a different Browser
import os, sys, subprocess, importlib, pathlib
SENTINEL = "/tmp/diffrax_colab_ready_v3"
def _run(cmd):
subprocess.check_call(cmd)
def _need_install():
try:
import numpy
import jax
import diffrax
import equinox
import optax
import matplotlib
return False
except Exception:
return True
if not os.path.exists(SENTINEL) or _need_install():
_run([sys.executable, "-m", "pip", "uninstall", "-y", "numpy", "jax", "jaxlib", "diffrax", "equinox", "optax"])
_run([sys.executable, "-m", "pip", "install", "-q", "--upgrade", "pip"])
_run([
sys.executable, "-m", "pip", "install", "-q",
"numpy==1.26.4",
"jax[cpu]==0.4.38",
"jaxlib==0.4.38",
"diffrax",
"equinox",
"optax",
"matplotlib"
])
pathlib.Path(SENTINEL).write_text("ready")
print("Packages installed cleanly. Runtime will restart now. After reconnect, run this same cell again.")
os._exit(0)
import time
import math
import numpy as np
import jax
import jax.numpy as jnp
import jax.random as jr
import diffrax
import equinox as eqx
import optax
import matplotlib.pyplot as plt
print("NumPy:", np.__version__)
print("JAX:", jax.__version__)
print("Backend:", jax.default_backend())
def logistic(t, y, args):
r, k = args
return r * y * (1 - y / k)
t0, t1 = 0.0, 10.0
ts = jnp.linspace(t0, t1, 300)
y0 = jnp.array(0.4)
args = (2.0, 5.0)
sol_logistic = diffrax.diffeqsolve(
diffrax.ODETerm(logistic),
diffrax.Tsit5(),
t0=t0,
t1=t1,
dt0=0.05,
y0=y0,
args=args,
saveat=diffrax.SaveAt(ts=ts, dense=True),
stepsize_controller=diffrax.PIDController(rtol=1e-6, atol=1e-8),
max_steps=100000,
)
query_ts = jnp.array([0.7, 2.35, 4.8, 9.2])
query_ys = jax.vmap(sol_logistic.evaluate)(query_ts)
print("\n=== Example 1: Logistic growth ===")
print("Saved solution shape:", sol_logistic.ys.shape)
print("Interpolated values:")
for t_, y_ in zip(query_ts, query_ys):
print(f"t={float(t_):.3f} -> y={float(y_):.6f}")
def lotka_volterra(t, y, args):
alpha, beta, delta, gamma = args
prey, predator = y
dprey = alpha * prey - beta * prey * predator
dpred = delta * prey * predator - gamma * predator
return jnp.array([dprey, dpred])
lv_y0 = jnp.array([10.0, 2.0])
lv_args = (1.5, 1.0, 0.75, 1.0)
lv_ts = jnp.linspace(0.0, 15.0, 500)
sol_lv = diffrax.diffeqsolve(
diffrax.ODETerm(lotka_volterra),
diffrax.Dopri5(),
t0=0.0,
t1=15.0,
dt0=0.02,
y0=lv_y0,
args=lv_args,
saveat=diffrax.SaveAt(ts=lv_ts),
stepsize_controller=diffrax.PIDController(rtol=1e-6, atol=1e-8),
max_steps=100000,
)
print("\n=== Example 2: Lotka-Volterra ===")
print("Shape:", sol_lv.ys.shape)
We set up the environment and ensure that all required scientific computing libraries are installed correctly. We import JAX, Diffrax, Equinox, Optax, and visualization tools to build and run differential equation simulations. We then solve a logistic growth ordinary differential equation using an adaptive solver and demonstrate dense interpolation to query the solution at arbitrary time points.
Copy CodeCopiedUse a different Browser
def spring_mass_damper(t, state, args):
k, c, m = args["k"], args["c"], args["m"]
x = state["x"]
v = state["v"]
dx = v
dv = -(k / m) * x - (c / m) * v
return {"x": dx, "v": dv}
pytree_state0 = {"x": jnp.array([2.0]), "v": jnp.array([0.0])}
pytree_args = {"k": 6.0, "c": 0.6, "m": 1.5}
pytree_ts = jnp.linspace(0.0, 12.0, 400)
sol_pytree = diffrax.diffeqsolve(
diffrax.ODETerm(spring_mass_damper),
diffrax.Tsit5(),
t0=0.0,
t1=12.0,
dt0=0.02,
y0=pytree_state0,
args=pytree_args,
saveat=diffrax.SaveAt(ts=pytree_ts),
stepsize_controller=diffrax.PIDController(rtol=1e-6, atol=1e-8),
max_steps=100000,
)
print("\n=== Example 3: PyTree state ===")
print("x shape:", sol_pytree.ys["x"].shape)
print("v shape:", sol_pytree.ys["v"].shape)
def damped_oscillator(t, y, args):
omega, zeta = args
x, v = y
dx = v
dv = -(omega ** 2) * x - 2.0 * zeta * omega * v
return jnp.array([dx, dv])
batch_y0 = jnp.array([
[1.0, 0.0],
[1.5, 0.0],
[2.0, 0.0],
[2.5, 0.0],
[3.0, 0.0],
])
batch_args = (2.5, 0.15)
batch_ts = jnp.linspace(0.0, 10.0, 300)
def solve_single(y0_single):
sol = diffrax.diffeqsolve(
diffrax.ODETerm(damped_oscillator),
diffrax.Tsit5(),
t0=0.0,
t1=10.0,
dt0=0.02,
y0=y0_single,
args=batch_args,
saveat=diffrax.SaveAt(ts=batch_ts),
stepsize_controller=diffrax.PIDController(rtol=1e-5, atol=1e-7),
max_steps=100000,
)
return sol.ys
batched_ys = jax.vmap(solve_single)(batch_y0)
print("\n=== Example 4: Batched solves ===")
print("Batched shape:", batched_ys.shape)
We model the Lotka–Volterra predator–prey system to study the dynamics of interacting populations over time. We then introduce a PyTree-based state representation to simulate a spring–mass–damper system where the system state is stored as structured data. Finally, we perform batched differential equation solves using JAX’s vmap to efficiently simulate multiple systems in parallel.
Copy CodeCopiedUse a different Browser
sigma = 0.30
theta = 1.20
mu = 1.50
sde_ts = jnp.linspace(0.0, 6.0, 400)
def ou_drift(t, y, args):
theta_, mu_ = args
return theta_ * (mu_ - y)
def ou_diffusion(t, y, args):
return jnp.array([[sigma]])
def solve_ou(key):
bm = diffrax.VirtualBrownianTree(
t0=0.0,
t1=6.0,
tol=1e-3,
shape=(1,),
key=key,
)
terms = diffrax.MultiTerm(
diffrax.ODETerm(ou_drift),
diffrax.ControlTerm(ou_diffusion, bm),
)
sol = diffrax.diffeqsolve(
terms,
diffrax.EulerHeun(),
t0=0.0,
t1=6.0,
dt0=0.01,
y0=jnp.array([0.0]),
args=(theta, mu),
saveat=diffrax.SaveAt(ts=sde_ts),
max_steps=100000,
)
return sol.ys[:, 0]
sde_keys = jr.split(jr.PRNGKey(0), 5)
sde_paths = jax.vmap(solve_ou)(sde_keys)
print("\n=== Example 5: SDE ===")
print("SDE paths shape:", sde_paths.shape)
true_a = 0.25
true_b = 2.20
train_ts = jnp.linspace(0.0, 6.0, 120)
def true_dynamics(t, y, args):
x, v = y
dx = v
dv = -true_b * x - true_a * v + 0.1 * jnp.sin(2.0 * t)
return jnp.array([dx, dv])
true_sol = diffrax.diffeqsolve(
diffrax.ODETerm(true_dynamics),
diffrax.Tsit5(),
t0=0.0,
t1=6.0,
dt0=0.01,
y0=jnp.array([1.0, 0.0]),
saveat=diffrax.SaveAt(ts=train_ts),
stepsize_controller=diffrax.PIDController(rtol=1e-6, atol=1e-8),
max_steps=100000,
)
noise_key = jr.PRNGKey(42)
train_y = true_sol.ys + 0.01 * jr.normal(noise_key, true_sol.ys.shape)
We simulate a stochastic differential equation representing an Ornstein–Uhlenbeck process. We construct a Brownian motion process and integrate it with the drift and diffusion terms to generate multiple stochastic trajectories. We then create a synthetic dataset by solving a physical dynamical system that will later be used to train a neural differential equation model.
Copy CodeCopiedUse a different Browser
class ODEFunc(eqx.Module):
mlp: eqx.nn.MLP
def __init__(self, key, width=64, depth=2):
self.mlp = eqx.nn.MLP(
in_size=3,
out_size=2,
width_size=width,
depth=depth,
activation=jax.nn.tanh,
final_activation=lambda x: x,
key=key,
)
def __call__(self, t, y, args):
inp = jnp.concatenate([y, jnp.array([t])], axis=0)
return self.mlp(inp)
class NeuralODE(eqx.Module):
func: ODEFunc
def __init__(self, key):
self.func = ODEFunc(key)
def __call__(self, ts, y0):
sol = diffrax.diffeqsolve(
diffrax.ODETerm(self.func),
diffrax.Tsit5(),
t0=ts[0],
t1=ts[-1],
dt0=0.01,
y0=y0,
saveat=diffrax.SaveAt(ts=ts),
stepsize_controller=diffrax.PIDController(rtol=1e-4, atol=1e-6),
max_steps=100000,
)
return sol.ys
model = NeuralODE(jr.PRNGKey(123))
optim = optax.adam(1e-2)
opt_state = optim.init(eqx.filter(model, eqx.is_array))
@eqx.filter_value_and_grad
def loss_fn(model, ts, y0, target):
pred = model(ts, y0)
return jnp.mean((pred - target) ** 2)
@eqx.filter_jit
def train_step(model, opt_state, ts, y0, target):
loss, grads = loss_fn(model, ts, y0, target)
updates, opt_state = optim.update(grads, opt_state, model)
model = eqx.apply_updates(model, updates)
return model, opt_state, loss
print("\n=== Example 6: Neural ODE training ===")
losses = []
start = time.time()
for step in range(200):
model, opt_state, loss = train_step(model, opt_state, train_ts, jnp.array([1.0, 0.0]), train_y)
losses.append(float(loss))
if step % 40 == 0 or step == 199:
print(f"step={step:03d} loss={float(loss):.8f}")
elapsed = time.time() - start
pred_y = model(train_ts, jnp.array([1.0, 0.0]))
print(f"Training time: {elapsed:.2f}s")
jit_solver = jax.jit(solve_single)
_ = jit_solver(batch_y0[0]).block_until_ready()
bench_start = time.time()
_ = jit_solver(batch_y0[0]).block_until_ready()
bench_end = time.time()
print("\n=== Example 7: JIT benchmark ===")
print(f"Single compiled solve latency: {(bench_end - bench_start) * 1000:.2f} ms")
We build a neural ordinary differential equation model using Equinox to represent the system dynamics with a neural network. We define a loss function and optimization procedure using Optax so that the model can learn the underlying dynamics from data. We then train the neural ODE using the differential equation solver and evaluate its performance, benchmarking the solver with JAX’s JIT compilation.
Copy CodeCopiedUse a different Browser
plt.figure(figsize=(8, 4))
plt.plot(ts, sol_logistic.ys, label="solution")
plt.scatter(np.array(query_ts), np.array(query_ys), s=30, label="dense interpolation")
plt.title("Adaptive ODE + Dense Interpolation")
plt.xlabel("t")
plt.ylabel("y")
plt.legend()
plt.tight_layout()
plt.show()
plt.figure(figsize=(8, 4))
plt.plot(lv_ts, sol_lv.ys[:, 0], label="prey")
plt.plot(lv_ts, sol_lv.ys[:, 1], label="predator")
plt.title("Lotka-Volterra")
plt.xlabel("t")
plt.ylabel("population")
plt.legend()
plt.tight_layout()
plt.show()
plt.figure(figsize=(8, 4))
plt.plot(pytree_ts, sol_pytree.ys["x"][:, 0], label="position")
plt.plot(pytree_ts, sol_pytree.ys["v"][:, 0], label="velocity")
plt.title("PyTree State Solve")
plt.xlabel("t")
plt.legend()
plt.tight_layout()
plt.show()
plt.figure(figsize=(8, 4))
for i in range(batched_ys.shape[0]):
plt.plot(batch_ts, batched_ys[i, :, 0], label=f"x0={float(batch_y0[i,0]):.1f}")
plt.title("Batched Solves with vmap")
plt.xlabel("t")
plt.ylabel("x(t)")
plt.legend()
plt.tight_layout()
plt.show()
plt.figure(figsize=(8, 4))
for i in range(sde_paths.shape[0]):
plt.plot(sde_ts, sde_paths[i], alpha=0.8)
plt.title("SDE Sample Paths (Ornstein-Uhlenbeck)")
plt.xlabel("t")
plt.ylabel("state")
plt.tight_layout()
plt.show()
plt.figure(figsize=(8, 4))
plt.plot(train_ts, train_y[:, 0], label="target x")
plt.plot(train_ts, pred_y[:, 0], "--", label="pred x")
plt.plot(train_ts, train_y[:, 1], label="target v")
plt.plot(train_ts, pred_y[:, 1], "--", label="pred v")
plt.title("Neural ODE Fit")
plt.xlabel("t")
plt.legend()
plt.tight_layout()
plt.show()
plt.figure(figsize=(8, 4))
plt.plot(losses)
plt.yscale("log")
plt.title("Neural ODE Training Loss")
plt.xlabel("step")
plt.ylabel("MSE")
plt.tight_layout()
plt.show()
print("\n=== SUMMARY ===")
print("1. Adaptive ODE solve with Tsit5")
print("2. Dense interpolation using solution.evaluate")
print("3. PyTree-valued states")
print("4. Batched solves using jax.vmap")
print("5. SDE simulation with VirtualBrownianTree")
print("6. Neural ODE training with Equinox + Optax")
print("7. JIT-compiled solve benchmark complete")
We visualize the results of the simulations and training process to understand the behavior of the systems we modeled. We plot the logistic growth solution, predator–prey dynamics, PyTree system states, batched oscillator trajectories, and stochastic paths. Also, we compare the neural ODE predictions with the target data and display the training loss to summarize the model’s overall performance.
In conclusion, we implemented a complete workflow for scientific computing and machine learning using Diffrax and the JAX ecosystem. We solved deterministic and stochastic differential equations, performed batched simulations, and trained a neural ODE model that learns the underlying dynamics of a system from data. Throughout the process, we leveraged JAX’s just-in-time compilation and automatic differentiation to achieve efficient computation and scalable experimentation. By combining Diffrax with Equinox and Optax, we demonstrated how differential equation solvers can seamlessly integrate with modern deep learning frameworks.
Check out Full Notebook here. Also, feel free to follow us on Twitter and don’t forget to join our 120k+ ML SubReddit and Subscribe to our Newsletter. Wait! are you on telegram? now you can join us on telegram as well.
The post A Coding Guide to Implement Advanced Differential Equation Solvers, Stochastic Simulations, and Neural Ordinary Differential Equations Using Diffrax and JAX appeared first on MarkTechPost.
関連記事
Google Health API に CLI ツール「ghealth」登場:Fitbit データを AI エージェントへ
Anthropic、再現可能なゲノム・プロテオーム・ケミインフォマティクスパイプライン向けマルチエージェント AI ワークベンチ「Claude Science Beta」をリリース
NVIDIA HORIZON:Git ワークツリーを自律的に進化させるハンズフリーエージェントが RTL ベンチマークで完全達成
今日のまとめ
AI日報で今日の重要ニュースをまとめ読み