refactoring training spc setup and benchmarking

This commit is contained in:
2026-03-08 18:30:53 +01:00
parent 9fafb26ec8
commit 73246d7dd8
36 changed files with 2180 additions and 613 deletions

View File

@@ -624,8 +624,8 @@ def evaluate_policy(
revenues.append(ep_revenue)
return {
"eval/reward": float(np.mean(rewards)),
"eval/revenue": float(np.mean(revenues)),
"eval/reward_mean": float(np.mean(rewards)),
"eval/revenue_mean": float(np.mean(revenues)),
"eval/reward_std": float(np.std(rewards)),
"eval/revenue_std": float(np.std(revenues)),
}
@@ -665,8 +665,8 @@ def _evaluate_q_network(
revenues.append(ep_revenue)
return {
"eval/reward": float(np.mean(rewards)),
"eval/revenue": float(np.mean(revenues)),
"eval/reward_mean": float(np.mean(rewards)),
"eval/revenue_mean": float(np.mean(revenues)),
"eval/reward_std": float(np.std(rewards)),
"eval/revenue_std": float(np.std(revenues)),
}
@@ -713,8 +713,8 @@ def _evaluate_q_table(
revenues.append(ep_revenue)
return {
"eval/reward": float(np.mean(rewards)),
"eval/revenue": float(np.mean(revenues)),
"eval/reward_mean": float(np.mean(rewards)),
"eval/revenue_mean": float(np.mean(revenues)),
"eval/reward_std": float(np.std(rewards)),
"eval/revenue_std": float(np.std(revenues)),
}
@@ -831,8 +831,8 @@ def _train_actor_critic(
if is_primary and HAS_WANDB and wandb.run is not None:
wandb.log(
{
"train/reward": float(segment_values["reward"].mean()),
"train/revenue": float(segment_values["revenue"].mean()),
"train/reward_mean": float(segment_values["reward"].mean()),
"train/revenue_mean": float(segment_values["revenue"].mean()),
"train/agent_prob": float(segment_values["agent_prob"].mean()),
"train/alpha_adv": float(segment_values["alpha_adv"].mean()),
"train/coi_leakage": float(segment_values["coi_leakage"].mean()),
@@ -873,8 +873,8 @@ def _train_actor_critic(
train_state = final_runner[0]
denom = float(metric_count) if metric_count > 0 else 1.0
metrics = {
"train/reward": float(metric_sums["reward"] / denom),
"train/revenue": float(metric_sums["revenue"] / denom),
"train/reward_mean": float(metric_sums["reward"] / denom),
"train/revenue_mean": float(metric_sums["revenue"] / denom),
"train/agent_prob": float(metric_sums["agent_prob"] / denom),
"train/alpha_adv": float(metric_sums["alpha_adv"] / denom),
"train/coi_leakage": float(metric_sums["coi_leakage"] / denom),
@@ -1052,14 +1052,14 @@ def _train_dqn(cfg: dict[str, Any]) -> tuple[dict[str, Any], dict[str, float]]:
):
wandb.log(
{
"train/reward": metric_sums["reward"] / max(metric_count, 1),
"train/revenue": metric_sums["revenue"] / max(metric_count, 1),
"train/reward_mean": metric_sums["reward"] / max(metric_count, 1),
"train/revenue_mean": metric_sums["revenue"] / max(metric_count, 1),
"train/agent_prob": metric_sums["agent_prob"]
/ max(metric_count, 1),
"train/alpha_adv": metric_sums["alpha_adv"] / max(metric_count, 1),
"train/coi_leakage": metric_sums["coi_leakage"]
/ max(metric_count, 1),
"train/dqn_loss": metric_sums["loss"] / max(loss_count, 1),
"train/loss": metric_sums["loss"] / max(loss_count, 1),
"train/epsilon": epsilon_value,
"train/global_step": global_step,
},
@@ -1090,12 +1090,12 @@ def _train_dqn(cfg: dict[str, Any]) -> tuple[dict[str, Any], dict[str, float]]:
denom = float(metric_count) if metric_count > 0 else 1.0
metrics = {
"train/reward": float(metric_sums["reward"] / denom),
"train/revenue": float(metric_sums["revenue"] / denom),
"train/reward_mean": float(metric_sums["reward"] / denom),
"train/revenue_mean": float(metric_sums["revenue"] / denom),
"train/agent_prob": float(metric_sums["agent_prob"] / denom),
"train/alpha_adv": float(metric_sums["alpha_adv"] / denom),
"train/coi_leakage": float(metric_sums["coi_leakage"] / denom),
"train/dqn_loss": float(metric_sums["loss"] / max(loss_count, 1)),
"train/loss": float(metric_sums["loss"] / max(loss_count, 1)),
"train/global_step": total_steps,
}
@@ -1236,8 +1236,8 @@ def _train_qtable(cfg: dict[str, Any]) -> tuple[dict[str, Any], dict[str, float]
):
wandb.log(
{
"train/reward": metric_sums["reward"] / max(metric_count, 1),
"train/revenue": metric_sums["revenue"] / max(metric_count, 1),
"train/reward_mean": metric_sums["reward"] / max(metric_count, 1),
"train/revenue_mean": metric_sums["revenue"] / max(metric_count, 1),
"train/agent_prob": metric_sums["agent_prob"]
/ max(metric_count, 1),
"train/alpha_adv": metric_sums["alpha_adv"] / max(metric_count, 1),
@@ -1269,8 +1269,8 @@ def _train_qtable(cfg: dict[str, Any]) -> tuple[dict[str, Any], dict[str, float]
denom = float(metric_count) if metric_count > 0 else 1.0
metrics = {
"train/reward": float(metric_sums["reward"] / denom),
"train/revenue": float(metric_sums["revenue"] / denom),
"train/reward_mean": float(metric_sums["reward"] / denom),
"train/revenue_mean": float(metric_sums["revenue"] / denom),
"train/agent_prob": float(metric_sums["agent_prob"] / denom),
"train/alpha_adv": float(metric_sums["alpha_adv"] / denom),
"train/coi_leakage": float(metric_sums["coi_leakage"] / denom),