mirror of
https://github.com/velocitatem/PHANTOM.git
synced 2026-05-31 08:33:36 +00:00
cleaning up jax bs
This commit is contained in:
@@ -7,14 +7,6 @@ from .orchestrators import run_benchmark_cli, run_sweep_agent, run_train_once
|
||||
from .spec import TrainSpec
|
||||
|
||||
|
||||
def _truthy(value: str | bool | None) -> bool:
|
||||
if isinstance(value, bool):
|
||||
return value
|
||||
if value is None:
|
||||
return False
|
||||
return str(value).strip().lower() in {"1", "true", "yes", "on"}
|
||||
|
||||
|
||||
def _parse_tags(raw: str | None) -> list[str]:
|
||||
if raw is None:
|
||||
return []
|
||||
@@ -55,7 +47,7 @@ def _build_parser() -> argparse.ArgumentParser:
|
||||
parser.add_argument("--group", type=str)
|
||||
parser.add_argument("--tags", type=str)
|
||||
|
||||
parser.add_argument("--backend", choices=["auto", "sb3", "jax"], default="auto")
|
||||
parser.add_argument("--backend", choices=["auto", "sb3"], default="auto")
|
||||
parser.add_argument("--algo", choices=["ppo", "a2c", "dqn", "qtable", "sac"])
|
||||
parser.add_argument("--seed", type=int)
|
||||
parser.add_argument("--total-timesteps", type=int)
|
||||
@@ -111,13 +103,6 @@ def _build_parser() -> argparse.ArgumentParser:
|
||||
parser.add_argument("--eval-freq", type=int)
|
||||
parser.add_argument("--eval-episodes", type=int)
|
||||
|
||||
parser.add_argument("--jax", action="store_true")
|
||||
parser.add_argument("--jax-num-envs", type=int)
|
||||
parser.add_argument("--jax-num-steps", type=int)
|
||||
parser.add_argument("--jax-num-minibatches", type=int)
|
||||
parser.add_argument("--jax-update-epochs", type=int)
|
||||
parser.add_argument("--jax-anneal-lr", type=str)
|
||||
|
||||
parser.add_argument("--sweep-agent", action="store_true")
|
||||
parser.add_argument("--sweep-id", type=str)
|
||||
parser.add_argument("--count", type=int, default=0)
|
||||
@@ -127,9 +112,6 @@ def _build_parser() -> argparse.ArgumentParser:
|
||||
|
||||
|
||||
def _overrides_from_args(args: argparse.Namespace) -> dict[str, Any]:
|
||||
jax_anneal_lr = (
|
||||
_truthy(args.jax_anneal_lr) if args.jax_anneal_lr is not None else None
|
||||
)
|
||||
backend = None if args.backend == "auto" else args.backend
|
||||
|
||||
overrides = {
|
||||
@@ -185,12 +167,6 @@ def _overrides_from_args(args: argparse.Namespace) -> dict[str, Any]:
|
||||
"max_grad_norm": args.max_grad_norm,
|
||||
"eval_freq": args.eval_freq,
|
||||
"eval_episodes": args.eval_episodes,
|
||||
"use_jax": args.jax or None,
|
||||
"jax_num_envs": args.jax_num_envs,
|
||||
"jax_num_steps": args.jax_num_steps,
|
||||
"jax_num_minibatches": args.jax_num_minibatches,
|
||||
"jax_update_epochs": args.jax_update_epochs,
|
||||
"jax_anneal_lr": jax_anneal_lr,
|
||||
}
|
||||
return {key: value for key, value in overrides.items() if value is not None}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user