mirror of
https://github.com/velocitatem/PHANTOM.git
synced 2026-05-31 08:33:36 +00:00
feat: training update
This commit is contained in:
@@ -9,10 +9,16 @@ import numpy as np
|
||||
from .wandb_checkpoint import checkpoint_artifact_name, download_latest_checkpoint
|
||||
|
||||
try:
|
||||
import wandb
|
||||
import wandb as _wandb
|
||||
|
||||
HAS_WANDB = True
|
||||
if hasattr(_wandb, "init") and callable(_wandb.init):
|
||||
wandb = _wandb
|
||||
HAS_WANDB = True
|
||||
else:
|
||||
wandb = None
|
||||
HAS_WANDB = False
|
||||
except ImportError:
|
||||
wandb = None
|
||||
HAS_WANDB = False
|
||||
|
||||
try:
|
||||
@@ -80,7 +86,7 @@ DEFAULT_CFG = {
|
||||
"jax_num_minibatches": 4,
|
||||
"jax_update_epochs": 4,
|
||||
"jax_anneal_lr": True,
|
||||
"checkpoint_interval": 10_000,
|
||||
"checkpoint_interval": 200_000,
|
||||
}
|
||||
|
||||
|
||||
@@ -404,6 +410,16 @@ def run_wandb(
|
||||
) -> dict:
|
||||
if not HAS_WANDB:
|
||||
raise ImportError("wandb is required for sweep runs")
|
||||
if not sweep_mode:
|
||||
pre_cfg = _cfg(overrides)
|
||||
if pre_cfg.get("use_jax"):
|
||||
try:
|
||||
import jax
|
||||
|
||||
if jax.process_count() > 1 and jax.process_index() != 0:
|
||||
return train_once(pre_cfg)
|
||||
except Exception:
|
||||
pass
|
||||
init_kwargs = {"mode": mode}
|
||||
if sweep_mode:
|
||||
run = wandb.init(**init_kwargs)
|
||||
@@ -431,7 +447,16 @@ def run_wandb(
|
||||
def run_local(overrides: dict) -> dict:
|
||||
cfg = _cfg(overrides)
|
||||
metrics = train_once(cfg)
|
||||
print(json.dumps(metrics, indent=2))
|
||||
should_print = True
|
||||
if cfg.get("use_jax"):
|
||||
try:
|
||||
import jax
|
||||
|
||||
should_print = jax.process_index() == 0
|
||||
except Exception:
|
||||
should_print = True
|
||||
if should_print:
|
||||
print(json.dumps(metrics, indent=2))
|
||||
return metrics
|
||||
|
||||
|
||||
@@ -439,15 +464,26 @@ def main():
|
||||
p = argparse.ArgumentParser(description="PHANTOM training and W&B sweeps")
|
||||
p.add_argument("--project", default=DEFAULT_CFG["project"])
|
||||
p.add_argument("--algo", choices=["ppo", "a2c", "dqn", "qtable"])
|
||||
p.add_argument("--seed", type=int)
|
||||
p.add_argument("--total-timesteps", type=int)
|
||||
p.add_argument("--alpha", type=float)
|
||||
p.add_argument("--N", type=int)
|
||||
p.add_argument("--n-products", type=int)
|
||||
p.add_argument("--lambda-coi", type=float)
|
||||
p.add_argument("--info-value", type=float)
|
||||
p.add_argument("--robust-radius", type=float)
|
||||
p.add_argument("--robust-points", type=int)
|
||||
p.add_argument("--learning-rate", type=float)
|
||||
p.add_argument("--gamma", type=float)
|
||||
p.add_argument("--gae-lambda", type=float)
|
||||
p.add_argument("--clip-range", type=float)
|
||||
p.add_argument("--ent-coef", type=float)
|
||||
p.add_argument("--revenue-weight", type=float)
|
||||
p.add_argument("--price-low", type=float)
|
||||
p.add_argument("--price-high", type=float)
|
||||
p.add_argument("--action-levels", type=int)
|
||||
p.add_argument("--action-scale-low", type=float)
|
||||
p.add_argument("--action-scale-high", type=float)
|
||||
p.add_argument("--max-steps", type=int)
|
||||
p.add_argument("--margin-floor", type=float)
|
||||
p.add_argument("--margin-floor-patience", type=int)
|
||||
@@ -469,15 +505,26 @@ def main():
|
||||
|
||||
overrides = {
|
||||
"algo": args.algo,
|
||||
"seed": args.seed,
|
||||
"total_timesteps": args.total_timesteps,
|
||||
"alpha": args.alpha,
|
||||
"N": args.N,
|
||||
"n_products": args.n_products,
|
||||
"lambda_coi": args.lambda_coi,
|
||||
"info_value": args.info_value,
|
||||
"robust_radius": args.robust_radius,
|
||||
"robust_points": args.robust_points,
|
||||
"learning_rate": args.learning_rate,
|
||||
"gamma": args.gamma,
|
||||
"gae_lambda": args.gae_lambda,
|
||||
"clip_range": args.clip_range,
|
||||
"ent_coef": args.ent_coef,
|
||||
"revenue_weight": args.revenue_weight,
|
||||
"price_low": args.price_low,
|
||||
"price_high": args.price_high,
|
||||
"action_levels": args.action_levels,
|
||||
"action_scale_low": args.action_scale_low,
|
||||
"action_scale_high": args.action_scale_high,
|
||||
"max_steps": args.max_steps,
|
||||
"margin_floor": args.margin_floor,
|
||||
"margin_floor_patience": args.margin_floor_patience,
|
||||
|
||||
Reference in New Issue
Block a user