mirror of
https://github.com/velocitatem/PHANTOM.git
synced 2026-05-31 08:33:36 +00:00
refactoring training spc setup and benchmarking
This commit is contained in:
340
engine/spec.py
Normal file
340
engine/spec.py
Normal file
@@ -0,0 +1,340 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
import os
|
||||
from typing import Any, Mapping, Sequence
|
||||
|
||||
|
||||
def _truthy(value: str | bool | None) -> bool:
|
||||
if isinstance(value, bool):
|
||||
return value
|
||||
if value is None:
|
||||
return False
|
||||
return str(value).strip().lower() in {"1", "true", "yes", "on"}
|
||||
|
||||
|
||||
def _normalize_keys(raw: Mapping[str, Any]) -> dict[str, Any]:
|
||||
alias_map = {
|
||||
"algorithm": "algo",
|
||||
"algorithm.name": "algo",
|
||||
"env.n_products": "n_products",
|
||||
"env.action_levels": "action_levels",
|
||||
"env.action_scale_low": "action_scale_low",
|
||||
"env.action_scale_high": "action_scale_high",
|
||||
"env.price_low": "price_low",
|
||||
"env.price_high": "price_high",
|
||||
"env.max_steps": "max_steps",
|
||||
"env.margin_floor": "margin_floor",
|
||||
"env.margin_floor_patience": "margin_floor_patience",
|
||||
"env.n_sessions": "N",
|
||||
"study.alpha": "alpha",
|
||||
"study.lambda_coi": "lambda_coi",
|
||||
"study.robust_radius": "robust_radius",
|
||||
"study.robust_points": "robust_points",
|
||||
"study.info_value": "info_value",
|
||||
"study.revenue_weight": "revenue_weight",
|
||||
"optimizer.learning_rate": "learning_rate",
|
||||
"optimizer.gamma": "gamma",
|
||||
"optimizer.batch_size": "batch_size",
|
||||
"optimizer.n_steps": "n_steps",
|
||||
"runtime.backend": "backend",
|
||||
"runtime.device": "device",
|
||||
"runtime.seed": "seed",
|
||||
"runtime.total_timesteps": "total_timesteps",
|
||||
"runtime.checkpoint_interval": "checkpoint_interval",
|
||||
"eval.eval_freq": "eval_freq",
|
||||
"eval.eval_episodes": "eval_episodes",
|
||||
}
|
||||
normalized: dict[str, Any] = {}
|
||||
for key, value in raw.items():
|
||||
canonical = alias_map.get(str(key), str(key))
|
||||
normalized[canonical] = value
|
||||
return normalized
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AlgorithmSpec:
|
||||
name: str = "ppo"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class EnvSpec:
|
||||
n_products: int = 10
|
||||
n_sessions: int = 100
|
||||
price_low: float = 10.0
|
||||
price_high: float = 150.0
|
||||
action_levels: int = 9
|
||||
action_scale_low: float = 0.8
|
||||
action_scale_high: float = 1.2
|
||||
max_steps: int = 100
|
||||
margin_floor: float = 0.05
|
||||
margin_floor_patience: int = 5
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class StudySpec:
|
||||
alpha: float = 0.3
|
||||
lambda_coi: float = 0.2
|
||||
robust_radius: float = 0.15
|
||||
robust_points: int = 5
|
||||
info_value: float = 1.0
|
||||
revenue_weight: float = 0.01
|
||||
no_robust: bool = False
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class OptimizerSpec:
|
||||
learning_rate: float = 3e-4
|
||||
gamma: float = 0.99
|
||||
buffer_size: int = 50_000
|
||||
batch_size: int = 256
|
||||
tau: float = 0.005
|
||||
train_freq: int = 1
|
||||
learning_starts: int = 1_000
|
||||
target_update_interval: int = 1_000
|
||||
exploration_fraction: float = 0.2
|
||||
exploration_final_eps: float = 0.05
|
||||
n_steps: int = 2_048
|
||||
n_epochs: int = 10
|
||||
gae_lambda: float = 0.95
|
||||
clip_range: float = 0.2
|
||||
ent_coef: float = 0.0
|
||||
q_lr: float = 0.1
|
||||
q_bins: int = 6
|
||||
eps_start: float = 1.0
|
||||
eps_end: float = 0.05
|
||||
eps_decay: float = 0.9995
|
||||
arch: str = "small"
|
||||
activation: str = "relu"
|
||||
jax_num_envs: int = 16
|
||||
jax_num_steps: int = 128
|
||||
jax_num_minibatches: int = 4
|
||||
jax_update_epochs: int = 4
|
||||
jax_anneal_lr: bool = True
|
||||
vf_coef: float = 0.5
|
||||
max_grad_norm: float = 0.5
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RuntimeSpec:
|
||||
project: str = "capstone"
|
||||
backend: str = "sb3"
|
||||
device: str = "auto"
|
||||
seed: int = 42
|
||||
total_timesteps: int = 50_000
|
||||
checkpoint_interval: int = 200_000
|
||||
model_dir: str = "engine/models"
|
||||
log_freq: int = 100
|
||||
use_jax: bool = False
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class EvalSpec:
|
||||
eval_freq: int = 1_000
|
||||
eval_episodes: int = 5
|
||||
robust_eval_enabled: bool = True
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TrainSpec:
|
||||
algorithm: AlgorithmSpec = field(default_factory=AlgorithmSpec)
|
||||
env: EnvSpec = field(default_factory=EnvSpec)
|
||||
study: StudySpec = field(default_factory=StudySpec)
|
||||
optimizer: OptimizerSpec = field(default_factory=OptimizerSpec)
|
||||
runtime: RuntimeSpec = field(default_factory=RuntimeSpec)
|
||||
eval: EvalSpec = field(default_factory=EvalSpec)
|
||||
|
||||
def to_flat_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"project": self.runtime.project,
|
||||
"algo": self.algorithm.name,
|
||||
"seed": self.runtime.seed,
|
||||
"total_timesteps": self.runtime.total_timesteps,
|
||||
"eval_episodes": self.eval.eval_episodes,
|
||||
"eval_freq": self.eval.eval_freq,
|
||||
"log_freq": self.runtime.log_freq,
|
||||
"model_dir": self.runtime.model_dir,
|
||||
"backend": self.runtime.backend,
|
||||
"device": self.runtime.device,
|
||||
"use_jax": self.runtime.use_jax,
|
||||
"checkpoint_interval": self.runtime.checkpoint_interval,
|
||||
"n_products": self.env.n_products,
|
||||
"N": self.env.n_sessions,
|
||||
"price_low": self.env.price_low,
|
||||
"price_high": self.env.price_high,
|
||||
"action_levels": self.env.action_levels,
|
||||
"action_scale_low": self.env.action_scale_low,
|
||||
"action_scale_high": self.env.action_scale_high,
|
||||
"max_steps": self.env.max_steps,
|
||||
"margin_floor": self.env.margin_floor,
|
||||
"margin_floor_patience": self.env.margin_floor_patience,
|
||||
"alpha": self.study.alpha,
|
||||
"lambda_coi": self.study.lambda_coi,
|
||||
"robust_radius": self.study.robust_radius,
|
||||
"robust_points": self.study.robust_points,
|
||||
"info_value": self.study.info_value,
|
||||
"revenue_weight": self.study.revenue_weight,
|
||||
"no_robust": self.study.no_robust,
|
||||
"learning_rate": self.optimizer.learning_rate,
|
||||
"gamma": self.optimizer.gamma,
|
||||
"buffer_size": self.optimizer.buffer_size,
|
||||
"batch_size": self.optimizer.batch_size,
|
||||
"tau": self.optimizer.tau,
|
||||
"train_freq": self.optimizer.train_freq,
|
||||
"learning_starts": self.optimizer.learning_starts,
|
||||
"target_update_interval": self.optimizer.target_update_interval,
|
||||
"exploration_fraction": self.optimizer.exploration_fraction,
|
||||
"exploration_final_eps": self.optimizer.exploration_final_eps,
|
||||
"n_steps": self.optimizer.n_steps,
|
||||
"n_epochs": self.optimizer.n_epochs,
|
||||
"gae_lambda": self.optimizer.gae_lambda,
|
||||
"clip_range": self.optimizer.clip_range,
|
||||
"ent_coef": self.optimizer.ent_coef,
|
||||
"q_lr": self.optimizer.q_lr,
|
||||
"q_bins": self.optimizer.q_bins,
|
||||
"eps_start": self.optimizer.eps_start,
|
||||
"eps_end": self.optimizer.eps_end,
|
||||
"eps_decay": self.optimizer.eps_decay,
|
||||
"arch": self.optimizer.arch,
|
||||
"activation": self.optimizer.activation,
|
||||
"jax_num_envs": self.optimizer.jax_num_envs,
|
||||
"jax_num_steps": self.optimizer.jax_num_steps,
|
||||
"jax_num_minibatches": self.optimizer.jax_num_minibatches,
|
||||
"jax_update_epochs": self.optimizer.jax_update_epochs,
|
||||
"jax_anneal_lr": self.optimizer.jax_anneal_lr,
|
||||
"vf_coef": self.optimizer.vf_coef,
|
||||
"max_grad_norm": self.optimizer.max_grad_norm,
|
||||
"robust_eval_enabled": self.eval.robust_eval_enabled,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_flat(
|
||||
cls,
|
||||
raw: Mapping[str, Any] | None = None,
|
||||
*,
|
||||
env_vars: Mapping[str, str] | None = None,
|
||||
) -> "TrainSpec":
|
||||
base = cls().to_flat_dict()
|
||||
incoming = _normalize_keys(raw or {})
|
||||
base.update({k: v for k, v in incoming.items() if v is not None})
|
||||
|
||||
runtime_env = os.environ if env_vars is None else env_vars
|
||||
base["device"] = str(
|
||||
base.get("device", runtime_env.get("PHANTOM_DEVICE", "auto"))
|
||||
)
|
||||
|
||||
requested_jax = _truthy(base.get("use_jax")) or _truthy(
|
||||
runtime_env.get("PHANTOM_USE_JAX")
|
||||
)
|
||||
backend = str(base.get("backend", "jax" if requested_jax else "sb3")).lower()
|
||||
if backend == "auto":
|
||||
backend = "jax" if requested_jax else "sb3"
|
||||
if backend == "jax":
|
||||
requested_jax = True
|
||||
|
||||
no_robust = _truthy(base.get("no_robust"))
|
||||
if no_robust:
|
||||
base["lambda_coi"] = 0.0
|
||||
base["robust_radius"] = 0.0
|
||||
base["robust_points"] = 1
|
||||
|
||||
return cls(
|
||||
algorithm=AlgorithmSpec(name=str(base["algo"]).lower().strip()),
|
||||
env=EnvSpec(
|
||||
n_products=int(base["n_products"]),
|
||||
n_sessions=int(base["N"]),
|
||||
price_low=float(base["price_low"]),
|
||||
price_high=float(base["price_high"]),
|
||||
action_levels=int(base["action_levels"]),
|
||||
action_scale_low=float(base["action_scale_low"]),
|
||||
action_scale_high=float(base["action_scale_high"]),
|
||||
max_steps=int(base["max_steps"]),
|
||||
margin_floor=float(base["margin_floor"]),
|
||||
margin_floor_patience=int(base["margin_floor_patience"]),
|
||||
),
|
||||
study=StudySpec(
|
||||
alpha=float(base["alpha"]),
|
||||
lambda_coi=float(base["lambda_coi"]),
|
||||
robust_radius=float(base["robust_radius"]),
|
||||
robust_points=int(base["robust_points"]),
|
||||
info_value=float(base["info_value"]),
|
||||
revenue_weight=float(base["revenue_weight"]),
|
||||
no_robust=no_robust,
|
||||
),
|
||||
optimizer=OptimizerSpec(
|
||||
learning_rate=float(base["learning_rate"]),
|
||||
gamma=float(base["gamma"]),
|
||||
buffer_size=int(base["buffer_size"]),
|
||||
batch_size=int(base["batch_size"]),
|
||||
tau=float(base["tau"]),
|
||||
train_freq=int(base["train_freq"]),
|
||||
learning_starts=int(base["learning_starts"]),
|
||||
target_update_interval=int(base["target_update_interval"]),
|
||||
exploration_fraction=float(base["exploration_fraction"]),
|
||||
exploration_final_eps=float(base["exploration_final_eps"]),
|
||||
n_steps=int(base["n_steps"]),
|
||||
n_epochs=int(base["n_epochs"]),
|
||||
gae_lambda=float(base["gae_lambda"]),
|
||||
clip_range=float(base["clip_range"]),
|
||||
ent_coef=float(base["ent_coef"]),
|
||||
q_lr=float(base["q_lr"]),
|
||||
q_bins=int(base["q_bins"]),
|
||||
eps_start=float(base["eps_start"]),
|
||||
eps_end=float(base["eps_end"]),
|
||||
eps_decay=float(base["eps_decay"]),
|
||||
arch=str(base["arch"]),
|
||||
activation=str(base["activation"]),
|
||||
jax_num_envs=int(base["jax_num_envs"]),
|
||||
jax_num_steps=int(base["jax_num_steps"]),
|
||||
jax_num_minibatches=int(base["jax_num_minibatches"]),
|
||||
jax_update_epochs=int(base["jax_update_epochs"]),
|
||||
jax_anneal_lr=_truthy(base.get("jax_anneal_lr")),
|
||||
vf_coef=float(base["vf_coef"]),
|
||||
max_grad_norm=float(base["max_grad_norm"]),
|
||||
),
|
||||
runtime=RuntimeSpec(
|
||||
project=str(base["project"]),
|
||||
backend=backend,
|
||||
device=str(base["device"]),
|
||||
seed=int(base["seed"]),
|
||||
total_timesteps=int(base["total_timesteps"]),
|
||||
checkpoint_interval=int(base["checkpoint_interval"]),
|
||||
model_dir=str(base["model_dir"]),
|
||||
log_freq=int(base["log_freq"]),
|
||||
use_jax=requested_jax,
|
||||
),
|
||||
eval=EvalSpec(
|
||||
eval_freq=int(base["eval_freq"]),
|
||||
eval_episodes=int(base["eval_episodes"]),
|
||||
robust_eval_enabled=_truthy(base.get("robust_eval_enabled", True)),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def run_name(spec: TrainSpec, *, kind: str, scenario: str) -> str:
|
||||
return (
|
||||
f"{kind}/{spec.algorithm.name}/{spec.runtime.backend}/"
|
||||
f"{spec.runtime.device}/{scenario}/s{spec.runtime.seed}"
|
||||
)
|
||||
|
||||
|
||||
def run_metadata(
|
||||
spec: TrainSpec,
|
||||
*,
|
||||
kind: str,
|
||||
scenario: str,
|
||||
group: str | None = None,
|
||||
tags: Sequence[str] = (),
|
||||
) -> dict[str, Any]:
|
||||
metadata: dict[str, Any] = {
|
||||
"run.kind": str(kind),
|
||||
"run.algo": spec.algorithm.name,
|
||||
"run.backend": spec.runtime.backend,
|
||||
"run.device": spec.runtime.device,
|
||||
"run.scenario": str(scenario),
|
||||
"run.seed": spec.runtime.seed,
|
||||
"run.tags": list(tags),
|
||||
}
|
||||
if group:
|
||||
metadata["run.group"] = group
|
||||
return metadata
|
||||
Reference in New Issue
Block a user