mirror of
https://github.com/velocitatem/PHANTOM.git
synced 2026-05-31 08:33:36 +00:00
refactoring training spc setup and benchmarking
This commit is contained in:
@@ -38,19 +38,19 @@ class MetricsCallback(BaseCallback):
|
||||
t = self.num_timesteps
|
||||
|
||||
payload = {
|
||||
"economics/revenue": econ["revenue"],
|
||||
"economics/margin": econ["margin"],
|
||||
"coi/level": econ["coi_level"],
|
||||
"economics/regret": econ["regret"],
|
||||
"train/revenue_step": econ["revenue"],
|
||||
"train/margin_step": econ["margin"],
|
||||
"train/coi_level": econ["coi_level"],
|
||||
"train/regret_step": econ["regret"],
|
||||
}
|
||||
if "coi_mix" in econ:
|
||||
payload["coi/mix"] = econ["coi_mix"]
|
||||
payload["train/coi_mix"] = econ["coi_mix"]
|
||||
if "coi_base" in econ:
|
||||
payload["coi/base"] = econ["coi_base"]
|
||||
payload["train/coi_base"] = econ["coi_base"]
|
||||
if "coi_leakage" in econ:
|
||||
payload["coi/leakage"] = econ["coi_leakage"]
|
||||
payload["train/coi_leakage"] = econ["coi_leakage"]
|
||||
if "coi_penalty" in econ:
|
||||
payload["coi/penalty"] = econ["coi_penalty"]
|
||||
payload["train/coi_penalty"] = econ["coi_penalty"]
|
||||
wandb.log(payload, step=t)
|
||||
|
||||
self._episode_revenues.append(econ["revenue"])
|
||||
@@ -76,8 +76,8 @@ class MetricsCallback(BaseCallback):
|
||||
return
|
||||
wandb.log(
|
||||
{
|
||||
"episode/mean_revenue": np.mean(self._episode_revenues),
|
||||
"episode/total_revenue": np.sum(self._episode_revenues),
|
||||
"train/revenue_rollout_mean": np.mean(self._episode_revenues),
|
||||
"train/revenue_rollout_total": np.sum(self._episode_revenues),
|
||||
},
|
||||
step=self.num_timesteps,
|
||||
)
|
||||
@@ -164,8 +164,8 @@ class EvalMetricsCallback(EvalCallback):
|
||||
if self.n_calls % self.eval_freq == 0 and hasattr(self, "last_mean_reward"):
|
||||
wandb.log(
|
||||
{
|
||||
"eval/mean_reward": self.last_mean_reward,
|
||||
"eval/mean_revenue": np.mean(self._eval_revenues)
|
||||
"eval/reward_mean": self.last_mean_reward,
|
||||
"eval/revenue_mean": np.mean(self._eval_revenues)
|
||||
if self._eval_revenues
|
||||
else 0,
|
||||
},
|
||||
|
||||
Reference in New Issue
Block a user