mirror of
https://github.com/velocitatem/PHANTOM.git
synced 2026-05-31 08:33:36 +00:00
refactored training approaches
This commit is contained in:
@@ -384,8 +384,6 @@ def train_once(cfg: dict) -> dict:
|
||||
"JAX backend requested but JAX is not installed. "
|
||||
"Install engine/jax/requirements.txt and jax[tpu] for TPU runs."
|
||||
)
|
||||
if algo == "qtable":
|
||||
raise ValueError("qtable is not supported in JAX backend")
|
||||
try:
|
||||
from .jax.train import train_jax
|
||||
except Exception as exc: # pragma: no cover
|
||||
@@ -409,20 +407,25 @@ def run_wandb(
|
||||
init_kwargs = {"mode": mode}
|
||||
if sweep_mode:
|
||||
run = wandb.init(**init_kwargs)
|
||||
cfg = _cfg(_wandb_cfg_dict())
|
||||
for k, v in overrides.items():
|
||||
if k not in wandb.config:
|
||||
cfg[k] = v
|
||||
else:
|
||||
run = wandb.init(project=project, config=overrides, **init_kwargs)
|
||||
|
||||
try:
|
||||
cfg = _cfg(_wandb_cfg_dict())
|
||||
metrics = train_once(cfg)
|
||||
step = int(metrics.get("train/global_step", cfg["total_timesteps"]))
|
||||
wandb.log(metrics, step=step)
|
||||
for k, v in metrics.items():
|
||||
run.summary[k] = v
|
||||
wandb.finish()
|
||||
return metrics
|
||||
if sweep_mode:
|
||||
for k, v in overrides.items():
|
||||
if k not in wandb.config:
|
||||
cfg[k] = v
|
||||
|
||||
metrics = train_once(cfg)
|
||||
step = int(metrics.get("train/global_step", cfg["total_timesteps"]))
|
||||
wandb.log(metrics, step=step)
|
||||
for k, v in metrics.items():
|
||||
run.summary[k] = v
|
||||
return metrics
|
||||
finally:
|
||||
if wandb.run is not None:
|
||||
wandb.finish()
|
||||
|
||||
|
||||
def run_local(overrides: dict) -> dict:
|
||||
|
||||
Reference in New Issue
Block a user