refactoring training spc setup and benchmarking

This commit is contained in:
2026-03-08 18:30:53 +01:00
parent 9fafb26ec8
commit 73246d7dd8
36 changed files with 2180 additions and 613 deletions

View File

@@ -0,0 +1 @@
__all__ = ["evaluate", "make_env", "train_jax_backend", "train_qtable", "train_sb3"]

81
engine/backends/common.py Normal file
View File

@@ -0,0 +1,81 @@
from __future__ import annotations
from typing import Any, Mapping
import numpy as np
def make_env(cfg: Mapping[str, Any]):
from gymnasium.wrappers import FlattenObservation
from ..lib.wrappers import EconomicMetricsWrapper
from ..wrapper import PHANTOM
env = PHANTOM(
n_products=int(cfg["n_products"]),
alpha=float(cfg["alpha"]),
N=int(cfg["N"]),
price_bounds=(float(cfg["price_low"]), float(cfg["price_high"])),
lambda_coi=float(cfg["lambda_coi"]),
robust_radius=float(cfg["robust_radius"]),
robust_points=int(cfg["robust_points"]),
info_value=float(cfg["info_value"]),
action_levels=int(cfg["action_levels"]),
action_scale_low=float(cfg["action_scale_low"]),
action_scale_high=float(cfg["action_scale_high"]),
max_steps=int(cfg.get("max_steps", 100)),
margin_floor=float(cfg.get("margin_floor", 0.05)),
margin_floor_patience=int(cfg.get("margin_floor_patience", 5)),
render_mode=None,
)
env = EconomicMetricsWrapper(env)
return FlattenObservation(env)
def _action(agent: Any, obs: Any, deterministic: bool = True):
out = agent.predict(obs, deterministic=deterministic)
action = out[0] if isinstance(out, tuple) else out
if isinstance(action, np.ndarray) and action.size == 1:
return int(action.reshape(-1)[0])
return action
def evaluate(agent: Any, env: Any, episodes: int) -> dict[str, float]:
rewards: list[float] = []
revenues: list[float] = []
margins: list[float] = []
coi_levels: list[float] = []
for _ in range(int(episodes)):
obs, _ = env.reset()
done = False
ep_reward = 0.0
ep_revenue = 0.0
ep_margin = 0.0
ep_coi = 0.0
steps = 0
while not done:
obs, reward, term, trunc, info = env.step(_action(agent, obs, True))
done = bool(term or trunc)
econ = info.get("economics", {})
ep_reward += float(reward)
ep_revenue += float(econ.get("revenue", info.get("revenue", 0.0)))
ep_margin += float(econ.get("margin", 0.0))
ep_coi += float(econ.get("coi_level", 0.0))
steps += 1
rewards.append(ep_reward)
revenues.append(ep_revenue)
denom = max(steps, 1)
margins.append(ep_margin / denom)
coi_levels.append(ep_coi / denom)
return {
"eval/reward_mean": float(np.mean(rewards)) if rewards else 0.0,
"eval/reward_std": float(np.std(rewards)) if rewards else 0.0,
"eval/revenue_mean": float(np.mean(revenues)) if revenues else 0.0,
"eval/revenue_std": float(np.std(revenues)) if revenues else 0.0,
"eval/margin_mean": float(np.mean(margins)) if margins else 0.0,
"eval/coi_level_mean": float(np.mean(coi_levels)) if coi_levels else 0.0,
}

18
engine/backends/jax.py Normal file
View File

@@ -0,0 +1,18 @@
from __future__ import annotations
from typing import Any, Mapping
from ..jax import JAX_AVAILABLE
def train_jax_backend(
cfg: Mapping[str, Any],
) -> tuple[dict[str, Any], dict[str, float | int | str]]:
if not JAX_AVAILABLE:
raise ImportError(
"JAX backend requested but JAX is not installed. "
"Install engine/jax/requirements.txt and jax[tpu] for TPU runs."
)
from ..jax.train import train_jax
return train_jax(dict(cfg))

53
engine/backends/qtable.py Normal file
View File

