adding naive jax and libraries and make adjustments

This commit is contained in:
2026-02-17 14:48:18 +01:00
parent 66c4a0cd1d
commit 802f31b4a1
17 changed files with 2331 additions and 6 deletions

View File

@@ -1,8 +1,10 @@
from __future__ import annotations
import argparse
import json
import os
from pathlib import Path
import numpy as np
from gymnasium.wrappers import FlattenObservation
try:
import wandb
@@ -20,9 +22,7 @@ try:
except ImportError:
HAS_SB3 = False
from .wrapper import PHANTOM
from .lib import EconomicMetricsWrapper, MetricsCallback
from .lib.discrete import EventQTable
from .jax import JAX_AVAILABLE
DEFAULT_CFG = {
@@ -69,14 +69,34 @@ DEFAULT_CFG = {
"arch": "small",
"activation": "relu",
"q_bins": 6,
"max_steps": 100,
"margin_floor": 0.05,
"margin_floor_patience": 5,
"use_jax": False,
"jax_num_envs": 16,
"jax_num_steps": 128,
"jax_num_minibatches": 4,
"jax_update_epochs": 4,
"jax_anneal_lr": True,
}
def _truthy(value: str | bool | None) -> bool:
if isinstance(value, bool):
return value
if value is None:
return False
return str(value).strip().lower() in {"1", "true", "yes", "on"}
def _cfg(raw: dict | None = None) -> dict:
cfg = dict(DEFAULT_CFG)
if raw:
cfg.update({k: v for k, v in raw.items() if v is not None})
cfg["algo"] = str(cfg["algo"]).lower()
cfg["use_jax"] = _truthy(cfg.get("use_jax")) or _truthy(
os.environ.get("PHANTOM_USE_JAX")
)
return cfg
@@ -89,6 +109,11 @@ def _wandb_cfg_dict() -> dict:
def make_env(cfg: dict):
from gymnasium.wrappers import FlattenObservation
from .wrapper import PHANTOM
from .lib.wrappers import EconomicMetricsWrapper
env = PHANTOM(
n_products=int(cfg["n_products"]),
alpha=float(cfg["alpha"]),
@@ -101,6 +126,9 @@ def make_env(cfg: dict):
action_levels=int(cfg["action_levels"]),
action_scale_low=float(cfg["action_scale_low"]),
action_scale_high=float(cfg["action_scale_high"]),
max_steps=int(cfg.get("max_steps", 100)),
margin_floor=float(cfg.get("margin_floor", 0.05)),
margin_floor_patience=int(cfg.get("margin_floor_patience", 5)),
render_mode=None,
)
env = EconomicMetricsWrapper(env)
@@ -235,6 +263,8 @@ def build_model(cfg: dict, env):
def train_qtable(cfg: dict) -> tuple[EventQTable, dict]:
from .lib.discrete import EventQTable
np.random.seed(int(cfg["seed"]))
env = make_env(cfg)
eval_env = make_env(cfg)
@@ -275,6 +305,8 @@ def train_qtable(cfg: dict) -> tuple[EventQTable, dict]:
def train_sb3(cfg: dict) -> tuple[object, dict]:
if not HAS_SB3:
raise ImportError("stable-baselines3 is required for SB3 models")
from .lib.callbacks import MetricsCallback
env = make_env(cfg)
eval_env = make_env(cfg)
env = Monitor(env)
@@ -303,7 +335,20 @@ def train_sb3(cfg: dict) -> tuple[object, dict]:
def train_once(cfg: dict) -> dict:
algo = cfg["algo"]
if algo == "qtable":
if cfg.get("use_jax"):
if not JAX_AVAILABLE:
raise ImportError(
"JAX backend requested but JAX is not installed. "
"Install engine/jax/requirements.txt and jax[tpu] for TPU runs."
)
if algo == "qtable":
raise ValueError("qtable is not supported in JAX backend")
try:
from .jax.train import train_jax
except Exception as exc: # pragma: no cover
raise ImportError(f"Failed to import JAX trainer: {exc}") from exc
_, metrics = train_jax(cfg)
elif algo == "qtable":
_, metrics = train_qtable(cfg)
else:
_, metrics = train_sb3(cfg)
@@ -357,8 +402,17 @@ def main():
p.add_argument("--learning-rate", type=float)
p.add_argument("--gamma", type=float)
p.add_argument("--revenue-weight", type=float)
p.add_argument("--max-steps", type=int)
p.add_argument("--margin-floor", type=float)
p.add_argument("--margin-floor-patience", type=int)
p.add_argument("--arch", type=str)
p.add_argument("--activation", type=str)
p.add_argument("--jax", action="store_true")
p.add_argument("--jax-num-envs", type=int)
p.add_argument("--jax-num-steps", type=int)
p.add_argument("--jax-num-minibatches", type=int)
p.add_argument("--jax-update-epochs", type=int)
p.add_argument("--jax-anneal-lr", type=str)
p.add_argument("--sweep-agent", action="store_true")
p.add_argument("--sweep-id", type=str)
p.add_argument("--count", type=int, default=0)
@@ -377,8 +431,19 @@ def main():
"learning_rate": args.learning_rate,
"gamma": args.gamma,
"revenue_weight": args.revenue_weight,
"max_steps": args.max_steps,
"margin_floor": args.margin_floor,
"margin_floor_patience": args.margin_floor_patience,
"arch": args.arch,
"activation": args.activation,
"use_jax": args.jax,
"jax_num_envs": args.jax_num_envs,
"jax_num_steps": args.jax_num_steps,
"jax_num_minibatches": args.jax_num_minibatches,
"jax_update_epochs": args.jax_update_epochs,
"jax_anneal_lr": _truthy(args.jax_anneal_lr)
if args.jax_anneal_lr is not None
else None,
}
overrides = {k: v for k, v in overrides.items() if v is not None}