cleaning up jax bs

This commit is contained in:
2026-03-08 19:15:58 +01:00
parent 73246d7dd8
commit 4c658a93a7
27 changed files with 173 additions and 3146 deletions

View File

@@ -7,7 +7,9 @@ import numpy as np
from .common import evaluate, make_env
def train_qtable(cfg: Mapping[str, Any]) -> tuple[object, dict[str, float | int]]:
def train_qtable(
cfg: Mapping[str, Any],
) -> tuple[object, dict[str, Any]]:
from ..lib.discrete import EventQTable
np.random.seed(int(cfg["seed"]))
@@ -26,8 +28,19 @@ def train_qtable(cfg: Mapping[str, Any]) -> tuple[object, dict[str, float | int]
total_revenue = 0.0
steps = 0
epsilon = float(cfg["eps_start"])
log_freq = max(1, int(cfg.get("log_freq", 100)))
obs, _ = env.reset(seed=int(cfg["seed"]))
interval_sums = {
"reward": 0.0,
"revenue": 0.0,
"agent_prob": 0.0,
"alpha_adv": 0.0,
"coi_leakage": 0.0,
}
interval_count = 0
train_events: list[dict[str, float | int]] = []
for _ in range(int(cfg["total_timesteps"])):
action, state = agent.act(obs, epsilon)
nxt, reward, term, trunc, info = env.step(action)
@@ -35,18 +48,57 @@ def train_qtable(cfg: Mapping[str, Any]) -> tuple[object, dict[str, float | int]
agent.update(state, action, float(reward), agent.encode(nxt), done)
total_reward += float(reward)
total_revenue += float(info.get("economics", {}).get("revenue", 0.0))
revenue = float(info.get("economics", {}).get("revenue", 0.0))
total_revenue += revenue
steps += 1
interval_sums["reward"] += float(reward)
interval_sums["revenue"] += revenue
interval_sums["agent_prob"] += float(info.get("agent_prob", 0.0))
interval_sums["alpha_adv"] += float(info.get("alpha_adv", 0.0))
interval_sums["coi_leakage"] += float(info.get("coi_leakage", 0.0))
interval_count += 1
if steps % log_freq == 0 and interval_count > 0:
denom = float(interval_count)
train_events.append(
{
"train/reward_mean": interval_sums["reward"] / denom,
"train/revenue_mean": interval_sums["revenue"] / denom,
"train/agent_prob": interval_sums["agent_prob"] / denom,
"train/alpha_adv": interval_sums["alpha_adv"] / denom,
"train/coi_leakage": interval_sums["coi_leakage"] / denom,
"train/epsilon": float(epsilon),
"train/global_step": int(steps),
}
)
interval_sums = {key: 0.0 for key in interval_sums}
interval_count = 0
epsilon = max(float(cfg["eps_end"]), epsilon * float(cfg["eps_decay"]))
obs = env.reset()[0] if done else nxt
metrics: dict[str, float | int] = {
if interval_count > 0:
denom = float(interval_count)
train_events.append(
{
"train/reward_mean": interval_sums["reward"] / denom,
"train/revenue_mean": interval_sums["revenue"] / denom,
"train/agent_prob": interval_sums["agent_prob"] / denom,
"train/alpha_adv": interval_sums["alpha_adv"] / denom,
"train/coi_leakage": interval_sums["coi_leakage"] / denom,
"train/epsilon": float(epsilon),
"train/global_step": int(steps),
}
)
metrics: dict[str, Any] = {
"train/reward_mean": total_reward / max(steps, 1),
"train/revenue_mean": total_revenue / max(steps, 1),
"train/epsilon": float(epsilon),
"train/global_step": int(cfg["total_timesteps"]),
}
metrics.update(evaluate(agent, eval_env, int(cfg["eval_episodes"])))
metrics["_train_events"] = train_events
env.close()
eval_env.close()