@@ -0,0 +1,53 @@
from __future__ import annotations
from typing import Any, Mapping
import numpy as np
from .common import evaluate, make_env
def train_qtable(cfg: Mapping[str, Any]) -> tuple[object, dict[str, float | int]]:
from ..lib.discrete import EventQTable
np.random.seed(int(cfg["seed"]))
env = make_env(cfg)
eval_env = make_env(cfg)
agent = EventQTable(
env.action_space.n,
int(cfg["n_products"]),
(float(cfg["price_low"]), float(cfg["price_high"])),
lr=float(cfg["q_lr"]),
gamma=float(cfg["gamma"]),
n_bins=int(cfg["q_bins"]),
)
total_reward = 0.0
total_revenue = 0.0
steps = 0
epsilon = float(cfg["eps_start"])
obs, _ = env.reset(seed=int(cfg["seed"]))
for _ in range(int(cfg["total_timesteps"])):
action, state = agent.act(obs, epsilon)
nxt, reward, term, trunc, info = env.step(action)
done = bool(term or trunc)
agent.update(state, action, float(reward), agent.encode(nxt), done)
total_reward += float(reward)
total_revenue += float(info.get("economics", {}).get("revenue", 0.0))
steps += 1
epsilon = max(float(cfg["eps_end"]), epsilon * float(cfg["eps_decay"]))
obs = env.reset()[0] if done else nxt
metrics: dict[str, float | int] = {
"train/reward_mean": total_reward / max(steps, 1),
"train/revenue_mean": total_revenue / max(steps, 1),
"train/epsilon": float(epsilon),
"train/global_step": int(cfg["total_timesteps"]),
}
metrics.update(evaluate(agent, eval_env, int(cfg["eval_episodes"])))
env.close()
eval_env.close()
return agent, metrics

228
engine/backends/sb3.py Normal file
View File

