feat: training update

This commit is contained in:
2026-02-27 09:33:04 +01:00
parent dac1e58a0d
commit e50d643fbf
3 changed files with 122 additions and 23 deletions

View File

@@ -727,6 +727,8 @@ def _train_actor_critic(
) -> tuple[dict[str, Any], dict[str, float]]:
num_devices = jax.local_device_count()
use_pmap = num_devices > 1
global_devices = max(1, int(jax.device_count()))
process_idx = int(jax.process_index())
init_runner_state, run_updates_raw, network, env, run_cfg = (
_make_actor_critic_train(cfg, algo=algo, use_pmap=use_pmap)
@@ -743,18 +745,26 @@ def _train_actor_critic(
run_fn = jax.jit(run_updates_raw, static_argnames=("num_updates",))
rollout_steps = int(run_cfg["num_steps"] * run_cfg["num_envs"])
rollout_steps_global = rollout_steps * (global_devices if use_pmap else 1)
total_updates = int(run_cfg["num_updates"])
checkpoint_interval = max(1, int(run_cfg.get("checkpoint_interval", 10_000)))
segment_updates = max(1, checkpoint_interval // max(rollout_steps, 1))
segment_updates = max(1, checkpoint_interval // max(rollout_steps_global, 1))
rng = jax.random.PRNGKey(run_cfg["seed"])
# single-device state used as template for serialization and eval
single_runner_state = init_runner_state(rng)
base_rng = jax.random.PRNGKey(run_cfg["seed"])
base_rng = jax.random.fold_in(base_rng, process_idx)
if use_pmap:
init_keys = jax.random.split(base_rng, num_devices)
runner_state = jax.vmap(init_runner_state)(init_keys)
single_runner_state = jax.tree_util.tree_map(lambda x: x[0], runner_state)
else:
single_runner_state = init_runner_state(base_rng)
runner_state = single_runner_state
updates_done = 0
restored_train_state = None
is_primary = jax.process_index() == 0
is_primary = process_idx == 0
artifact_name = None
if is_primary and HAS_WANDB and wandb.run is not None:
if HAS_WANDB and wandb.run is not None:
sweep_id = getattr(wandb.run, "sweep_id", None)
artifact_name = checkpoint_artifact_name(
run_cfg,
@@ -770,16 +780,20 @@ def _train_actor_critic(
template = {"runner_state": single_runner_state, "updates_done": 0}
payload = serialization.from_bytes(template, checkpoint_path.read_bytes())
single_runner_state = payload["runner_state"]
restored_train_state = payload["runner_state"][0]
updates_done = int(payload.get("updates_done", 0))
if updates_done <= 0:
updates_done = int(metadata.get("updates_done", 0))
updates_done = max(0, min(updates_done, total_updates))
if use_pmap:
runner_state = jax.device_put_replicated(
single_runner_state, jax.local_devices()
if use_pmap and restored_train_state is not None:
runner_state = (
jax.device_put_replicated(restored_train_state, jax.local_devices()),
runner_state[1],
runner_state[2],
runner_state[3],
)
else:
elif not use_pmap:
runner_state = single_runner_state
metric_keys = ["reward", "revenue", "agent_prob", "alpha_adv", "coi_leakage"]
@@ -796,13 +810,14 @@ def _train_actor_critic(
metric = out["metrics"]
if use_pmap:
# take device-0 slice; shape is (n_devices, segment_updates)
segment_values = {
key: np.asarray(metric[key][0], dtype=np.float64) for key in metric_keys
key: np.asarray(metric[key], dtype=np.float64).reshape(-1)
for key in metric_keys
}
else:
segment_values = {
key: np.asarray(metric[key], dtype=np.float64) for key in metric_keys
key: np.asarray(metric[key], dtype=np.float64).reshape(-1)
for key in metric_keys
}
segment_count = int(segment_values["reward"].shape[0]) if segment_values else 0
@@ -811,7 +826,7 @@ def _train_actor_critic(
metric_sums[key] += float(segment_values[key].sum())
updates_done += int(updates_this_segment)
global_step = int(updates_done * rollout_steps)
global_step = int(updates_done * rollout_steps_global)
if is_primary and HAS_WANDB and wandb.run is not None:
wandb.log(
@@ -842,7 +857,7 @@ def _train_actor_critic(
metadata={
"step": global_step,
"updates_done": updates_done,
"rollout_steps": rollout_steps,
"rollout_steps": rollout_steps_global,
"algo": algo,
},
)
@@ -863,7 +878,7 @@ def _train_actor_critic(
"train/agent_prob": float(metric_sums["agent_prob"] / denom),
"train/alpha_adv": float(metric_sums["alpha_adv"] / denom),
"train/coi_leakage": float(metric_sums["coi_leakage"] / denom),
"train/global_step": int(updates_done * rollout_steps),
"train/global_step": int(updates_done * rollout_steps_global),
}
eval_metrics = evaluate_policy(

View File

@@ -9,10 +9,16 @@ import numpy as np
from .wandb_checkpoint import checkpoint_artifact_name, download_latest_checkpoint
try:
import wandb
import wandb as _wandb
HAS_WANDB = True
if hasattr(_wandb, "init") and callable(_wandb.init):
wandb = _wandb
HAS_WANDB = True
else:
wandb = None
HAS_WANDB = False
except ImportError:
wandb = None
HAS_WANDB = False
try:
@@ -80,7 +86,7 @@ DEFAULT_CFG = {
"jax_num_minibatches": 4,
"jax_update_epochs": 4,
"jax_anneal_lr": True,
"checkpoint_interval": 10_000,
"checkpoint_interval": 200_000,
}
@@ -404,6 +410,16 @@ def run_wandb(
) -> dict:
if not HAS_WANDB:
raise ImportError("wandb is required for sweep runs")
if not sweep_mode:
pre_cfg = _cfg(overrides)
if pre_cfg.get("use_jax"):
try:
import jax
if jax.process_count() > 1 and jax.process_index() != 0:
return train_once(pre_cfg)
except Exception:
pass
init_kwargs = {"mode": mode}
if sweep_mode:
run = wandb.init(**init_kwargs)
@@ -431,7 +447,16 @@ def run_wandb(
def run_local(overrides: dict) -> dict:
cfg = _cfg(overrides)
metrics = train_once(cfg)
print(json.dumps(metrics, indent=2))
should_print = True
if cfg.get("use_jax"):
try:
import jax
should_print = jax.process_index() == 0
except Exception:
should_print = True
if should_print:
print(json.dumps(metrics, indent=2))
return metrics
@@ -439,15 +464,26 @@ def main():
p = argparse.ArgumentParser(description="PHANTOM training and W&B sweeps")
p.add_argument("--project", default=DEFAULT_CFG["project"])
p.add_argument("--algo", choices=["ppo", "a2c", "dqn", "qtable"])
p.add_argument("--seed", type=int)
p.add_argument("--total-timesteps", type=int)
p.add_argument("--alpha", type=float)
p.add_argument("--N", type=int)
p.add_argument("--n-products", type=int)
p.add_argument("--lambda-coi", type=float)
p.add_argument("--info-value", type=float)
p.add_argument("--robust-radius", type=float)
p.add_argument("--robust-points", type=int)
p.add_argument("--learning-rate", type=float)
p.add_argument("--gamma", type=float)
p.add_argument("--gae-lambda", type=float)
p.add_argument("--clip-range", type=float)
p.add_argument("--ent-coef", type=float)
p.add_argument("--revenue-weight", type=float)
p.add_argument("--price-low", type=float)
p.add_argument("--price-high", type=float)
p.add_argument("--action-levels", type=int)
p.add_argument("--action-scale-low", type=float)
p.add_argument("--action-scale-high", type=float)
p.add_argument("--max-steps", type=int)
p.add_argument("--margin-floor", type=float)
p.add_argument("--margin-floor-patience", type=int)
@@ -469,15 +505,26 @@ def main():
overrides = {
"algo": args.algo,
"seed": args.seed,
"total_timesteps": args.total_timesteps,
"alpha": args.alpha,
"N": args.N,
"n_products": args.n_products,
"lambda_coi": args.lambda_coi,
"info_value": args.info_value,
"robust_radius": args.robust_radius,
"robust_points": args.robust_points,
"learning_rate": args.learning_rate,
"gamma": args.gamma,
"gae_lambda": args.gae_lambda,
"clip_range": args.clip_range,
"ent_coef": args.ent_coef,
"revenue_weight": args.revenue_weight,
"price_low": args.price_low,
"price_high": args.price_high,
"action_levels": args.action_levels,
"action_scale_low": args.action_scale_low,
"action_scale_high": args.action_scale_high,
"max_steps": args.max_steps,
"margin_floor": args.margin_floor,
"margin_floor_patience": args.margin_floor_patience,