fix: supra reward adjustment and sweep

This commit is contained in:
2026-03-16 15:58:05 +01:00
parent 43b952cf2b
commit 3439775fbd
4 changed files with 103 additions and 8 deletions

View File

@@ -54,6 +54,9 @@ def _evaluate_env(agent: Any, env: Any, episodes: int) -> dict[str, float]:
coi_levels: list[float] = []
coi_leakages: list[float] = []
volatilities: list[float] = []
upward_volatilities: list[float] = []
supra_shares: list[float] = []
supra_penalties: list[float] = []
agent_probs: list[float] = []
for _ in range(int(episodes)):
@@ -65,6 +68,9 @@ def _evaluate_env(agent: Any, env: Any, episodes: int) -> dict[str, float]:
ep_coi = 0.0
ep_coi_leakage = 0.0
ep_volatility = 0.0
ep_upward_volatility = 0.0
ep_supra_share = 0.0
ep_supra_penalty = 0.0
ep_agent_prob = 0.0
steps = 0
@@ -78,6 +84,15 @@ def _evaluate_env(agent: Any, env: Any, episodes: int) -> dict[str, float]:
ep_coi += float(econ.get("coi_level", 0.0))
ep_coi_leakage += float(econ.get("coi_leakage", 0.0))
ep_volatility += float(econ.get("volatility", 0.0))
ep_upward_volatility += float(
info.get("upward_volatility", econ.get("upward_volatility", 0.0))
)
ep_supra_share += float(
info.get("supra_share", econ.get("supra_share", 0.0))
)
ep_supra_penalty += float(
info.get("supra_penalty", econ.get("supra_penalty", 0.0))
)
ep_agent_prob += float(econ.get("agent_prob", info.get("agent_prob", 0.0)))
steps += 1
@@ -88,6 +103,9 @@ def _evaluate_env(agent: Any, env: Any, episodes: int) -> dict[str, float]:
coi_levels.append(ep_coi / denom)
coi_leakages.append(ep_coi_leakage / denom)
volatilities.append(ep_volatility / denom)
upward_volatilities.append(ep_upward_volatility / denom)
supra_shares.append(ep_supra_share / denom)
supra_penalties.append(ep_supra_penalty / denom)
agent_probs.append(ep_agent_prob / denom)
return {
@@ -99,6 +117,13 @@ def _evaluate_env(agent: Any, env: Any, episodes: int) -> dict[str, float]:
"eval/coi_level_mean": float(np.mean(coi_levels)) if coi_levels else 0.0,
"eval/coi_leakage_mean": float(np.mean(coi_leakages)) if coi_leakages else 0.0,
"eval/volatility_mean": float(np.mean(volatilities)) if volatilities else 0.0,
"eval/upward_volatility_mean": (
float(np.mean(upward_volatilities)) if upward_volatilities else 0.0
),
"eval/supra_share_mean": float(np.mean(supra_shares)) if supra_shares else 0.0,
"eval/supra_penalty_mean": (
float(np.mean(supra_penalties)) if supra_penalties else 0.0
),
"eval/agent_prob_mean": float(np.mean(agent_probs)) if agent_probs else 0.0,
}