@@ -0,0 +1,228 @@
from __future__ import annotations
import json
from pathlib import Path
from typing import Any, Mapping
from ..lib.callbacks import CheckpointArtifactCallback, MetricsCallback
from ..telemetry.wandb import get_wandb_module
from ..wandb_checkpoint import checkpoint_artifact_name, download_latest_checkpoint
from .common import evaluate, make_env
def _net_arch(name: Any) -> list[int]:
presets = {
"tiny": [32, 32],
"small": [64, 64],
"medium": [128, 128],
"large": [256, 256],
}
if isinstance(name, (list, tuple)):
return [int(v) for v in name]
raw = str(name).lower().strip()
if raw in presets:
return presets[raw]
if "x" in raw:
try:
parsed = [int(v) for v in raw.split("x") if v]
return parsed if parsed else presets["small"]
except ValueError:
return presets["small"]
return presets["small"]
def _activation(name: Any):
try:
import torch.nn as nn
except ImportError:
return None
return {
"relu": nn.ReLU,
"tanh": nn.Tanh,
"elu": nn.ELU,
"leaky_relu": nn.LeakyReLU,
}.get(str(name).lower().strip(), nn.ReLU)
def _policy_kwargs(cfg: Mapping[str, Any]) -> dict[str, Any]:
kwargs: dict[str, Any] = {"net_arch": _net_arch(cfg.get("arch", "small"))}
activation = _activation(cfg.get("activation", "relu"))
if activation is not None:
kwargs["activation_fn"] = activation
return kwargs
def _sb3_model_cls(algo: str):
try:
from stable_baselines3 import A2C, DQN, PPO
except ImportError as exc:
raise ImportError("stable-baselines3 is required for SB3 algorithms") from exc
if algo == "ppo":
return PPO
if algo == "a2c":
return A2C
if algo == "dqn":
return DQN
raise ValueError(f"unsupported algo '{algo}'")
def build_model(cfg: Mapping[str, Any], env: Any):
try:
from stable_baselines3 import A2C, DQN, PPO
except ImportError as exc:
raise ImportError("stable-baselines3 is required for SB3 algorithms") from exc
algo = str(cfg["algo"])
policy_kwargs = _policy_kwargs(cfg)
device = str(cfg.get("device", "auto"))
seed = int(cfg["seed"])
if algo == "sac":
raise ValueError("sac is not supported with the discrete core env")
if algo == "ppo":
return PPO(
"MlpPolicy",
env,
verbose=1,
device=device,
policy_kwargs=policy_kwargs,
seed=seed,
learning_rate=float(cfg["learning_rate"]),
n_steps=int(cfg["n_steps"]),
batch_size=int(cfg["batch_size"]),
n_epochs=int(cfg["n_epochs"]),
gamma=float(cfg["gamma"]),
gae_lambda=float(cfg["gae_lambda"]),
clip_range=float(cfg["clip_range"]),
ent_coef=float(cfg["ent_coef"]),
)
if algo == "a2c":
return A2C(
"MlpPolicy",
env,
verbose=1,
device=device,
policy_kwargs=policy_kwargs,
seed=seed,
learning_rate=float(cfg["learning_rate"]),
n_steps=max(5, int(cfg["n_steps"]) // 32),
gamma=float(cfg["gamma"]),
gae_lambda=float(cfg["gae_lambda"]),
ent_coef=float(cfg["ent_coef"]),
)
if algo == "dqn":
return DQN(
"MlpPolicy",
env,
verbose=1,
device=device,
policy_kwargs=policy_kwargs,
seed=seed,
learning_rate=float(cfg["learning_rate"]),
buffer_size=int(cfg["buffer_size"]),
batch_size=int(cfg["batch_size"]),
gamma=float(cfg["gamma"]),
train_freq=int(cfg["train_freq"]),
learning_starts=int(cfg["learning_starts"]),
target_update_interval=int(cfg["target_update_interval"]),
exploration_fraction=float(cfg["exploration_fraction"]),
exploration_final_eps=float(cfg["exploration_final_eps"]),
)
raise ValueError(f"unsupported algo '{algo}'")
def _maybe_resume_model(cfg: Mapping[str, Any], env: Any, model: Any):
wandb = get_wandb_module()
if wandb is None or wandb.run is None:
return model
sweep_id = getattr(wandb.run, "sweep_id", None)
artifact_name = checkpoint_artifact_name(cfg, backend="sb3", sweep_id=sweep_id)
checkpoint_file = f"phantom_{cfg['algo']}_checkpoint.zip"
restored = download_latest_checkpoint(artifact_name, file_name=checkpoint_file)
if restored is None:
return model
checkpoint_path, metadata = restored
resumed = _sb3_model_cls(str(cfg["algo"]).lower()).load(
checkpoint_path.as_posix(),
env=env,
)
resume_step = int(metadata.get("step", getattr(resumed, "num_timesteps", 0)))
resumed.num_timesteps = max(int(getattr(resumed, "num_timesteps", 0)), resume_step)
return resumed
def train_sb3(cfg: Mapping[str, Any]) -> tuple[object, dict[str, float | int | str]]:
try:
from stable_baselines3.common.callbacks import EvalCallback
from stable_baselines3.common.monitor import Monitor
except ImportError as exc:
raise ImportError("stable-baselines3 is required for SB3 models") from exc
env = Monitor(make_env(cfg))
eval_env = Monitor(make_env(cfg))
model = build_model(cfg, env)
try:
import torch
print(
"PHANTOM_DEVICE: "
+ json.dumps(
{
"requested": str(cfg.get("device", "auto")),
"torch_cuda_available": bool(torch.cuda.is_available()),
"torch_device_count": int(torch.cuda.device_count()),
"sb3_device": str(getattr(model, "device", "unknown")),
}
)
)
except Exception:
pass
model = _maybe_resume_model(cfg, env, model)
callbacks = [MetricsCallback(log_histograms=False, log_freq=int(cfg["log_freq"]))]
callbacks.append(
CheckpointArtifactCallback(
dict(cfg),
interval=int(cfg.get("checkpoint_interval", 10_000)),
)
)
callbacks.append(
EvalCallback(
eval_env,
eval_freq=int(cfg["eval_freq"]),
n_eval_episodes=int(cfg["eval_episodes"]),
deterministic=True,
verbose=0,
)
)
target_steps = int(cfg["total_timesteps"])
remaining_steps = max(0, target_steps - int(getattr(model, "num_timesteps", 0)))
if remaining_steps > 0:
model.learn(
total_timesteps=remaining_steps,
callback=callbacks,
reset_num_timesteps=False,
)
model_dir = Path(str(cfg["model_dir"]))
model_dir.mkdir(parents=True, exist_ok=True)
model_path = model_dir / f"phantom_{cfg['algo']}"
model.save(str(model_path))
metrics: dict[str, float | int | str] = evaluate(
model,
eval_env,
int(cfg["eval_episodes"]),
)
metrics["train/global_step"] = int(model.num_timesteps)
metrics["model/path"] = str(model_path.with_suffix(".zip"))
env.close()
eval_env.close()
return model, metrics