mirror of
https://github.com/velocitatem/PHANTOM.git
synced 2026-05-31 16:43:36 +00:00
cleaning up jax bs
This commit is contained in:
@@ -106,11 +106,6 @@ class OptimizerSpec:
|
||||
eps_decay: float = 0.9995
|
||||
arch: str = "small"
|
||||
activation: str = "relu"
|
||||
jax_num_envs: int = 16
|
||||
jax_num_steps: int = 128
|
||||
jax_num_minibatches: int = 4
|
||||
jax_update_epochs: int = 4
|
||||
jax_anneal_lr: bool = True
|
||||
vf_coef: float = 0.5
|
||||
max_grad_norm: float = 0.5
|
||||
|
||||
@@ -125,7 +120,6 @@ class RuntimeSpec:
|
||||
checkpoint_interval: int = 200_000
|
||||
model_dir: str = "engine/models"
|
||||
log_freq: int = 100
|
||||
use_jax: bool = False
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@@ -156,7 +150,6 @@ class TrainSpec:
|
||||
"model_dir": self.runtime.model_dir,
|
||||
"backend": self.runtime.backend,
|
||||
"device": self.runtime.device,
|
||||
"use_jax": self.runtime.use_jax,
|
||||
"checkpoint_interval": self.runtime.checkpoint_interval,
|
||||
"n_products": self.env.n_products,
|
||||
"N": self.env.n_sessions,
|
||||
@@ -197,11 +190,6 @@ class TrainSpec:
|
||||
"eps_decay": self.optimizer.eps_decay,
|
||||
"arch": self.optimizer.arch,
|
||||
"activation": self.optimizer.activation,
|
||||
"jax_num_envs": self.optimizer.jax_num_envs,
|
||||
"jax_num_steps": self.optimizer.jax_num_steps,
|
||||
"jax_num_minibatches": self.optimizer.jax_num_minibatches,
|
||||
"jax_update_epochs": self.optimizer.jax_update_epochs,
|
||||
"jax_anneal_lr": self.optimizer.jax_anneal_lr,
|
||||
"vf_coef": self.optimizer.vf_coef,
|
||||
"max_grad_norm": self.optimizer.max_grad_norm,
|
||||
"robust_eval_enabled": self.eval.robust_eval_enabled,
|
||||
@@ -223,14 +211,11 @@ class TrainSpec:
|
||||
base.get("device", runtime_env.get("PHANTOM_DEVICE", "auto"))
|
||||
)
|
||||
|
||||
requested_jax = _truthy(base.get("use_jax")) or _truthy(
|
||||
runtime_env.get("PHANTOM_USE_JAX")
|
||||
)
|
||||
backend = str(base.get("backend", "jax" if requested_jax else "sb3")).lower()
|
||||
backend = str(base.get("backend", "sb3")).lower()
|
||||
if backend == "auto":
|
||||
backend = "jax" if requested_jax else "sb3"
|
||||
if backend == "jax":
|
||||
requested_jax = True
|
||||
backend = "sb3"
|
||||
if backend != "sb3":
|
||||
backend = "sb3"
|
||||
|
||||
no_robust = _truthy(base.get("no_robust"))
|
||||
if no_robust:
|
||||
@@ -284,11 +269,6 @@ class TrainSpec:
|
||||
eps_decay=float(base["eps_decay"]),
|
||||
arch=str(base["arch"]),
|
||||
activation=str(base["activation"]),
|
||||
jax_num_envs=int(base["jax_num_envs"]),
|
||||
jax_num_steps=int(base["jax_num_steps"]),
|
||||
jax_num_minibatches=int(base["jax_num_minibatches"]),
|
||||
jax_update_epochs=int(base["jax_update_epochs"]),
|
||||
jax_anneal_lr=_truthy(base.get("jax_anneal_lr")),
|
||||
vf_coef=float(base["vf_coef"]),
|
||||
max_grad_norm=float(base["max_grad_norm"]),
|
||||
),
|
||||
@@ -301,7 +281,6 @@ class TrainSpec:
|
||||
checkpoint_interval=int(base["checkpoint_interval"]),
|
||||
model_dir=str(base["model_dir"]),
|
||||
log_freq=int(base["log_freq"]),
|
||||
use_jax=requested_jax,
|
||||
),
|
||||
eval=EvalSpec(
|
||||
eval_freq=int(base["eval_freq"]),
|
||||
|
||||
Reference in New Issue
Block a user