mirror of
https://github.com/velocitatem/PHANTOM.git
synced 2026-05-31 08:33:36 +00:00
updating engine training for training
This commit is contained in:
@@ -32,10 +32,17 @@ def _normalize_keys(raw: Mapping[str, Any]) -> dict[str, Any]:
|
||||
"study.robust_radius": "robust_radius",
|
||||
"study.robust_points": "robust_points",
|
||||
"study.robust_rollouts": "robust_rollouts",
|
||||
"study.ambiguity_radius": "robust_radius",
|
||||
"study.ambiguity_points": "robust_points",
|
||||
"study.ambiguity_rollouts": "robust_rollouts",
|
||||
"study.info_value": "info_value",
|
||||
"study.eta_ux": "eta_ux",
|
||||
"study.reward_profit_weight": "reward_profit_weight",
|
||||
"study.revenue_weight": "revenue_weight",
|
||||
"ambiguity_radius": "robust_radius",
|
||||
"ambiguity_points": "robust_points",
|
||||
"ambiguity_rollouts": "robust_rollouts",
|
||||
"baseline_mode": "no_robust",
|
||||
"stress_eval_enabled": "robust_eval_enabled",
|
||||
"optimizer.learning_rate": "learning_rate",
|
||||
"optimizer.gamma": "gamma",
|
||||
"optimizer.batch_size": "batch_size",
|
||||
@@ -45,6 +52,7 @@ def _normalize_keys(raw: Mapping[str, Any]) -> dict[str, Any]:
|
||||
"runtime.seed": "seed",
|
||||
"runtime.total_timesteps": "total_timesteps",
|
||||
"runtime.checkpoint_interval": "checkpoint_interval",
|
||||
"runtime.hist_freq": "hist_freq",
|
||||
"eval.eval_freq": "eval_freq",
|
||||
"eval.eval_episodes": "eval_episodes",
|
||||
}
|
||||
@@ -86,7 +94,6 @@ class StudySpec:
|
||||
info_value: float = 1.0
|
||||
eta_ux: float = 0.5
|
||||
reward_profit_weight: float = 1.0
|
||||
revenue_weight: float = 0.01
|
||||
no_robust: bool = False
|
||||
|
||||
|
||||
@@ -128,6 +135,7 @@ class RuntimeSpec:
|
||||
checkpoint_interval: int = 200_000
|
||||
model_dir: str = "engine/models"
|
||||
log_freq: int = 100
|
||||
hist_freq: int = 500
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@@ -159,6 +167,7 @@ class TrainSpec:
|
||||
"backend": self.runtime.backend,
|
||||
"device": self.runtime.device,
|
||||
"checkpoint_interval": self.runtime.checkpoint_interval,
|
||||
"hist_freq": self.runtime.hist_freq,
|
||||
"n_products": self.env.n_products,
|
||||
"N": self.env.n_sessions,
|
||||
"price_low": self.env.price_low,
|
||||
@@ -179,7 +188,6 @@ class TrainSpec:
|
||||
"info_value": self.study.info_value,
|
||||
"eta_ux": self.study.eta_ux,
|
||||
"reward_profit_weight": self.study.reward_profit_weight,
|
||||
"revenue_weight": self.study.revenue_weight,
|
||||
"no_robust": self.study.no_robust,
|
||||
"learning_rate": self.optimizer.learning_rate,
|
||||
"gamma": self.optimizer.gamma,
|
||||
@@ -262,7 +270,6 @@ class TrainSpec:
|
||||
info_value=float(base["info_value"]),
|
||||
eta_ux=float(base["eta_ux"]),
|
||||
reward_profit_weight=float(base["reward_profit_weight"]),
|
||||
revenue_weight=float(base["revenue_weight"]),
|
||||
no_robust=no_robust,
|
||||
),
|
||||
optimizer=OptimizerSpec(
|
||||
@@ -300,6 +307,7 @@ class TrainSpec:
|
||||
checkpoint_interval=int(base["checkpoint_interval"]),
|
||||
model_dir=str(base["model_dir"]),
|
||||
log_freq=int(base["log_freq"]),
|
||||
hist_freq=int(base["hist_freq"]),
|
||||
),
|
||||
eval=EvalSpec(
|
||||
eval_freq=int(base["eval_freq"]),
|
||||
@@ -310,9 +318,11 @@ class TrainSpec:
|
||||
|
||||
|
||||
def run_name(spec: TrainSpec, *, kind: str, scenario: str) -> str:
|
||||
alpha_token = f"{float(spec.study.alpha):.2f}".rstrip("0").rstrip(".")
|
||||
mode = "baseline" if bool(spec.study.no_robust) else "defended"
|
||||
return (
|
||||
f"{kind}/{spec.algorithm.name}/{spec.runtime.backend}/"
|
||||
f"{spec.runtime.device}/{scenario}/s{spec.runtime.seed}"
|
||||
f"{spec.runtime.device}/{scenario}/a{alpha_token}/{mode}/s{spec.runtime.seed}"
|
||||
)
|
||||
|
||||
|
||||
@@ -324,6 +334,7 @@ def run_metadata(
|
||||
group: str | None = None,
|
||||
tags: Sequence[str] = (),
|
||||
) -> dict[str, Any]:
|
||||
mode = "baseline" if bool(spec.study.no_robust) else "defended"
|
||||
metadata: dict[str, Any] = {
|
||||
"run.kind": str(kind),
|
||||
"run.algo": spec.algorithm.name,
|
||||
@@ -332,6 +343,10 @@ def run_metadata(
|
||||
"run.scenario": str(scenario),
|
||||
"run.seed": spec.runtime.seed,
|
||||
"run.tags": list(tags),
|
||||
"study/alpha": float(spec.study.alpha),
|
||||
"study/mode": mode,
|
||||
"study/baseline_mode": float(bool(spec.study.no_robust)),
|
||||
"tiers": spec.algorithm.name,
|
||||
}
|
||||
if group:
|
||||
metadata["run.group"] = group
|
||||
|
||||
Reference in New Issue
Block a user