updating engine training for training

This commit is contained in:
2026-03-15 21:14:11 +01:00
parent 19b47aa699
commit 52b4dcdce3
13 changed files with 544 additions and 160 deletions

View File

@@ -32,10 +32,17 @@ def _normalize_keys(raw: Mapping[str, Any]) -> dict[str, Any]:
"study.robust_radius": "robust_radius",
"study.robust_points": "robust_points",
"study.robust_rollouts": "robust_rollouts",
"study.ambiguity_radius": "robust_radius",
"study.ambiguity_points": "robust_points",
"study.ambiguity_rollouts": "robust_rollouts",
"study.info_value": "info_value",
"study.eta_ux": "eta_ux",
"study.reward_profit_weight": "reward_profit_weight",
"study.revenue_weight": "revenue_weight",
"ambiguity_radius": "robust_radius",
"ambiguity_points": "robust_points",
"ambiguity_rollouts": "robust_rollouts",
"baseline_mode": "no_robust",
"stress_eval_enabled": "robust_eval_enabled",
"optimizer.learning_rate": "learning_rate",
"optimizer.gamma": "gamma",
"optimizer.batch_size": "batch_size",
@@ -45,6 +52,7 @@ def _normalize_keys(raw: Mapping[str, Any]) -> dict[str, Any]:
"runtime.seed": "seed",
"runtime.total_timesteps": "total_timesteps",
"runtime.checkpoint_interval": "checkpoint_interval",
"runtime.hist_freq": "hist_freq",
"eval.eval_freq": "eval_freq",
"eval.eval_episodes": "eval_episodes",
}
@@ -86,7 +94,6 @@ class StudySpec:
info_value: float = 1.0
eta_ux: float = 0.5
reward_profit_weight: float = 1.0
revenue_weight: float = 0.01
no_robust: bool = False
@@ -128,6 +135,7 @@ class RuntimeSpec:
checkpoint_interval: int = 200_000
model_dir: str = "engine/models"
log_freq: int = 100
hist_freq: int = 500
@dataclass(frozen=True)
@@ -159,6 +167,7 @@ class TrainSpec:
"backend": self.runtime.backend,
"device": self.runtime.device,
"checkpoint_interval": self.runtime.checkpoint_interval,
"hist_freq": self.runtime.hist_freq,
"n_products": self.env.n_products,
"N": self.env.n_sessions,
"price_low": self.env.price_low,
@@ -179,7 +188,6 @@ class TrainSpec:
"info_value": self.study.info_value,
"eta_ux": self.study.eta_ux,
"reward_profit_weight": self.study.reward_profit_weight,
"revenue_weight": self.study.revenue_weight,
"no_robust": self.study.no_robust,
"learning_rate": self.optimizer.learning_rate,
"gamma": self.optimizer.gamma,
@@ -262,7 +270,6 @@ class TrainSpec:
info_value=float(base["info_value"]),
eta_ux=float(base["eta_ux"]),
reward_profit_weight=float(base["reward_profit_weight"]),
revenue_weight=float(base["revenue_weight"]),
no_robust=no_robust,
),
optimizer=OptimizerSpec(
@@ -300,6 +307,7 @@ class TrainSpec:
checkpoint_interval=int(base["checkpoint_interval"]),
model_dir=str(base["model_dir"]),
log_freq=int(base["log_freq"]),
hist_freq=int(base["hist_freq"]),
),
eval=EvalSpec(
eval_freq=int(base["eval_freq"]),
@@ -310,9 +318,11 @@ class TrainSpec:
def run_name(spec: TrainSpec, *, kind: str, scenario: str) -> str:
alpha_token = f"{float(spec.study.alpha):.2f}".rstrip("0").rstrip(".")
mode = "baseline" if bool(spec.study.no_robust) else "defended"
return (
f"{kind}/{spec.algorithm.name}/{spec.runtime.backend}/"
f"{spec.runtime.device}/{scenario}/s{spec.runtime.seed}"
f"{spec.runtime.device}/{scenario}/a{alpha_token}/{mode}/s{spec.runtime.seed}"
)
@@ -324,6 +334,7 @@ def run_metadata(
group: str | None = None,
tags: Sequence[str] = (),
) -> dict[str, Any]:
mode = "baseline" if bool(spec.study.no_robust) else "defended"
metadata: dict[str, Any] = {
"run.kind": str(kind),
"run.algo": spec.algorithm.name,
@@ -332,6 +343,10 @@ def run_metadata(
"run.scenario": str(scenario),
"run.seed": spec.runtime.seed,
"run.tags": list(tags),
"study/alpha": float(spec.study.alpha),
"study/mode": mode,
"study/baseline_mode": float(bool(spec.study.no_robust)),
"tiers": spec.algorithm.name,
}
if group:
metadata["run.group"] = group