refactoring training spc setup and benchmarking

This commit is contained in:
2026-03-08 18:30:53 +01:00
parent 9fafb26ec8
commit 73246d7dd8
36 changed files with 2180 additions and 613 deletions

340
engine/spec.py Normal file
View File

@@ -0,0 +1,340 @@
from __future__ import annotations
from dataclasses import dataclass, field
import os
from typing import Any, Mapping, Sequence
def _truthy(value: str | bool | None) -> bool:
if isinstance(value, bool):
return value
if value is None:
return False
return str(value).strip().lower() in {"1", "true", "yes", "on"}
def _normalize_keys(raw: Mapping[str, Any]) -> dict[str, Any]:
alias_map = {
"algorithm": "algo",
"algorithm.name": "algo",
"env.n_products": "n_products",
"env.action_levels": "action_levels",
"env.action_scale_low": "action_scale_low",
"env.action_scale_high": "action_scale_high",
"env.price_low": "price_low",
"env.price_high": "price_high",
"env.max_steps": "max_steps",
"env.margin_floor": "margin_floor",
"env.margin_floor_patience": "margin_floor_patience",
"env.n_sessions": "N",
"study.alpha": "alpha",
"study.lambda_coi": "lambda_coi",
"study.robust_radius": "robust_radius",
"study.robust_points": "robust_points",
"study.info_value": "info_value",
"study.revenue_weight": "revenue_weight",
"optimizer.learning_rate": "learning_rate",
"optimizer.gamma": "gamma",
"optimizer.batch_size": "batch_size",
"optimizer.n_steps": "n_steps",
"runtime.backend": "backend",
"runtime.device": "device",
"runtime.seed": "seed",
"runtime.total_timesteps": "total_timesteps",
"runtime.checkpoint_interval": "checkpoint_interval",
"eval.eval_freq": "eval_freq",
"eval.eval_episodes": "eval_episodes",
}
normalized: dict[str, Any] = {}
for key, value in raw.items():
canonical = alias_map.get(str(key), str(key))
normalized[canonical] = value
return normalized
@dataclass(frozen=True)
class AlgorithmSpec:
name: str = "ppo"
@dataclass(frozen=True)
class EnvSpec:
n_products: int = 10
n_sessions: int = 100
price_low: float = 10.0
price_high: float = 150.0
action_levels: int = 9
action_scale_low: float = 0.8
action_scale_high: float = 1.2
max_steps: int = 100
margin_floor: float = 0.05
margin_floor_patience: int = 5
@dataclass(frozen=True)
class StudySpec:
alpha: float = 0.3
lambda_coi: float = 0.2
robust_radius: float = 0.15
robust_points: int = 5
info_value: float = 1.0
revenue_weight: float = 0.01
no_robust: bool = False
@dataclass(frozen=True)
class OptimizerSpec:
learning_rate: float = 3e-4
gamma: float = 0.99
buffer_size: int = 50_000
batch_size: int = 256
tau: float = 0.005
train_freq: int = 1
learning_starts: int = 1_000
target_update_interval: int = 1_000
exploration_fraction: float = 0.2
exploration_final_eps: float = 0.05
n_steps: int = 2_048
n_epochs: int = 10
gae_lambda: float = 0.95
clip_range: float = 0.2
ent_coef: float = 0.0
q_lr: float = 0.1
q_bins: int = 6
eps_start: float = 1.0
eps_end: float = 0.05
eps_decay: float = 0.9995
arch: str = "small"
activation: str = "relu"
jax_num_envs: int = 16
jax_num_steps: int = 128
jax_num_minibatches: int = 4
jax_update_epochs: int = 4
jax_anneal_lr: bool = True
vf_coef: float = 0.5
max_grad_norm: float = 0.5
@dataclass(frozen=True)
class RuntimeSpec:
project: str = "capstone"
backend: str = "sb3"
device: str = "auto"
seed: int = 42
total_timesteps: int = 50_000
checkpoint_interval: int = 200_000
model_dir: str = "engine/models"
log_freq: int = 100
use_jax: bool = False
@dataclass(frozen=True)
class EvalSpec:
eval_freq: int = 1_000
eval_episodes: int = 5
robust_eval_enabled: bool = True
@dataclass(frozen=True)
class TrainSpec:
algorithm: AlgorithmSpec = field(default_factory=AlgorithmSpec)
env: EnvSpec = field(default_factory=EnvSpec)
study: StudySpec = field(default_factory=StudySpec)
optimizer: OptimizerSpec = field(default_factory=OptimizerSpec)
runtime: RuntimeSpec = field(default_factory=RuntimeSpec)
eval: EvalSpec = field(default_factory=EvalSpec)
def to_flat_dict(self) -> dict[str, Any]:
return {
"project": self.runtime.project,
"algo": self.algorithm.name,
"seed": self.runtime.seed,
"total_timesteps": self.runtime.total_timesteps,
"eval_episodes": self.eval.eval_episodes,
"eval_freq": self.eval.eval_freq,
"log_freq": self.runtime.log_freq,
"model_dir": self.runtime.model_dir,
"backend": self.runtime.backend,
"device": self.runtime.device,
"use_jax": self.runtime.use_jax,
"checkpoint_interval": self.runtime.checkpoint_interval,
"n_products": self.env.n_products,
"N": self.env.n_sessions,
"price_low": self.env.price_low,
"price_high": self.env.price_high,
"action_levels": self.env.action_levels,
"action_scale_low": self.env.action_scale_low,
"action_scale_high": self.env.action_scale_high,
"max_steps": self.env.max_steps,
"margin_floor": self.env.margin_floor,
"margin_floor_patience": self.env.margin_floor_patience,
"alpha": self.study.alpha,
"lambda_coi": self.study.lambda_coi,
"robust_radius": self.study.robust_radius,
"robust_points": self.study.robust_points,
"info_value": self.study.info_value,
"revenue_weight": self.study.revenue_weight,
"no_robust": self.study.no_robust,
"learning_rate": self.optimizer.learning_rate,
"gamma": self.optimizer.gamma,
"buffer_size": self.optimizer.buffer_size,
"batch_size": self.optimizer.batch_size,
"tau": self.optimizer.tau,
"train_freq": self.optimizer.train_freq,
"learning_starts": self.optimizer.learning_starts,
"target_update_interval": self.optimizer.target_update_interval,
"exploration_fraction": self.optimizer.exploration_fraction,
"exploration_final_eps": self.optimizer.exploration_final_eps,
"n_steps": self.optimizer.n_steps,
"n_epochs": self.optimizer.n_epochs,
"gae_lambda": self.optimizer.gae_lambda,
"clip_range": self.optimizer.clip_range,
"ent_coef": self.optimizer.ent_coef,
"q_lr": self.optimizer.q_lr,
"q_bins": self.optimizer.q_bins,
"eps_start": self.optimizer.eps_start,
"eps_end": self.optimizer.eps_end,
"eps_decay": self.optimizer.eps_decay,
"arch": self.optimizer.arch,
"activation": self.optimizer.activation,
"jax_num_envs": self.optimizer.jax_num_envs,
"jax_num_steps": self.optimizer.jax_num_steps,
"jax_num_minibatches": self.optimizer.jax_num_minibatches,
"jax_update_epochs": self.optimizer.jax_update_epochs,
"jax_anneal_lr": self.optimizer.jax_anneal_lr,
"vf_coef": self.optimizer.vf_coef,
"max_grad_norm": self.optimizer.max_grad_norm,
"robust_eval_enabled": self.eval.robust_eval_enabled,
}
@classmethod
def from_flat(
cls,
raw: Mapping[str, Any] | None = None,
*,
env_vars: Mapping[str, str] | None = None,
) -> "TrainSpec":
base = cls().to_flat_dict()
incoming = _normalize_keys(raw or {})
base.update({k: v for k, v in incoming.items() if v is not None})
runtime_env = os.environ if env_vars is None else env_vars
base["device"] = str(
base.get("device", runtime_env.get("PHANTOM_DEVICE", "auto"))
)
requested_jax = _truthy(base.get("use_jax")) or _truthy(
runtime_env.get("PHANTOM_USE_JAX")
)
backend = str(base.get("backend", "jax" if requested_jax else "sb3")).lower()
if backend == "auto":
backend = "jax" if requested_jax else "sb3"
if backend == "jax":
requested_jax = True
no_robust = _truthy(base.get("no_robust"))
if no_robust:
base["lambda_coi"] = 0.0
base["robust_radius"] = 0.0
base["robust_points"] = 1
return cls(
algorithm=AlgorithmSpec(name=str(base["algo"]).lower().strip()),
env=EnvSpec(
n_products=int(base["n_products"]),
n_sessions=int(base["N"]),
price_low=float(base["price_low"]),
price_high=float(base["price_high"]),
action_levels=int(base["action_levels"]),
action_scale_low=float(base["action_scale_low"]),
action_scale_high=float(base["action_scale_high"]),
max_steps=int(base["max_steps"]),
margin_floor=float(base["margin_floor"]),
margin_floor_patience=int(base["margin_floor_patience"]),
),
study=StudySpec(
alpha=float(base["alpha"]),
lambda_coi=float(base["lambda_coi"]),
robust_radius=float(base["robust_radius"]),
robust_points=int(base["robust_points"]),
info_value=float(base["info_value"]),
revenue_weight=float(base["revenue_weight"]),
no_robust=no_robust,
),
optimizer=OptimizerSpec(
learning_rate=float(base["learning_rate"]),
gamma=float(base["gamma"]),
buffer_size=int(base["buffer_size"]),
batch_size=int(base["batch_size"]),
tau=float(base["tau"]),
train_freq=int(base["train_freq"]),
learning_starts=int(base["learning_starts"]),
target_update_interval=int(base["target_update_interval"]),
exploration_fraction=float(base["exploration_fraction"]),
exploration_final_eps=float(base["exploration_final_eps"]),
n_steps=int(base["n_steps"]),
n_epochs=int(base["n_epochs"]),
gae_lambda=float(base["gae_lambda"]),
clip_range=float(base["clip_range"]),
ent_coef=float(base["ent_coef"]),
q_lr=float(base["q_lr"]),
q_bins=int(base["q_bins"]),
eps_start=float(base["eps_start"]),
eps_end=float(base["eps_end"]),
eps_decay=float(base["eps_decay"]),
arch=str(base["arch"]),
activation=str(base["activation"]),
jax_num_envs=int(base["jax_num_envs"]),
jax_num_steps=int(base["jax_num_steps"]),
jax_num_minibatches=int(base["jax_num_minibatches"]),
jax_update_epochs=int(base["jax_update_epochs"]),
jax_anneal_lr=_truthy(base.get("jax_anneal_lr")),
vf_coef=float(base["vf_coef"]),
max_grad_norm=float(base["max_grad_norm"]),
),
runtime=RuntimeSpec(
project=str(base["project"]),
backend=backend,
device=str(base["device"]),
seed=int(base["seed"]),
total_timesteps=int(base["total_timesteps"]),
checkpoint_interval=int(base["checkpoint_interval"]),
model_dir=str(base["model_dir"]),
log_freq=int(base["log_freq"]),
use_jax=requested_jax,
),
eval=EvalSpec(
eval_freq=int(base["eval_freq"]),
eval_episodes=int(base["eval_episodes"]),
robust_eval_enabled=_truthy(base.get("robust_eval_enabled", True)),
),
)
def run_name(spec: TrainSpec, *, kind: str, scenario: str) -> str:
return (
f"{kind}/{spec.algorithm.name}/{spec.runtime.backend}/"
f"{spec.runtime.device}/{scenario}/s{spec.runtime.seed}"
)
def run_metadata(
spec: TrainSpec,
*,
kind: str,
scenario: str,
group: str | None = None,
tags: Sequence[str] = (),
) -> dict[str, Any]:
metadata: dict[str, Any] = {
"run.kind": str(kind),
"run.algo": spec.algorithm.name,
"run.backend": spec.runtime.backend,
"run.device": spec.runtime.device,
"run.scenario": str(scenario),
"run.seed": spec.runtime.seed,
"run.tags": list(tags),
}
if group:
metadata["run.group"] = group
return metadata