refactored training approaches

This commit is contained in:
2026-02-19 18:23:08 +01:00
parent 5912062dc0
commit 1a9901f118
8 changed files with 947 additions and 308 deletions

View File

@@ -384,8 +384,6 @@ def train_once(cfg: dict) -> dict:
"JAX backend requested but JAX is not installed. "
"Install engine/jax/requirements.txt and jax[tpu] for TPU runs."
)
if algo == "qtable":
raise ValueError("qtable is not supported in JAX backend")
try:
from .jax.train import train_jax
except Exception as exc: # pragma: no cover
@@ -409,20 +407,25 @@ def run_wandb(
init_kwargs = {"mode": mode}
if sweep_mode:
run = wandb.init(**init_kwargs)
cfg = _cfg(_wandb_cfg_dict())
for k, v in overrides.items():
if k not in wandb.config:
cfg[k] = v
else:
run = wandb.init(project=project, config=overrides, **init_kwargs)
try:
cfg = _cfg(_wandb_cfg_dict())
metrics = train_once(cfg)
step = int(metrics.get("train/global_step", cfg["total_timesteps"]))
wandb.log(metrics, step=step)
for k, v in metrics.items():
run.summary[k] = v
wandb.finish()
return metrics
if sweep_mode:
for k, v in overrides.items():
if k not in wandb.config:
cfg[k] = v
metrics = train_once(cfg)
step = int(metrics.get("train/global_step", cfg["total_timesteps"]))
wandb.log(metrics, step=step)
for k, v in metrics.items():
run.summary[k] = v
return metrics
finally:
if wandb.run is not None:
wandb.finish()
def run_local(overrides: dict) -> dict: