mirror of
https://github.com/velocitatem/PHANTOM.git
synced 2026-05-31 16:43:36 +00:00
cleaning up jax bs
This commit is contained in:
@@ -1 +1 @@
|
||||
__all__ = ["evaluate", "make_env", "train_jax_backend", "train_qtable", "train_sb3"]
|
||||
__all__ = ["evaluate", "make_env", "train_qtable", "train_sb3"]
|
||||
|
||||
@@ -1,18 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Mapping
|
||||
|
||||
from ..jax import JAX_AVAILABLE
|
||||
|
||||
|
||||
def train_jax_backend(
|
||||
cfg: Mapping[str, Any],
|
||||
) -> tuple[dict[str, Any], dict[str, float | int | str]]:
|
||||
if not JAX_AVAILABLE:
|
||||
raise ImportError(
|
||||
"JAX backend requested but JAX is not installed. "
|
||||
"Install engine/jax/requirements.txt and jax[tpu] for TPU runs."
|
||||
)
|
||||
from ..jax.train import train_jax
|
||||
|
||||
return train_jax(dict(cfg))
|
||||
@@ -7,7 +7,9 @@ import numpy as np
|
||||
from .common import evaluate, make_env
|
||||
|
||||
|
||||
def train_qtable(cfg: Mapping[str, Any]) -> tuple[object, dict[str, float | int]]:
|
||||
def train_qtable(
|
||||
cfg: Mapping[str, Any],
|
||||
) -> tuple[object, dict[str, Any]]:
|
||||
from ..lib.discrete import EventQTable
|
||||
|
||||
np.random.seed(int(cfg["seed"]))
|
||||
@@ -26,8 +28,19 @@ def train_qtable(cfg: Mapping[str, Any]) -> tuple[object, dict[str, float | int]
|
||||
total_revenue = 0.0
|
||||
steps = 0
|
||||
epsilon = float(cfg["eps_start"])
|
||||
log_freq = max(1, int(cfg.get("log_freq", 100)))
|
||||
obs, _ = env.reset(seed=int(cfg["seed"]))
|
||||
|
||||
interval_sums = {
|
||||
"reward": 0.0,
|
||||
"revenue": 0.0,
|
||||
"agent_prob": 0.0,
|
||||
"alpha_adv": 0.0,
|
||||
"coi_leakage": 0.0,
|
||||
}
|
||||
interval_count = 0
|
||||
train_events: list[dict[str, float | int]] = []
|
||||
|
||||
for _ in range(int(cfg["total_timesteps"])):
|
||||
action, state = agent.act(obs, epsilon)
|
||||
nxt, reward, term, trunc, info = env.step(action)
|
||||
@@ -35,18 +48,57 @@ def train_qtable(cfg: Mapping[str, Any]) -> tuple[object, dict[str, float | int]
|
||||
agent.update(state, action, float(reward), agent.encode(nxt), done)
|
||||
|
||||
total_reward += float(reward)
|
||||
total_revenue += float(info.get("economics", {}).get("revenue", 0.0))
|
||||
revenue = float(info.get("economics", {}).get("revenue", 0.0))
|
||||
total_revenue += revenue
|
||||
steps += 1
|
||||
interval_sums["reward"] += float(reward)
|
||||
interval_sums["revenue"] += revenue
|
||||
interval_sums["agent_prob"] += float(info.get("agent_prob", 0.0))
|
||||
interval_sums["alpha_adv"] += float(info.get("alpha_adv", 0.0))
|
||||
interval_sums["coi_leakage"] += float(info.get("coi_leakage", 0.0))
|
||||
interval_count += 1
|
||||
|
||||
if steps % log_freq == 0 and interval_count > 0:
|
||||
denom = float(interval_count)
|
||||
train_events.append(
|
||||
{
|
||||
"train/reward_mean": interval_sums["reward"] / denom,
|
||||
"train/revenue_mean": interval_sums["revenue"] / denom,
|
||||
"train/agent_prob": interval_sums["agent_prob"] / denom,
|
||||
"train/alpha_adv": interval_sums["alpha_adv"] / denom,
|
||||
"train/coi_leakage": interval_sums["coi_leakage"] / denom,
|
||||
"train/epsilon": float(epsilon),
|
||||
"train/global_step": int(steps),
|
||||
}
|
||||
)
|
||||
interval_sums = {key: 0.0 for key in interval_sums}
|
||||
interval_count = 0
|
||||
|
||||
epsilon = max(float(cfg["eps_end"]), epsilon * float(cfg["eps_decay"]))
|
||||
obs = env.reset()[0] if done else nxt
|
||||
|
||||
metrics: dict[str, float | int] = {
|
||||
if interval_count > 0:
|
||||
denom = float(interval_count)
|
||||
train_events.append(
|
||||
{
|
||||
"train/reward_mean": interval_sums["reward"] / denom,
|
||||
"train/revenue_mean": interval_sums["revenue"] / denom,
|
||||
"train/agent_prob": interval_sums["agent_prob"] / denom,
|
||||
"train/alpha_adv": interval_sums["alpha_adv"] / denom,
|
||||
"train/coi_leakage": interval_sums["coi_leakage"] / denom,
|
||||
"train/epsilon": float(epsilon),
|
||||
"train/global_step": int(steps),
|
||||
}
|
||||
)
|
||||
|
||||
metrics: dict[str, Any] = {
|
||||
"train/reward_mean": total_reward / max(steps, 1),
|
||||
"train/revenue_mean": total_revenue / max(steps, 1),
|
||||
"train/epsilon": float(epsilon),
|
||||
"train/global_step": int(cfg["total_timesteps"]),
|
||||
}
|
||||
metrics.update(evaluate(agent, eval_env, int(cfg["eval_episodes"])))
|
||||
metrics["_train_events"] = train_events
|
||||
|
||||
env.close()
|
||||
eval_env.close()
|
||||
|
||||
@@ -4,9 +4,7 @@ import json
|
||||
from pathlib import Path
|
||||
from typing import Any, Mapping
|
||||
|
||||
from ..lib.callbacks import CheckpointArtifactCallback, MetricsCallback
|
||||
from ..telemetry.wandb import get_wandb_module
|
||||
from ..wandb_checkpoint import checkpoint_artifact_name, download_latest_checkpoint
|
||||
from ..lib.callbacks import MetricsCallback
|
||||
from .common import evaluate, make_env
|
||||
|
||||
|
||||
@@ -52,21 +50,6 @@ def _policy_kwargs(cfg: Mapping[str, Any]) -> dict[str, Any]:
|
||||
return kwargs
|
||||
|
||||
|
||||
def _sb3_model_cls(algo: str):
|
||||
try:
|
||||
from stable_baselines3 import A2C, DQN, PPO
|
||||
except ImportError as exc:
|
||||
raise ImportError("stable-baselines3 is required for SB3 algorithms") from exc
|
||||
|
||||
if algo == "ppo":
|
||||
return PPO
|
||||
if algo == "a2c":
|
||||
return A2C
|
||||
if algo == "dqn":
|
||||
return DQN
|
||||
raise ValueError(f"unsupported algo '{algo}'")
|
||||
|
||||
|
||||
def build_model(cfg: Mapping[str, Any], env: Any):
|
||||
try:
|
||||
from stable_baselines3 import A2C, DQN, PPO
|
||||
@@ -132,29 +115,7 @@ def build_model(cfg: Mapping[str, Any], env: Any):
|
||||
raise ValueError(f"unsupported algo '{algo}'")
|
||||
|
||||
|
||||
def _maybe_resume_model(cfg: Mapping[str, Any], env: Any, model: Any):
|
||||
wandb = get_wandb_module()
|
||||
if wandb is None or wandb.run is None:
|
||||
return model
|
||||
|
||||
sweep_id = getattr(wandb.run, "sweep_id", None)
|
||||
artifact_name = checkpoint_artifact_name(cfg, backend="sb3", sweep_id=sweep_id)
|
||||
checkpoint_file = f"phantom_{cfg['algo']}_checkpoint.zip"
|
||||
restored = download_latest_checkpoint(artifact_name, file_name=checkpoint_file)
|
||||
if restored is None:
|
||||
return model
|
||||
|
||||
checkpoint_path, metadata = restored
|
||||
resumed = _sb3_model_cls(str(cfg["algo"]).lower()).load(
|
||||
checkpoint_path.as_posix(),
|
||||
env=env,
|
||||
)
|
||||
resume_step = int(metadata.get("step", getattr(resumed, "num_timesteps", 0)))
|
||||
resumed.num_timesteps = max(int(getattr(resumed, "num_timesteps", 0)), resume_step)
|
||||
return resumed
|
||||
|
||||
|
||||
def train_sb3(cfg: Mapping[str, Any]) -> tuple[object, dict[str, float | int | str]]:
|
||||
def train_sb3(cfg: Mapping[str, Any]) -> tuple[object, dict[str, Any]]:
|
||||
try:
|
||||
from stable_baselines3.common.callbacks import EvalCallback
|
||||
from stable_baselines3.common.monitor import Monitor
|
||||
@@ -182,15 +143,10 @@ def train_sb3(cfg: Mapping[str, Any]) -> tuple[object, dict[str, float | int | s
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
model = _maybe_resume_model(cfg, env, model)
|
||||
|
||||
callbacks = [MetricsCallback(log_histograms=False, log_freq=int(cfg["log_freq"]))]
|
||||
callbacks.append(
|
||||
CheckpointArtifactCallback(
|
||||
dict(cfg),
|
||||
interval=int(cfg.get("checkpoint_interval", 10_000)),
|
||||
)
|
||||
metrics_callback = MetricsCallback(
|
||||
log_histograms=False, log_freq=int(cfg["log_freq"])
|
||||
)
|
||||
callbacks = [metrics_callback]
|
||||
callbacks.append(
|
||||
EvalCallback(
|
||||
eval_env,
|
||||
@@ -215,13 +171,14 @@ def train_sb3(cfg: Mapping[str, Any]) -> tuple[object, dict[str, float | int | s
|
||||
model_path = model_dir / f"phantom_{cfg['algo']}"
|
||||
model.save(str(model_path))
|
||||
|
||||
metrics: dict[str, float | int | str] = evaluate(
|
||||
metrics: dict[str, Any] = evaluate(
|
||||
model,
|
||||
eval_env,
|
||||
int(cfg["eval_episodes"]),
|
||||
)
|
||||
metrics["train/global_step"] = int(model.num_timesteps)
|
||||
metrics["model/path"] = str(model_path.with_suffix(".zip"))
|
||||
metrics["_train_events"] = list(metrics_callback.events)
|
||||
|
||||
env.close()
|
||||
eval_env.close()
|
||||
|
||||
Reference in New Issue
Block a user