cleaning up jax bs

This commit is contained in:
2026-03-08 19:15:58 +01:00
parent 73246d7dd8
commit 4c658a93a7
27 changed files with 173 additions and 3146 deletions

View File

@@ -7,14 +7,6 @@ from .orchestrators import run_benchmark_cli, run_sweep_agent, run_train_once
from .spec import TrainSpec
def _truthy(value: str | bool | None) -> bool:
if isinstance(value, bool):
return value
if value is None:
return False
return str(value).strip().lower() in {"1", "true", "yes", "on"}
def _parse_tags(raw: str | None) -> list[str]:
if raw is None:
return []
@@ -55,7 +47,7 @@ def _build_parser() -> argparse.ArgumentParser:
parser.add_argument("--group", type=str)
parser.add_argument("--tags", type=str)
parser.add_argument("--backend", choices=["auto", "sb3", "jax"], default="auto")
parser.add_argument("--backend", choices=["auto", "sb3"], default="auto")
parser.add_argument("--algo", choices=["ppo", "a2c", "dqn", "qtable", "sac"])
parser.add_argument("--seed", type=int)
parser.add_argument("--total-timesteps", type=int)
@@ -111,13 +103,6 @@ def _build_parser() -> argparse.ArgumentParser:
parser.add_argument("--eval-freq", type=int)
parser.add_argument("--eval-episodes", type=int)
parser.add_argument("--jax", action="store_true")
parser.add_argument("--jax-num-envs", type=int)
parser.add_argument("--jax-num-steps", type=int)
parser.add_argument("--jax-num-minibatches", type=int)
parser.add_argument("--jax-update-epochs", type=int)
parser.add_argument("--jax-anneal-lr", type=str)
parser.add_argument("--sweep-agent", action="store_true")
parser.add_argument("--sweep-id", type=str)
parser.add_argument("--count", type=int, default=0)
@@ -127,9 +112,6 @@ def _build_parser() -> argparse.ArgumentParser:
def _overrides_from_args(args: argparse.Namespace) -> dict[str, Any]:
jax_anneal_lr = (
_truthy(args.jax_anneal_lr) if args.jax_anneal_lr is not None else None
)
backend = None if args.backend == "auto" else args.backend
overrides = {
@@ -185,12 +167,6 @@ def _overrides_from_args(args: argparse.Namespace) -> dict[str, Any]:
"max_grad_norm": args.max_grad_norm,
"eval_freq": args.eval_freq,
"eval_episodes": args.eval_episodes,
"use_jax": args.jax or None,
"jax_num_envs": args.jax_num_envs,
"jax_num_steps": args.jax_num_steps,
"jax_num_minibatches": args.jax_num_minibatches,
"jax_update_epochs": args.jax_update_epochs,
"jax_anneal_lr": jax_anneal_lr,
}
return {key: value for key, value in overrides.items() if value is not None}