refactoring training spc setup and benchmarking

This commit is contained in:
2026-03-08 18:30:53 +01:00
parent 9fafb26ec8
commit 73246d7dd8
36 changed files with 2180 additions and 613 deletions

View File

@@ -38,19 +38,19 @@ class MetricsCallback(BaseCallback):
t = self.num_timesteps
payload = {
"economics/revenue": econ["revenue"],
"economics/margin": econ["margin"],
"coi/level": econ["coi_level"],
"economics/regret": econ["regret"],
"train/revenue_step": econ["revenue"],
"train/margin_step": econ["margin"],
"train/coi_level": econ["coi_level"],
"train/regret_step": econ["regret"],
}
if "coi_mix" in econ:
payload["coi/mix"] = econ["coi_mix"]
payload["train/coi_mix"] = econ["coi_mix"]
if "coi_base" in econ:
payload["coi/base"] = econ["coi_base"]
payload["train/coi_base"] = econ["coi_base"]
if "coi_leakage" in econ:
payload["coi/leakage"] = econ["coi_leakage"]
payload["train/coi_leakage"] = econ["coi_leakage"]
if "coi_penalty" in econ:
payload["coi/penalty"] = econ["coi_penalty"]
payload["train/coi_penalty"] = econ["coi_penalty"]
wandb.log(payload, step=t)
self._episode_revenues.append(econ["revenue"])
@@ -76,8 +76,8 @@ class MetricsCallback(BaseCallback):
return
wandb.log(
{
"episode/mean_revenue": np.mean(self._episode_revenues),
"episode/total_revenue": np.sum(self._episode_revenues),
"train/revenue_rollout_mean": np.mean(self._episode_revenues),
"train/revenue_rollout_total": np.sum(self._episode_revenues),
},
step=self.num_timesteps,
)
@@ -164,8 +164,8 @@ class EvalMetricsCallback(EvalCallback):
if self.n_calls % self.eval_freq == 0 and hasattr(self, "last_mean_reward"):
wandb.log(
{
"eval/mean_reward": self.last_mean_reward,
"eval/mean_revenue": np.mean(self._eval_revenues)
"eval/reward_mean": self.last_mean_reward,
"eval/revenue_mean": np.mean(self._eval_revenues)
if self._eval_revenues
else 0,
},