cleaning up jax bs

This commit is contained in:
2026-03-08 19:15:58 +01:00
parent 73246d7dd8
commit 4c658a93a7
27 changed files with 173 additions and 3146 deletions

View File

@@ -106,11 +106,6 @@ class OptimizerSpec:
eps_decay: float = 0.9995
arch: str = "small"
activation: str = "relu"
jax_num_envs: int = 16
jax_num_steps: int = 128
jax_num_minibatches: int = 4
jax_update_epochs: int = 4
jax_anneal_lr: bool = True
vf_coef: float = 0.5
max_grad_norm: float = 0.5
@@ -125,7 +120,6 @@ class RuntimeSpec:
checkpoint_interval: int = 200_000
model_dir: str = "engine/models"
log_freq: int = 100
use_jax: bool = False
@dataclass(frozen=True)
@@ -156,7 +150,6 @@ class TrainSpec:
"model_dir": self.runtime.model_dir,
"backend": self.runtime.backend,
"device": self.runtime.device,
"use_jax": self.runtime.use_jax,
"checkpoint_interval": self.runtime.checkpoint_interval,
"n_products": self.env.n_products,
"N": self.env.n_sessions,
@@ -197,11 +190,6 @@ class TrainSpec:
"eps_decay": self.optimizer.eps_decay,
"arch": self.optimizer.arch,
"activation": self.optimizer.activation,
"jax_num_envs": self.optimizer.jax_num_envs,
"jax_num_steps": self.optimizer.jax_num_steps,
"jax_num_minibatches": self.optimizer.jax_num_minibatches,
"jax_update_epochs": self.optimizer.jax_update_epochs,
"jax_anneal_lr": self.optimizer.jax_anneal_lr,
"vf_coef": self.optimizer.vf_coef,
"max_grad_norm": self.optimizer.max_grad_norm,
"robust_eval_enabled": self.eval.robust_eval_enabled,
@@ -223,14 +211,11 @@ class TrainSpec:
base.get("device", runtime_env.get("PHANTOM_DEVICE", "auto"))
)
requested_jax = _truthy(base.get("use_jax")) or _truthy(
runtime_env.get("PHANTOM_USE_JAX")
)
backend = str(base.get("backend", "jax" if requested_jax else "sb3")).lower()
backend = str(base.get("backend", "sb3")).lower()
if backend == "auto":
backend = "jax" if requested_jax else "sb3"
if backend == "jax":
requested_jax = True
backend = "sb3"
if backend != "sb3":
backend = "sb3"
no_robust = _truthy(base.get("no_robust"))
if no_robust:
@@ -284,11 +269,6 @@ class TrainSpec:
eps_decay=float(base["eps_decay"]),
arch=str(base["arch"]),
activation=str(base["activation"]),
jax_num_envs=int(base["jax_num_envs"]),
jax_num_steps=int(base["jax_num_steps"]),
jax_num_minibatches=int(base["jax_num_minibatches"]),
jax_update_epochs=int(base["jax_update_epochs"]),
jax_anneal_lr=_truthy(base.get("jax_anneal_lr")),
vf_coef=float(base["vf_coef"]),
max_grad_norm=float(base["max_grad_norm"]),
),
@@ -301,7 +281,6 @@ class TrainSpec:
checkpoint_interval=int(base["checkpoint_interval"]),
model_dir=str(base["model_dir"]),
log_freq=int(base["log_freq"]),
use_jax=requested_jax,
),
eval=EvalSpec(
eval_freq=int(base["eval_freq"]),