mirror of
https://github.com/velocitatem/PHANTOM.git
synced 2026-05-31 08:33:36 +00:00
adding naive jax and libraries and make adjustments
This commit is contained in:
@@ -1,8 +1,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
from gymnasium.wrappers import FlattenObservation
|
||||
|
||||
try:
|
||||
import wandb
|
||||
@@ -20,9 +22,7 @@ try:
|
||||
except ImportError:
|
||||
HAS_SB3 = False
|
||||
|
||||
from .wrapper import PHANTOM
|
||||
from .lib import EconomicMetricsWrapper, MetricsCallback
|
||||
from .lib.discrete import EventQTable
|
||||
from .jax import JAX_AVAILABLE
|
||||
|
||||
|
||||
DEFAULT_CFG = {
|
||||
@@ -69,14 +69,34 @@ DEFAULT_CFG = {
|
||||
"arch": "small",
|
||||
"activation": "relu",
|
||||
"q_bins": 6,
|
||||
"max_steps": 100,
|
||||
"margin_floor": 0.05,
|
||||
"margin_floor_patience": 5,
|
||||
"use_jax": False,
|
||||
"jax_num_envs": 16,
|
||||
"jax_num_steps": 128,
|
||||
"jax_num_minibatches": 4,
|
||||
"jax_update_epochs": 4,
|
||||
"jax_anneal_lr": True,
|
||||
}
|
||||
|
||||
|
||||
def _truthy(value: str | bool | None) -> bool:
|
||||
if isinstance(value, bool):
|
||||
return value
|
||||
if value is None:
|
||||
return False
|
||||
return str(value).strip().lower() in {"1", "true", "yes", "on"}
|
||||
|
||||
|
||||
def _cfg(raw: dict | None = None) -> dict:
|
||||
cfg = dict(DEFAULT_CFG)
|
||||
if raw:
|
||||
cfg.update({k: v for k, v in raw.items() if v is not None})
|
||||
cfg["algo"] = str(cfg["algo"]).lower()
|
||||
cfg["use_jax"] = _truthy(cfg.get("use_jax")) or _truthy(
|
||||
os.environ.get("PHANTOM_USE_JAX")
|
||||
)
|
||||
return cfg
|
||||
|
||||
|
||||
@@ -89,6 +109,11 @@ def _wandb_cfg_dict() -> dict:
|
||||
|
||||
|
||||
def make_env(cfg: dict):
|
||||
from gymnasium.wrappers import FlattenObservation
|
||||
|
||||
from .wrapper import PHANTOM
|
||||
from .lib.wrappers import EconomicMetricsWrapper
|
||||
|
||||
env = PHANTOM(
|
||||
n_products=int(cfg["n_products"]),
|
||||
alpha=float(cfg["alpha"]),
|
||||
@@ -101,6 +126,9 @@ def make_env(cfg: dict):
|
||||
action_levels=int(cfg["action_levels"]),
|
||||
action_scale_low=float(cfg["action_scale_low"]),
|
||||
action_scale_high=float(cfg["action_scale_high"]),
|
||||
max_steps=int(cfg.get("max_steps", 100)),
|
||||
margin_floor=float(cfg.get("margin_floor", 0.05)),
|
||||
margin_floor_patience=int(cfg.get("margin_floor_patience", 5)),
|
||||
render_mode=None,
|
||||
)
|
||||
env = EconomicMetricsWrapper(env)
|
||||
@@ -235,6 +263,8 @@ def build_model(cfg: dict, env):
|
||||
|
||||
|
||||
def train_qtable(cfg: dict) -> tuple[EventQTable, dict]:
|
||||
from .lib.discrete import EventQTable
|
||||
|
||||
np.random.seed(int(cfg["seed"]))
|
||||
env = make_env(cfg)
|
||||
eval_env = make_env(cfg)
|
||||
@@ -275,6 +305,8 @@ def train_qtable(cfg: dict) -> tuple[EventQTable, dict]:
|
||||
def train_sb3(cfg: dict) -> tuple[object, dict]:
|
||||
if not HAS_SB3:
|
||||
raise ImportError("stable-baselines3 is required for SB3 models")
|
||||
from .lib.callbacks import MetricsCallback
|
||||
|
||||
env = make_env(cfg)
|
||||
eval_env = make_env(cfg)
|
||||
env = Monitor(env)
|
||||
@@ -303,7 +335,20 @@ def train_sb3(cfg: dict) -> tuple[object, dict]:
|
||||
|
||||
def train_once(cfg: dict) -> dict:
|
||||
algo = cfg["algo"]
|
||||
if algo == "qtable":
|
||||
if cfg.get("use_jax"):
|
||||
if not JAX_AVAILABLE:
|
||||
raise ImportError(
|
||||
"JAX backend requested but JAX is not installed. "
|
||||
"Install engine/jax/requirements.txt and jax[tpu] for TPU runs."
|
||||
)
|
||||
if algo == "qtable":
|
||||
raise ValueError("qtable is not supported in JAX backend")
|
||||
try:
|
||||
from .jax.train import train_jax
|
||||
except Exception as exc: # pragma: no cover
|
||||
raise ImportError(f"Failed to import JAX trainer: {exc}") from exc
|
||||
_, metrics = train_jax(cfg)
|
||||
elif algo == "qtable":
|
||||
_, metrics = train_qtable(cfg)
|
||||
else:
|
||||
_, metrics = train_sb3(cfg)
|
||||
@@ -357,8 +402,17 @@ def main():
|
||||
p.add_argument("--learning-rate", type=float)
|
||||
p.add_argument("--gamma", type=float)
|
||||
p.add_argument("--revenue-weight", type=float)
|
||||
p.add_argument("--max-steps", type=int)
|
||||
p.add_argument("--margin-floor", type=float)
|
||||
p.add_argument("--margin-floor-patience", type=int)
|
||||
p.add_argument("--arch", type=str)
|
||||
p.add_argument("--activation", type=str)
|
||||
p.add_argument("--jax", action="store_true")
|
||||
p.add_argument("--jax-num-envs", type=int)
|
||||
p.add_argument("--jax-num-steps", type=int)
|
||||
p.add_argument("--jax-num-minibatches", type=int)
|
||||
p.add_argument("--jax-update-epochs", type=int)
|
||||
p.add_argument("--jax-anneal-lr", type=str)
|
||||
p.add_argument("--sweep-agent", action="store_true")
|
||||
p.add_argument("--sweep-id", type=str)
|
||||
p.add_argument("--count", type=int, default=0)
|
||||
@@ -377,8 +431,19 @@ def main():
|
||||
"learning_rate": args.learning_rate,
|
||||
"gamma": args.gamma,
|
||||
"revenue_weight": args.revenue_weight,
|
||||
"max_steps": args.max_steps,
|
||||
"margin_floor": args.margin_floor,
|
||||
"margin_floor_patience": args.margin_floor_patience,
|
||||
"arch": args.arch,
|
||||
"activation": args.activation,
|
||||
"use_jax": args.jax,
|
||||
"jax_num_envs": args.jax_num_envs,
|
||||
"jax_num_steps": args.jax_num_steps,
|
||||
"jax_num_minibatches": args.jax_num_minibatches,
|
||||
"jax_update_epochs": args.jax_update_epochs,
|
||||
"jax_anneal_lr": _truthy(args.jax_anneal_lr)
|
||||
if args.jax_anneal_lr is not None
|
||||
else None,
|
||||
}
|
||||
overrides = {k: v for k, v in overrides.items() if v is not None}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user