mirror of
https://github.com/velocitatem/PHANTOM.git
synced 2026-05-31 08:33:36 +00:00
fixing models for gcp
This commit is contained in:
@@ -7,6 +7,19 @@ from typing import Any, NamedTuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
try:
|
||||
import wandb
|
||||
|
||||
HAS_WANDB = True
|
||||
except ImportError:
|
||||
HAS_WANDB = False
|
||||
|
||||
from ..wandb_checkpoint import (
|
||||
checkpoint_artifact_name,
|
||||
download_latest_checkpoint,
|
||||
log_checkpoint_bytes,
|
||||
)
|
||||
|
||||
try:
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
@@ -142,6 +155,7 @@ def _jax_cfg(cfg: dict[str, Any]) -> dict[str, Any]:
|
||||
"num_minibatches": int(cfg.get("jax_num_minibatches", 4)),
|
||||
"update_epochs": int(cfg.get("jax_update_epochs", 4)),
|
||||
"anneal_lr": bool(cfg.get("jax_anneal_lr", True)),
|
||||
"checkpoint_interval": int(cfg.get("checkpoint_interval", 10_000)),
|
||||
}
|
||||
rollout = out["num_envs"] * out["num_steps"]
|
||||
out["num_updates"] = max(1, out["total_timesteps"] // max(rollout, 1))
|
||||
@@ -185,201 +199,198 @@ def make_train(config: dict[str, Any]):
|
||||
frac = 1.0 - updates_done / max(cfg["num_updates"], 1)
|
||||
return cfg["learning_rate"] * frac
|
||||
|
||||
def train(rng: jax.Array):
|
||||
if cfg["anneal_lr"]:
|
||||
tx = optax.chain(
|
||||
optax.clip_by_global_norm(cfg["max_grad_norm"]),
|
||||
optax.adam(learning_rate=linear_schedule, eps=1e-5),
|
||||
)
|
||||
else:
|
||||
tx = optax.chain(
|
||||
optax.clip_by_global_norm(cfg["max_grad_norm"]),
|
||||
optax.adam(cfg["learning_rate"], eps=1e-5),
|
||||
)
|
||||
|
||||
def init_runner_state(rng: jax.Array):
|
||||
rng, init_key = jax.random.split(rng)
|
||||
init_obs = jnp.zeros((env.observation_dim(),), dtype=jnp.float32)
|
||||
params = network.init(init_key, init_obs)
|
||||
|
||||
if cfg["anneal_lr"]:
|
||||
tx = optax.chain(
|
||||
optax.clip_by_global_norm(cfg["max_grad_norm"]),
|
||||
optax.adam(learning_rate=linear_schedule, eps=1e-5),
|
||||
)
|
||||
else:
|
||||
tx = optax.chain(
|
||||
optax.clip_by_global_norm(cfg["max_grad_norm"]),
|
||||
optax.adam(cfg["learning_rate"], eps=1e-5),
|
||||
)
|
||||
train_state = TrainState.create(apply_fn=network.apply, params=params, tx=tx)
|
||||
|
||||
rng, reset_key = jax.random.split(rng)
|
||||
reset_keys = jax.random.split(reset_key, cfg["num_envs"])
|
||||
obs, env_state = jax.vmap(env.reset)(reset_keys)
|
||||
return train_state, env_state, obs, rng
|
||||
|
||||
def _update_step(runner_state, _):
|
||||
def _env_step(runner_state, _):
|
||||
train_state, env_state, last_obs, rng = runner_state
|
||||
rng, action_key = jax.random.split(rng)
|
||||
policy, value = network.apply(train_state.params, last_obs)
|
||||
action = policy.sample(seed=action_key)
|
||||
log_prob = policy.log_prob(action)
|
||||
|
||||
rng, step_key = jax.random.split(rng)
|
||||
step_keys = jax.random.split(step_key, cfg["num_envs"])
|
||||
nxt_obs, nxt_state, reward, done, info = jax.vmap(
|
||||
env.step,
|
||||
in_axes=(0, 0, 0),
|
||||
)(step_keys, env_state, action)
|
||||
|
||||
rng, reset_key = jax.random.split(rng)
|
||||
reset_keys = jax.random.split(reset_key, cfg["num_envs"])
|
||||
rst_obs, rst_state = jax.vmap(env.reset)(reset_keys)
|
||||
obs_next = jnp.where(done[:, None], rst_obs, nxt_obs)
|
||||
env_next = jax.tree_util.tree_map(
|
||||
lambda keep, reset: _select_env_state(done, keep, reset),
|
||||
nxt_state,
|
||||
rst_state,
|
||||
)
|
||||
transition = Transition(
|
||||
done=done,
|
||||
action=action,
|
||||
value=value,
|
||||
reward=reward,
|
||||
log_prob=log_prob,
|
||||
obs=last_obs,
|
||||
info=info,
|
||||
)
|
||||
return (train_state, env_next, obs_next, rng), transition
|
||||
|
||||
runner_state, traj_batch = jax.lax.scan(
|
||||
_env_step,
|
||||
runner_state,
|
||||
None,
|
||||
length=cfg["num_steps"],
|
||||
)
|
||||
|
||||
def _update_step(runner_state, _):
|
||||
def _env_step(runner_state, _):
|
||||
train_state, env_state, last_obs, rng = runner_state
|
||||
_, last_value = network.apply(train_state.params, last_obs)
|
||||
rng, action_key = jax.random.split(rng)
|
||||
policy, value = network.apply(train_state.params, last_obs)
|
||||
action = policy.sample(seed=action_key)
|
||||
log_prob = policy.log_prob(action)
|
||||
|
||||
def _compute_gae(traj_batch, last_value):
|
||||
def _gae_step(carry, transition):
|
||||
gae, next_value = carry
|
||||
delta = (
|
||||
transition.reward
|
||||
+ cfg["gamma"] * next_value * (1.0 - transition.done)
|
||||
- transition.value
|
||||
)
|
||||
gae = (
|
||||
delta
|
||||
+ cfg["gamma"]
|
||||
* cfg["gae_lambda"]
|
||||
* (1.0 - transition.done)
|
||||
* gae
|
||||
)
|
||||
return (gae, transition.value), gae
|
||||
rng, step_key = jax.random.split(rng)
|
||||
step_keys = jax.random.split(step_key, cfg["num_envs"])
|
||||
nxt_obs, nxt_state, reward, done, info = jax.vmap(
|
||||
env.step,
|
||||
in_axes=(0, 0, 0),
|
||||
)(step_keys, env_state, action)
|
||||
|
||||
_, advantages = jax.lax.scan(
|
||||
_gae_step,
|
||||
(jnp.zeros_like(last_value), last_value),
|
||||
traj_batch,
|
||||
reverse=True,
|
||||
unroll=16,
|
||||
)
|
||||
targets = advantages + traj_batch.value
|
||||
return advantages, targets
|
||||
|
||||
advantages, targets = _compute_gae(traj_batch, last_value)
|
||||
|
||||
def _update_epoch(update_state, _):
|
||||
def _update_minibatch(train_state, batch_info):
|
||||
traj_b, adv_b, tgt_b = batch_info
|
||||
|
||||
def _loss_fn(params, traj_b, adv_b, tgt_b):
|
||||
policy, value = network.apply(params, traj_b.obs)
|
||||
log_prob = policy.log_prob(traj_b.action)
|
||||
|
||||
value_clipped = traj_b.value + (value - traj_b.value).clip(
|
||||
-cfg["clip_range"], cfg["clip_range"]
|
||||
)
|
||||
value_loss = (
|
||||
0.5
|
||||
* jnp.maximum(
|
||||
jnp.square(value - tgt_b),
|
||||
jnp.square(value_clipped - tgt_b),
|
||||
).mean()
|
||||
)
|
||||
|
||||
adv_norm = (adv_b - adv_b.mean()) / (adv_b.std() + 1e-8)
|
||||
ratio = jnp.exp(log_prob - traj_b.log_prob)
|
||||
loss_actor = -jnp.minimum(
|
||||
ratio * adv_norm,
|
||||
jnp.clip(
|
||||
ratio,
|
||||
1.0 - cfg["clip_range"],
|
||||
1.0 + cfg["clip_range"],
|
||||
)
|
||||
* adv_norm,
|
||||
).mean()
|
||||
entropy = policy.entropy().mean()
|
||||
total_loss = (
|
||||
loss_actor
|
||||
+ cfg["vf_coef"] * value_loss
|
||||
- cfg["ent_coef"] * entropy
|
||||
)
|
||||
return total_loss, (value_loss, loss_actor, entropy)
|
||||
|
||||
grad_fn = jax.value_and_grad(_loss_fn, has_aux=True)
|
||||
(_, _), grads = grad_fn(train_state.params, traj_b, adv_b, tgt_b)
|
||||
train_state = train_state.apply_gradients(grads=grads)
|
||||
return train_state, jnp.asarray(0.0, dtype=jnp.float32)
|
||||
|
||||
train_state, traj_batch, advantages, targets, rng = update_state
|
||||
rng, perm_key = jax.random.split(rng)
|
||||
batch_size = cfg["num_envs"] * cfg["num_steps"]
|
||||
permutation = jax.random.permutation(perm_key, batch_size)
|
||||
batch = (traj_batch, advantages, targets)
|
||||
batch = jax.tree_util.tree_map(
|
||||
lambda x: x.reshape((batch_size,) + x.shape[2:]),
|
||||
batch,
|
||||
)
|
||||
shuffled = jax.tree_util.tree_map(
|
||||
lambda x: jnp.take(x, permutation, axis=0),
|
||||
batch,
|
||||
)
|
||||
minibatches = jax.tree_util.tree_map(
|
||||
lambda x: x.reshape(
|
||||
(cfg["num_minibatches"], cfg["minibatch_size"]) + x.shape[1:]
|
||||
),
|
||||
shuffled,
|
||||
)
|
||||
train_state, _ = jax.lax.scan(
|
||||
_update_minibatch, train_state, minibatches
|
||||
)
|
||||
return (train_state, traj_batch, advantages, targets, rng), None
|
||||
|
||||
update_state = (train_state, traj_batch, advantages, targets, rng)
|
||||
update_state, _ = jax.lax.scan(
|
||||
_update_epoch,
|
||||
update_state,
|
||||
None,
|
||||
length=cfg["update_epochs"],
|
||||
rng, reset_key = jax.random.split(rng)
|
||||
reset_keys = jax.random.split(reset_key, cfg["num_envs"])
|
||||
rst_obs, rst_state = jax.vmap(env.reset)(reset_keys)
|
||||
obs_next = jnp.where(done[:, None], rst_obs, nxt_obs)
|
||||
env_next = jax.tree_util.tree_map(
|
||||
lambda keep, reset: _select_env_state(done, keep, reset),
|
||||
nxt_state,
|
||||
rst_state,
|
||||
)
|
||||
train_state = update_state[0]
|
||||
rng = update_state[-1]
|
||||
transition = Transition(
|
||||
done=done,
|
||||
action=action,
|
||||
value=value,
|
||||
reward=reward,
|
||||
log_prob=log_prob,
|
||||
obs=last_obs,
|
||||
info=info,
|
||||
)
|
||||
return (train_state, env_next, obs_next, rng), transition
|
||||
|
||||
metric = {
|
||||
"reward": jnp.mean(traj_batch.reward),
|
||||
"revenue": jnp.mean(traj_batch.info["revenue"]),
|
||||
"agent_prob": jnp.mean(traj_batch.info["agent_prob"]),
|
||||
"alpha_adv": jnp.mean(traj_batch.info["alpha_adv"]),
|
||||
"coi_leakage": jnp.mean(traj_batch.info["coi_leakage"]),
|
||||
}
|
||||
runner_state = (train_state, env_state, last_obs, rng)
|
||||
return runner_state, metric
|
||||
runner_state, traj_batch = jax.lax.scan(
|
||||
_env_step,
|
||||
runner_state,
|
||||
None,
|
||||
length=cfg["num_steps"],
|
||||
)
|
||||
|
||||
runner_state = (train_state, env_state, obs, rng)
|
||||
train_state, env_state, last_obs, rng = runner_state
|
||||
_, last_value = network.apply(train_state.params, last_obs)
|
||||
|
||||
def _compute_gae(traj_batch, last_value):
|
||||
def _gae_step(carry, transition):
|
||||
gae, next_value = carry
|
||||
delta = (
|
||||
transition.reward
|
||||
+ cfg["gamma"] * next_value * (1.0 - transition.done)
|
||||
- transition.value
|
||||
)
|
||||
gae = (
|
||||
delta
|
||||
+ cfg["gamma"] * cfg["gae_lambda"] * (1.0 - transition.done) * gae
|
||||
)
|
||||
return (gae, transition.value), gae
|
||||
|
||||
_, advantages = jax.lax.scan(
|
||||
_gae_step,
|
||||
(jnp.zeros_like(last_value), last_value),
|
||||
traj_batch,
|
||||
reverse=True,
|
||||
unroll=16,
|
||||
)
|
||||
targets = advantages + traj_batch.value
|
||||
return advantages, targets
|
||||
|
||||
advantages, targets = _compute_gae(traj_batch, last_value)
|
||||
|
||||
def _update_epoch(update_state, _):
|
||||
def _update_minibatch(train_state, batch_info):
|
||||
traj_b, adv_b, tgt_b = batch_info
|
||||
|
||||
def _loss_fn(params, traj_b, adv_b, tgt_b):
|
||||
policy, value = network.apply(params, traj_b.obs)
|
||||
log_prob = policy.log_prob(traj_b.action)
|
||||
|
||||
value_clipped = traj_b.value + (value - traj_b.value).clip(
|
||||
-cfg["clip_range"], cfg["clip_range"]
|
||||
)
|
||||
value_loss = (
|
||||
0.5
|
||||
* jnp.maximum(
|
||||
jnp.square(value - tgt_b),
|
||||
jnp.square(value_clipped - tgt_b),
|
||||
).mean()
|
||||
)
|
||||
|
||||
adv_norm = (adv_b - adv_b.mean()) / (adv_b.std() + 1e-8)
|
||||
ratio = jnp.exp(log_prob - traj_b.log_prob)
|
||||
loss_actor = -jnp.minimum(
|
||||
ratio * adv_norm,
|
||||
jnp.clip(
|
||||
ratio,
|
||||
1.0 - cfg["clip_range"],
|
||||
1.0 + cfg["clip_range"],
|
||||
)
|
||||
* adv_norm,
|
||||
).mean()
|
||||
entropy = policy.entropy().mean()
|
||||
total_loss = (
|
||||
loss_actor
|
||||
+ cfg["vf_coef"] * value_loss
|
||||
- cfg["ent_coef"] * entropy
|
||||
)
|
||||
return total_loss, (value_loss, loss_actor, entropy)
|
||||
|
||||
grad_fn = jax.value_and_grad(_loss_fn, has_aux=True)
|
||||
(_, _), grads = grad_fn(train_state.params, traj_b, adv_b, tgt_b)
|
||||
train_state = train_state.apply_gradients(grads=grads)
|
||||
return train_state, jnp.asarray(0.0, dtype=jnp.float32)
|
||||
|
||||
train_state, traj_batch, advantages, targets, rng = update_state
|
||||
rng, perm_key = jax.random.split(rng)
|
||||
batch_size = cfg["num_envs"] * cfg["num_steps"]
|
||||
permutation = jax.random.permutation(perm_key, batch_size)
|
||||
batch = (traj_batch, advantages, targets)
|
||||
batch = jax.tree_util.tree_map(
|
||||
lambda x: x.reshape((batch_size,) + x.shape[2:]),
|
||||
batch,
|
||||
)
|
||||
shuffled = jax.tree_util.tree_map(
|
||||
lambda x: jnp.take(x, permutation, axis=0),
|
||||
batch,
|
||||
)
|
||||
minibatches = jax.tree_util.tree_map(
|
||||
lambda x: x.reshape(
|
||||
(cfg["num_minibatches"], cfg["minibatch_size"]) + x.shape[1:]
|
||||
),
|
||||
shuffled,
|
||||
)
|
||||
train_state, _ = jax.lax.scan(_update_minibatch, train_state, minibatches)
|
||||
return (train_state, traj_batch, advantages, targets, rng), None
|
||||
|
||||
update_state = (train_state, traj_batch, advantages, targets, rng)
|
||||
update_state, _ = jax.lax.scan(
|
||||
_update_epoch,
|
||||
update_state,
|
||||
None,
|
||||
length=cfg["update_epochs"],
|
||||
)
|
||||
train_state = update_state[0]
|
||||
rng = update_state[-1]
|
||||
|
||||
metric = {
|
||||
"reward": jnp.mean(traj_batch.reward),
|
||||
"revenue": jnp.mean(traj_batch.info["revenue"]),
|
||||
"agent_prob": jnp.mean(traj_batch.info["agent_prob"]),
|
||||
"alpha_adv": jnp.mean(traj_batch.info["alpha_adv"]),
|
||||
"coi_leakage": jnp.mean(traj_batch.info["coi_leakage"]),
|
||||
}
|
||||
next_runner_state = (train_state, env_state, last_obs, rng)
|
||||
return next_runner_state, metric
|
||||
|
||||
def run_updates(runner_state, *, num_updates: int):
|
||||
updates = max(1, int(num_updates))
|
||||
runner_state, metric = jax.lax.scan(
|
||||
_update_step,
|
||||
runner_state,
|
||||
None,
|
||||
length=cfg["num_updates"],
|
||||
length=updates,
|
||||
)
|
||||
return {
|
||||
"runner_state": runner_state,
|
||||
"metrics": metric,
|
||||
}
|
||||
|
||||
return train, network, env, cfg
|
||||
return init_runner_state, run_updates, network, env, cfg
|
||||
|
||||
|
||||
def evaluate_policy(
|
||||
@@ -436,22 +447,103 @@ def train_jax(cfg: dict[str, Any]) -> tuple[dict[str, Any], dict[str, float]]:
|
||||
f"JAX backend currently supports algo='ppo' only, got '{run_cfg['algo']}'"
|
||||
)
|
||||
|
||||
train_fn, network, env, run_cfg = make_train(run_cfg)
|
||||
train_jit = jax.jit(train_fn)
|
||||
rng = jax.random.PRNGKey(run_cfg["seed"])
|
||||
out = train_jit(rng)
|
||||
init_runner_state, run_updates, network, env, run_cfg = make_train(run_cfg)
|
||||
run_updates_jit = jax.jit(run_updates, static_argnames=("num_updates",))
|
||||
rollout_steps = int(run_cfg["num_steps"] * run_cfg["num_envs"])
|
||||
total_updates = int(run_cfg["num_updates"])
|
||||
checkpoint_interval = max(1, int(run_cfg.get("checkpoint_interval", 10_000)))
|
||||
segment_updates = max(1, checkpoint_interval // max(rollout_steps, 1))
|
||||
|
||||
train_state = out["runner_state"][0]
|
||||
metric = out["metrics"]
|
||||
rng = jax.random.PRNGKey(run_cfg["seed"])
|
||||
runner_state = init_runner_state(rng)
|
||||
updates_done = 0
|
||||
|
||||
artifact_name = None
|
||||
if HAS_WANDB and wandb.run is not None:
|
||||
sweep_id = getattr(wandb.run, "sweep_id", None)
|
||||
artifact_name = checkpoint_artifact_name(
|
||||
run_cfg,
|
||||
backend="jax",
|
||||
sweep_id=sweep_id,
|
||||
)
|
||||
restored = download_latest_checkpoint(
|
||||
artifact_name,
|
||||
file_name="jax_runner_state.msgpack",
|
||||
)
|
||||
if restored is not None:
|
||||
checkpoint_path, metadata = restored
|
||||
template = {
|
||||
"runner_state": runner_state,
|
||||
"updates_done": 0,
|
||||
}
|
||||
payload = serialization.from_bytes(template, checkpoint_path.read_bytes())
|
||||
runner_state = payload["runner_state"]
|
||||
updates_done = int(payload.get("updates_done", 0))
|
||||
if updates_done <= 0:
|
||||
updates_done = int(metadata.get("updates_done", 0))
|
||||
updates_done = max(0, min(updates_done, total_updates))
|
||||
|
||||
metric_keys = ["reward", "revenue", "agent_prob", "alpha_adv", "coi_leakage"]
|
||||
metric_sums = {k: 0.0 for k in metric_keys}
|
||||
metric_count = 0
|
||||
|
||||
while updates_done < total_updates:
|
||||
updates_this_segment = min(segment_updates, total_updates - updates_done)
|
||||
out = run_updates_jit(runner_state, num_updates=updates_this_segment)
|
||||
runner_state = out["runner_state"]
|
||||
metric = out["metrics"]
|
||||
|
||||
segment_values = {
|
||||
k: np.asarray(metric[k], dtype=np.float64) for k in metric_keys
|
||||
}
|
||||
segment_count = int(segment_values["reward"].shape[0]) if segment_values else 0
|
||||
metric_count += segment_count
|
||||
for key in metric_keys:
|
||||
metric_sums[key] += float(segment_values[key].sum())
|
||||
|
||||
updates_done += int(updates_this_segment)
|
||||
global_step = int(updates_done * rollout_steps)
|
||||
|
||||
if HAS_WANDB and wandb.run is not None:
|
||||
wandb.log(
|
||||
{
|
||||
"train/reward": float(segment_values["reward"].mean()),
|
||||
"train/revenue": float(segment_values["revenue"].mean()),
|
||||
"train/agent_prob": float(segment_values["agent_prob"].mean()),
|
||||
"train/alpha_adv": float(segment_values["alpha_adv"].mean()),
|
||||
"train/coi_leakage": float(segment_values["coi_leakage"].mean()),
|
||||
"train/global_step": global_step,
|
||||
},
|
||||
step=global_step,
|
||||
)
|
||||
if artifact_name is not None:
|
||||
checkpoint_payload = serialization.to_bytes(
|
||||
{
|
||||
"runner_state": runner_state,
|
||||
"updates_done": updates_done,
|
||||
}
|
||||
)
|
||||
log_checkpoint_bytes(
|
||||
artifact_name,
|
||||
file_name="jax_runner_state.msgpack",
|
||||
payload=checkpoint_payload,
|
||||
metadata={
|
||||
"step": global_step,
|
||||
"updates_done": updates_done,
|
||||
"rollout_steps": rollout_steps,
|
||||
"algo": "ppo",
|
||||
},
|
||||
)
|
||||
|
||||
train_state = runner_state[0]
|
||||
denom = float(metric_count) if metric_count > 0 else 1.0
|
||||
metrics = {
|
||||
"train/reward": float(np.mean(np.asarray(metric["reward"]))),
|
||||
"train/revenue": float(np.mean(np.asarray(metric["revenue"]))),
|
||||
"train/agent_prob": float(np.mean(np.asarray(metric["agent_prob"]))),
|
||||
"train/alpha_adv": float(np.mean(np.asarray(metric["alpha_adv"]))),
|
||||
"train/coi_leakage": float(np.mean(np.asarray(metric["coi_leakage"]))),
|
||||
"train/global_step": int(
|
||||
run_cfg["num_updates"] * run_cfg["num_steps"] * run_cfg["num_envs"]
|
||||
),
|
||||
"train/reward": float(metric_sums["reward"] / denom),
|
||||
"train/revenue": float(metric_sums["revenue"] / denom),
|
||||
"train/agent_prob": float(metric_sums["agent_prob"] / denom),
|
||||
"train/alpha_adv": float(metric_sums["alpha_adv"] / denom),
|
||||
"train/coi_leakage": float(metric_sums["coi_leakage"] / denom),
|
||||
"train/global_step": int(updates_done * rollout_steps),
|
||||
}
|
||||
|
||||
eval_metrics = evaluate_policy(
|
||||
|
||||
@@ -2,7 +2,7 @@ from .demand import estimate_demand, estimate_weighted_demand, generate_demand_f
|
||||
from .behavior import sample_behavior, get_transition_models, trajectory_to_events
|
||||
from .render import DashboardRenderer, style_axis
|
||||
from .wrappers import EconomicMetricsWrapper
|
||||
from .callbacks import MetricsCallback, EvalMetricsCallback
|
||||
from .callbacks import MetricsCallback, EvalMetricsCallback, CheckpointArtifactCallback
|
||||
from .providers import (
|
||||
ProviderBenchmark,
|
||||
ProviderResult,
|
||||
|
||||
@@ -1,8 +1,12 @@
|
||||
"""Training callbacks for W&B/TensorBoard logging - reads from info dict."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from stable_baselines3.common.callbacks import BaseCallback, EvalCallback
|
||||
import numpy as np
|
||||
|
||||
from ..wandb_checkpoint import checkpoint_artifact_name, log_checkpoint_file
|
||||
|
||||
try:
|
||||
import wandb
|
||||
|
||||
@@ -80,6 +84,65 @@ class MetricsCallback(BaseCallback):
|
||||
self._episode_revenues = []
|
||||
|
||||
|
||||
class CheckpointArtifactCallback(BaseCallback):
|
||||
"""Periodic SB3 checkpoint uploader backed by W&B artifacts."""
|
||||
|
||||
def __init__(self, cfg: dict, interval: int = 10_000, verbose: int = 0):
|
||||
super().__init__(verbose)
|
||||
self.cfg = dict(cfg)
|
||||
self.interval = max(1, int(interval))
|
||||
self.model_dir = Path(str(self.cfg.get("model_dir", "engine/models")))
|
||||
self.model_dir.mkdir(parents=True, exist_ok=True)
|
||||
self._next_checkpoint = self.interval
|
||||
self._last_saved_step = -1
|
||||
|
||||
def _artifact_name(self) -> str:
|
||||
sweep_id = (
|
||||
getattr(wandb.run, "sweep_id", None)
|
||||
if HAS_WANDB and wandb.run is not None
|
||||
else None
|
||||
)
|
||||
return checkpoint_artifact_name(self.cfg, backend="sb3", sweep_id=sweep_id)
|
||||
|
||||
def _checkpoint_file(self) -> Path:
|
||||
algo = str(self.cfg.get("algo", "model"))
|
||||
base = self.model_dir / f"phantom_{algo}_checkpoint"
|
||||
self.model.save(str(base))
|
||||
return base.with_suffix(".zip")
|
||||
|
||||
def _save_checkpoint(self) -> None:
|
||||
if not HAS_WANDB or wandb.run is None:
|
||||
return
|
||||
step = int(self.num_timesteps)
|
||||
if step <= self._last_saved_step:
|
||||
return
|
||||
checkpoint_path = self._checkpoint_file()
|
||||
metadata = {
|
||||
"step": step,
|
||||
"algo": str(self.cfg.get("algo", "unknown")),
|
||||
"sweep_id": getattr(wandb.run, "sweep_id", None),
|
||||
}
|
||||
saved = log_checkpoint_file(
|
||||
self._artifact_name(),
|
||||
file_path=checkpoint_path,
|
||||
artifact_file_name=checkpoint_path.name,
|
||||
metadata=metadata,
|
||||
)
|
||||
if saved:
|
||||
self._last_saved_step = step
|
||||
|
||||
def _on_step(self) -> bool:
|
||||
if self.num_timesteps < self._next_checkpoint:
|
||||
return True
|
||||
self._save_checkpoint()
|
||||
while self._next_checkpoint <= self.num_timesteps:
|
||||
self._next_checkpoint += self.interval
|
||||
return True
|
||||
|
||||
def _on_training_end(self) -> None:
|
||||
self._save_checkpoint()
|
||||
|
||||
|
||||
class EvalMetricsCallback(EvalCallback):
|
||||
"""Deterministic evaluation - true performance without exploration noise."""
|
||||
|
||||
|
||||
@@ -6,6 +6,8 @@ import os
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
|
||||
from .wandb_checkpoint import checkpoint_artifact_name, download_latest_checkpoint
|
||||
|
||||
try:
|
||||
import wandb
|
||||
|
||||
@@ -78,6 +80,7 @@ DEFAULT_CFG = {
|
||||
"jax_num_minibatches": 4,
|
||||
"jax_update_epochs": 4,
|
||||
"jax_anneal_lr": True,
|
||||
"checkpoint_interval": 10_000,
|
||||
}
|
||||
|
||||
|
||||
@@ -262,6 +265,16 @@ def build_model(cfg: dict, env):
|
||||
raise ValueError(f"unsupported algo '{algo}'")
|
||||
|
||||
|
||||
def _sb3_model_cls(algo: str):
|
||||
if algo == "ppo":
|
||||
return PPO
|
||||
if algo == "a2c":
|
||||
return A2C
|
||||
if algo == "dqn":
|
||||
return DQN
|
||||
raise ValueError(f"unsupported algo '{algo}'")
|
||||
|
||||
|
||||
def train_qtable(cfg: dict) -> tuple[EventQTable, dict]:
|
||||
from .lib.discrete import EventQTable
|
||||
|
||||
@@ -305,14 +318,36 @@ def train_qtable(cfg: dict) -> tuple[EventQTable, dict]:
|
||||
def train_sb3(cfg: dict) -> tuple[object, dict]:
|
||||
if not HAS_SB3:
|
||||
raise ImportError("stable-baselines3 is required for SB3 models")
|
||||
from .lib.callbacks import MetricsCallback
|
||||
from .lib.callbacks import CheckpointArtifactCallback, MetricsCallback
|
||||
|
||||
env = make_env(cfg)
|
||||
eval_env = make_env(cfg)
|
||||
env = Monitor(env)
|
||||
eval_env = Monitor(eval_env)
|
||||
model = build_model(cfg, env)
|
||||
resume_step = 0
|
||||
if HAS_WANDB and wandb.run is not None:
|
||||
sweep_id = getattr(wandb.run, "sweep_id", None)
|
||||
artifact_name = checkpoint_artifact_name(cfg, backend="sb3", sweep_id=sweep_id)
|
||||
checkpoint_file = f"phantom_{cfg['algo']}_checkpoint.zip"
|
||||
restored = download_latest_checkpoint(artifact_name, file_name=checkpoint_file)
|
||||
if restored is not None:
|
||||
checkpoint_path, metadata = restored
|
||||
model = _sb3_model_cls(cfg["algo"]).load(
|
||||
checkpoint_path.as_posix(), env=env
|
||||
)
|
||||
resume_step = int(metadata.get("step", getattr(model, "num_timesteps", 0)))
|
||||
model.num_timesteps = max(
|
||||
int(getattr(model, "num_timesteps", 0)), resume_step
|
||||
)
|
||||
|
||||
cbs = [MetricsCallback(log_histograms=True, log_freq=int(cfg["log_freq"]))]
|
||||
cbs.append(
|
||||
CheckpointArtifactCallback(
|
||||
cfg,
|
||||
interval=int(cfg.get("checkpoint_interval", 10_000)),
|
||||
)
|
||||
)
|
||||
cbs.append(
|
||||
EvalCallback(
|
||||
eval_env,
|
||||
@@ -322,7 +357,15 @@ def train_sb3(cfg: dict) -> tuple[object, dict]:
|
||||
verbose=0,
|
||||
)
|
||||
)
|
||||
model.learn(total_timesteps=int(cfg["total_timesteps"]), callback=cbs)
|
||||
target_steps = int(cfg["total_timesteps"])
|
||||
remaining_steps = max(0, target_steps - int(getattr(model, "num_timesteps", 0)))
|
||||
if remaining_steps > 0:
|
||||
model.learn(
|
||||
total_timesteps=remaining_steps,
|
||||
callback=cbs,
|
||||
reset_num_timesteps=False,
|
||||
)
|
||||
|
||||
model_path = Path(cfg["model_dir"])
|
||||
model_path.mkdir(parents=True, exist_ok=True)
|
||||
model.save(str(model_path / f"phantom_{cfg['algo']}"))
|
||||
@@ -413,6 +456,7 @@ def main():
|
||||
p.add_argument("--jax-num-minibatches", type=int)
|
||||
p.add_argument("--jax-update-epochs", type=int)
|
||||
p.add_argument("--jax-anneal-lr", type=str)
|
||||
p.add_argument("--checkpoint-interval", type=int)
|
||||
p.add_argument("--sweep-agent", action="store_true")
|
||||
p.add_argument("--sweep-id", type=str)
|
||||
p.add_argument("--count", type=int, default=0)
|
||||
@@ -441,6 +485,7 @@ def main():
|
||||
"jax_num_steps": args.jax_num_steps,
|
||||
"jax_num_minibatches": args.jax_num_minibatches,
|
||||
"jax_update_epochs": args.jax_update_epochs,
|
||||
"checkpoint_interval": args.checkpoint_interval,
|
||||
"jax_anneal_lr": _truthy(args.jax_anneal_lr)
|
||||
if args.jax_anneal_lr is not None
|
||||
else None,
|
||||
|
||||
Reference in New Issue
Block a user