mirror of
https://github.com/velocitatem/PHANTOM.git
synced 2026-05-31 08:33:36 +00:00
fixing models for gcp
This commit is contained in:
@@ -7,6 +7,19 @@ from typing import Any, NamedTuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
try:
|
||||
import wandb
|
||||
|
||||
HAS_WANDB = True
|
||||
except ImportError:
|
||||
HAS_WANDB = False
|
||||
|
||||
from ..wandb_checkpoint import (
|
||||
checkpoint_artifact_name,
|
||||
download_latest_checkpoint,
|
||||
log_checkpoint_bytes,
|
||||
)
|
||||
|
||||
try:
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
@@ -142,6 +155,7 @@ def _jax_cfg(cfg: dict[str, Any]) -> dict[str, Any]:
|
||||
"num_minibatches": int(cfg.get("jax_num_minibatches", 4)),
|
||||
"update_epochs": int(cfg.get("jax_update_epochs", 4)),
|
||||
"anneal_lr": bool(cfg.get("jax_anneal_lr", True)),
|
||||
"checkpoint_interval": int(cfg.get("checkpoint_interval", 10_000)),
|
||||
}
|
||||
rollout = out["num_envs"] * out["num_steps"]
|
||||
out["num_updates"] = max(1, out["total_timesteps"] // max(rollout, 1))
|
||||
@@ -185,201 +199,198 @@ def make_train(config: dict[str, Any]):
|
||||
frac = 1.0 - updates_done / max(cfg["num_updates"], 1)
|
||||
return cfg["learning_rate"] * frac
|
||||
|
||||
def train(rng: jax.Array):
|
||||
if cfg["anneal_lr"]:
|
||||
tx = optax.chain(
|
||||
optax.clip_by_global_norm(cfg["max_grad_norm"]),
|
||||
optax.adam(learning_rate=linear_schedule, eps=1e-5),
|
||||
)
|
||||
else:
|
||||
tx = optax.chain(
|
||||
optax.clip_by_global_norm(cfg["max_grad_norm"]),
|
||||
optax.adam(cfg["learning_rate"], eps=1e-5),
|
||||
)
|
||||
|
||||
def init_runner_state(rng: jax.Array):
|
||||
rng, init_key = jax.random.split(rng)
|
||||
init_obs = jnp.zeros((env.observation_dim(),), dtype=jnp.float32)
|
||||
params = network.init(init_key, init_obs)
|
||||
|
||||
if cfg["anneal_lr"]:
|
||||
tx = optax.chain(
|
||||
optax.clip_by_global_norm(cfg["max_grad_norm"]),
|
||||
optax.adam(learning_rate=linear_schedule, eps=1e-5),
|
||||
)
|
||||
else:
|
||||
tx = optax.chain(
|
||||
optax.clip_by_global_norm(cfg["max_grad_norm"]),
|
||||
optax.adam(cfg["learning_rate"], eps=1e-5),
|
||||
)
|
||||
train_state = TrainState.create(apply_fn=network.apply, params=params, tx=tx)
|
||||
|
||||
rng, reset_key = jax.random.split(rng)
|
||||
reset_keys = jax.random.split(reset_key, cfg["num_envs"])
|
||||
obs, env_state = jax.vmap(env.reset)(reset_keys)
|
||||
return train_state, env_state, obs, rng
|
||||
|
||||
def _update_step(runner_state, _):
|
||||
def _env_step(runner_state, _):
|
||||
train_state, env_state, last_obs, rng = runner_state
|
||||
rng, action_key = jax.random.split(rng)
|
||||
policy, value = network.apply(train_state.params, last_obs)
|
||||
action = policy.sample(seed=action_key)
|
||||
log_prob = policy.log_prob(action)
|
||||
|
||||
rng, step_key = jax.random.split(rng)
|
||||
step_keys = jax.random.split(step_key, cfg["num_envs"])
|
||||
nxt_obs, nxt_state, reward, done, info = jax.vmap(
|
||||
env.step,
|
||||
in_axes=(0, 0, 0),
|
||||
)(step_keys, env_state, action)
|
||||
|
||||
rng, reset_key = jax.random.split(rng)
|
||||
reset_keys = jax.random.split(reset_key, cfg["num_envs"])
|
||||
rst_obs, rst_state = jax.vmap(env.reset)(reset_keys)
|
||||
obs_next = jnp.where(done[:, None], rst_obs, nxt_obs)
|
||||
env_next = jax.tree_util.tree_map(
|
||||
lambda keep, reset: _select_env_state(done, keep, reset),
|
||||
nxt_state,
|
||||
rst_state,
|
||||
)
|
||||
transition = Transition(
|
||||
done=done,
|
||||
action=action,
|
||||
value=value,
|
||||
reward=reward,
|
||||
log_prob=log_prob,
|
||||
obs=last_obs,
|
||||
info=info,
|
||||
)
|
||||
return (train_state, env_next, obs_next, rng), transition
|
||||
|
||||
runner_state, traj_batch = jax.lax.scan(
|
||||
_env_step,
|
||||
runner_state,
|
||||
None,
|
||||
length=cfg["num_steps"],
|
||||
)
|
||||
|
||||
def _update_step(runner_state, _):
|
||||
def _env_step(runner_state, _):
|
||||
train_state, env_state, last_obs, rng = runner_state
|
||||
_, last_value = network.apply(train_state.params, last_obs)
|
||||
rng, action_key = jax.random.split(rng)
|
||||
policy, value = network.apply(train_state.params, last_obs)
|
||||
action = policy.sample(seed=action_key)
|
||||
log_prob = policy.log_prob(action)
|
||||
|
||||
def _compute_gae(traj_batch, last_value):
|
||||
def _gae_step(carry, transition):
|
||||
gae, next_value = carry
|
||||
delta = (
|
||||
transition.reward
|
||||
+ cfg["gamma"] * next_value * (1.0 - transition.done)
|
||||
- transition.value
|
||||
)
|
||||
gae = (
|
||||
delta
|
||||
+ cfg["gamma"]
|
||||
* cfg["gae_lambda"]
|
||||
* (1.0 - transition.done)
|
||||
* gae
|
||||
)
|
||||
return (gae, transition.value), gae
|
||||
rng, step_key = jax.random.split(rng)
|
||||
step_keys = jax.random.split(step_key, cfg["num_envs"])
|
||||
nxt_obs, nxt_state, reward, done, info = jax.vmap(
|
||||
env.step,
|
||||
in_axes=(0, 0, 0),
|
||||
)(step_keys, env_state, action)
|
||||
|
||||
_, advantages = jax.lax.scan(
|
||||
_gae_step,
|
||||
(jnp.zeros_like(last_value), last_value),
|
||||
traj_batch,
|
||||
reverse=True,
|
||||
unroll=16,
|
||||
)
|
||||
targets = advantages + traj_batch.value
|
||||
return advantages, targets
|
||||
|
||||
advantages, targets = _compute_gae(traj_batch, last_value)
|
||||
|
||||
def _update_epoch(update_state, _):
|
||||
def _update_minibatch(train_state, batch_info):
|
||||
traj_b, adv_b, tgt_b = batch_info
|
||||
|
||||
def _loss_fn(params, traj_b, adv_b, tgt_b):
|
||||
policy, value = network.apply(params, traj_b.obs)
|
||||
log_prob = policy.log_prob(traj_b.action)
|
||||
|
||||
value_clipped = traj_b.value + (value - traj_b.value).clip(
|
||||
-cfg["clip_range"], cfg["clip_range"]
|
||||
)
|
||||
value_loss = (
|
||||
0.5
|
||||
* jnp.maximum(
|
||||
jnp.square(value - tgt_b),
|
||||
jnp.square(value_clipped - tgt_b),
|
||||
).mean()
|
||||
)
|
||||
|
||||
adv_norm = (adv_b - adv_b.mean()) / (adv_b.std() + 1e-8)
|
||||
ratio = jnp.exp(log_prob - traj_b.log_prob)
|
||||
loss_actor = -jnp.minimum(
|
||||
ratio * adv_norm,
|
||||
jnp.clip(
|
||||
ratio,
|
||||
1.0 - cfg["clip_range"],
|
||||
1.0 + cfg["clip_range"],
|
||||
)
|
||||
* adv_norm,
|
||||
).mean()
|
||||
entropy = policy.entropy().mean()
|
||||
total_loss = (
|
||||
loss_actor
|
||||
+ cfg["vf_coef"] * value_loss
|
||||
- cfg["ent_coef"] * entropy
|
||||
)
|
||||
return total_loss, (value_loss, loss_actor, entropy)
|
||||
|
||||
grad_fn = jax.value_and_grad(_loss_fn, has_aux=True)
|
||||
(_, _), grads = grad_fn(train_state.params, traj_b, adv_b, tgt_b)
|
||||
train_state = train_state.apply_gradients(grads=grads)
|
||||
return train_state, jnp.asarray(0.0, dtype=jnp.float32)
|
||||
|
||||
train_state, traj_batch, advantages, targets, rng = update_state
|
||||
rng, perm_key = jax.random.split(rng)
|
||||
batch_size = cfg["num_envs"] * cfg["num_steps"]
|
||||
permutation = jax.random.permutation(perm_key, batch_size)
|
||||
batch = (traj_batch, advantages, targets)
|
||||
batch = jax.tree_util.tree_map(
|
||||
lambda x: x.reshape((batch_size,) + x.shape[2:]),
|
||||
batch,
|
||||
)
|
||||
shuffled = jax.tree_util.tree_map(
|
||||
lambda x: jnp.take(x, permutation, axis=0),
|
||||
batch,
|
||||
)
|
||||
minibatches = jax.tree_util.tree_map(
|
||||
lambda x: x.reshape(
|
||||
(cfg["num_minibatches"], cfg["minibatch_size"]) + x.shape[1:]
|
||||
),
|
||||
shuffled,
|
||||
)
|
||||
train_state, _ = jax.lax.scan(
|
||||
_update_minibatch, train_state, minibatches
|
||||
)
|
||||
return (train_state, traj_batch, advantages, targets, rng), None
|
||||
|
||||
update_state = (train_state, traj_batch, advantages, targets, rng)
|
||||
update_state, _ = jax.lax.scan(
|
||||
_update_epoch,
|
||||
update_state,
|
||||
None,
|
||||
length=cfg["update_epochs"],
|
||||
rng, reset_key = jax.random.split(rng)
|
||||
reset_keys = jax.random.split(reset_key, cfg["num_envs"])
|
||||
rst_obs, rst_state = jax.vmap(env.reset)(reset_keys)
|
||||
obs_next = jnp.where(done[:, None], rst_obs, nxt_obs)
|
||||
env_next = jax.tree_util.tree_map(
|
||||
lambda keep, reset: _select_env_state(done, keep, reset),
|
||||
nxt_state,
|
||||
rst_state,
|
||||
)
|
||||
train_state = update_state[0]
|
||||
rng = update_state[-1]
|
||||
transition = Transition(
|
||||
done=done,
|
||||
action=action,
|
||||
value=value,
|
||||
reward=reward,
|
||||
log_prob=log_prob,
|
||||
obs=last_obs,
|
||||
info=info,
|
||||
)
|
||||
return (train_state, env_next, obs_next, rng), transition
|
||||
|
||||
metric = {
|
||||
"reward": jnp.mean(traj_batch.reward),
|
||||
"revenue": jnp.mean(traj_batch.info["revenue"]),
|
||||
"agent_prob": jnp.mean(traj_batch.info["agent_prob"]),
|
||||
"alpha_adv": jnp.mean(traj_batch.info["alpha_adv"]),
|
||||
"coi_leakage": jnp.mean(traj_batch.info["coi_leakage"]),
|
||||
}
|
||||
runner_state = (train_state, env_state, last_obs, rng)
|
||||
return runner_state, metric
|
||||
runner_state, traj_batch = jax.lax.scan(
|
||||
_env_step,
|
||||
runner_state,
|
||||
None,
|
||||
length=cfg["num_steps"],
|
||||
)
|
||||
|
||||
runner_state = (train_state, env_state, obs, rng)
|
||||
train_state, env_state, last_obs, rng = runner_state
|
||||
_, last_value = network.apply(train_state.params, last_obs)
|
||||
|
||||
def _compute_gae(traj_batch, last_value):
|
||||
def _gae_step(carry, transition):
|
||||
gae, next_value = carry
|
||||
delta = (
|
||||
transition.reward
|
||||
+ cfg["gamma"] * next_value * (1.0 - transition.done)
|
||||
- transition.value
|
||||
)
|
||||
gae = (
|
||||
delta
|
||||
+ cfg["gamma"] * cfg["gae_lambda"] * (1.0 - transition.done) * gae
|
||||
)
|
||||
return (gae, transition.value), gae
|
||||
|
||||
_, advantages = jax.lax.scan(
|
||||
_gae_step,
|
||||
(jnp.zeros_like(last_value), last_value),
|
||||
traj_batch,
|
||||
reverse=True,
|
||||
unroll=16,
|
||||
)
|
||||
targets = advantages + traj_batch.value
|
||||
return advantages, targets
|
||||
|
||||
advantages, targets = _compute_gae(traj_batch, last_value)
|
||||
|
||||
def _update_epoch(update_state, _):
|
||||
def _update_minibatch(train_state, batch_info):
|
||||
traj_b, adv_b, tgt_b = batch_info
|
||||
|
||||
def _loss_fn(params, traj_b, adv_b, tgt_b):
|
||||
policy, value = network.apply(params, traj_b.obs)
|
||||
log_prob = policy.log_prob(traj_b.action)
|
||||
|
||||
value_clipped = traj_b.value + (value - traj_b.value).clip(
|
||||
-cfg["clip_range"], cfg["clip_range"]
|
||||
)
|
||||
value_loss = (
|
||||
0.5
|
||||
* jnp.maximum(
|
||||
jnp.square(value - tgt_b),
|
||||
jnp.square(value_clipped - tgt_b),
|
||||
).mean()
|
||||
)
|
||||
|
||||
adv_norm = (adv_b - adv_b.mean()) / (adv_b.std() + 1e-8)
|
||||
ratio = jnp.exp(log_prob - traj_b.log_prob)
|
||||
loss_actor = -jnp.minimum(
|
||||
ratio * adv_norm,
|
||||
jnp.clip(
|
||||
ratio,
|
||||
1.0 - cfg["clip_range"],
|
||||
1.0 + cfg["clip_range"],
|
||||
)
|
||||
* adv_norm,
|
||||
).mean()
|
||||
entropy = policy.entropy().mean()
|
||||
total_loss = (
|
||||
loss_actor
|
||||
+ cfg["vf_coef"] * value_loss
|
||||
- cfg["ent_coef"] * entropy
|
||||
)
|
||||
return total_loss, (value_loss, loss_actor, entropy)
|
||||
|
||||
grad_fn = jax.value_and_grad(_loss_fn, has_aux=True)
|
||||
(_, _), grads = grad_fn(train_state.params, traj_b, adv_b, tgt_b)
|
||||
train_state = train_state.apply_gradients(grads=grads)
|
||||
return train_state, jnp.asarray(0.0, dtype=jnp.float32)
|
||||
|
||||
train_state, traj_batch, advantages, targets, rng = update_state
|
||||
rng, perm_key = jax.random.split(rng)
|
||||
batch_size = cfg["num_envs"] * cfg["num_steps"]
|
||||
permutation = jax.random.permutation(perm_key, batch_size)
|
||||
batch = (traj_batch, advantages, targets)
|
||||
batch = jax.tree_util.tree_map(
|
||||
lambda x: x.reshape((batch_size,) + x.shape[2:]),
|
||||
batch,
|
||||
)
|
||||
shuffled = jax.tree_util.tree_map(
|
||||
lambda x: jnp.take(x, permutation, axis=0),
|
||||
batch,
|
||||
)
|
||||
minibatches = jax.tree_util.tree_map(
|
||||
lambda x: x.reshape(
|
||||
(cfg["num_minibatches"], cfg["minibatch_size"]) + x.shape[1:]
|
||||
),
|
||||
shuffled,
|
||||
)
|
||||
train_state, _ = jax.lax.scan(_update_minibatch, train_state, minibatches)
|
||||
return (train_state, traj_batch, advantages, targets, rng), None
|
||||
|
||||
update_state = (train_state, traj_batch, advantages, targets, rng)
|
||||
update_state, _ = jax.lax.scan(
|
||||
_update_epoch,
|
||||
update_state,
|
||||
None,
|
||||
length=cfg["update_epochs"],
|
||||
)
|
||||
train_state = update_state[0]
|
||||
rng = update_state[-1]
|
||||
|
||||
metric = {
|
||||
"reward": jnp.mean(traj_batch.reward),
|
||||
"revenue": jnp.mean(traj_batch.info["revenue"]),
|
||||
"agent_prob": jnp.mean(traj_batch.info["agent_prob"]),
|
||||
"alpha_adv": jnp.mean(traj_batch.info["alpha_adv"]),
|
||||
"coi_leakage": jnp.mean(traj_batch.info["coi_leakage"]),
|
||||
}
|
||||
next_runner_state = (train_state, env_state, last_obs, rng)
|
||||
return next_runner_state, metric
|
||||
|
||||
def run_updates(runner_state, *, num_updates: int):
|
||||
updates = max(1, int(num_updates))
|
||||
runner_state, metric = jax.lax.scan(
|
||||
_update_step,
|
||||
runner_state,
|
||||
None,
|
||||
length=cfg["num_updates"],
|
||||
length=updates,
|
||||
)
|
||||
return {
|
||||
"runner_state": runner_state,
|
||||
"metrics": metric,
|
||||
}
|
||||
|
||||
return train, network, env, cfg
|
||||
return init_runner_state, run_updates, network, env, cfg
|
||||
|
||||
|
||||
def evaluate_policy(
|
||||
@@ -436,22 +447,103 @@ def train_jax(cfg: dict[str, Any]) -> tuple[dict[str, Any], dict[str, float]]:
|
||||
f"JAX backend currently supports algo='ppo' only, got '{run_cfg['algo']}'"
|
||||
)
|
||||
|
||||
train_fn, network, env, run_cfg = make_train(run_cfg)
|
||||
train_jit = jax.jit(train_fn)
|
||||
rng = jax.random.PRNGKey(run_cfg["seed"])
|
||||
out = train_jit(rng)
|
||||
init_runner_state, run_updates, network, env, run_cfg = make_train(run_cfg)
|
||||
run_updates_jit = jax.jit(run_updates, static_argnames=("num_updates",))
|
||||
rollout_steps = int(run_cfg["num_steps"] * run_cfg["num_envs"])
|
||||
total_updates = int(run_cfg["num_updates"])
|
||||
checkpoint_interval = max(1, int(run_cfg.get("checkpoint_interval", 10_000)))
|
||||
segment_updates = max(1, checkpoint_interval // max(rollout_steps, 1))
|
||||
|
||||
train_state = out["runner_state"][0]
|
||||
metric = out["metrics"]
|
||||
rng = jax.random.PRNGKey(run_cfg["seed"])
|
||||
runner_state = init_runner_state(rng)
|
||||
updates_done = 0
|
||||
|
||||
artifact_name = None
|
||||
if HAS_WANDB and wandb.run is not None:
|
||||
sweep_id = getattr(wandb.run, "sweep_id", None)
|
||||
artifact_name = checkpoint_artifact_name(
|
||||
run_cfg,
|
||||
backend="jax",
|
||||
sweep_id=sweep_id,
|
||||
)
|
||||
restored = download_latest_checkpoint(
|
||||
artifact_name,
|
||||
file_name="jax_runner_state.msgpack",
|
||||
)
|
||||
if restored is not None:
|
||||
checkpoint_path, metadata = restored
|
||||
template = {
|
||||
"runner_state": runner_state,
|
||||
"updates_done": 0,
|
||||
}
|
||||
payload = serialization.from_bytes(template, checkpoint_path.read_bytes())
|
||||
runner_state = payload["runner_state"]
|
||||
updates_done = int(payload.get("updates_done", 0))
|
||||
if updates_done <= 0:
|
||||
updates_done = int(metadata.get("updates_done", 0))
|
||||
updates_done = max(0, min(updates_done, total_updates))
|
||||
|
||||
metric_keys = ["reward", "revenue", "agent_prob", "alpha_adv", "coi_leakage"]
|
||||
metric_sums = {k: 0.0 for k in metric_keys}
|
||||
metric_count = 0
|
||||
|
||||
while updates_done < total_updates:
|
||||
updates_this_segment = min(segment_updates, total_updates - updates_done)
|
||||
out = run_updates_jit(runner_state, num_updates=updates_this_segment)
|
||||
runner_state = out["runner_state"]
|
||||
metric = out["metrics"]
|
||||
|
||||
segment_values = {
|
||||
k: np.asarray(metric[k], dtype=np.float64) for k in metric_keys
|
||||
}
|
||||
segment_count = int(segment_values["reward"].shape[0]) if segment_values else 0
|
||||
metric_count += segment_count
|
||||
for key in metric_keys:
|
||||
metric_sums[key] += float(segment_values[key].sum())
|
||||
|
||||
updates_done += int(updates_this_segment)
|
||||
global_step = int(updates_done * rollout_steps)
|
||||
|
||||
if HAS_WANDB and wandb.run is not None:
|
||||
wandb.log(
|
||||
{
|
||||
"train/reward": float(segment_values["reward"].mean()),
|
||||
"train/revenue": float(segment_values["revenue"].mean()),
|
||||
"train/agent_prob": float(segment_values["agent_prob"].mean()),
|
||||
"train/alpha_adv": float(segment_values["alpha_adv"].mean()),
|
||||
"train/coi_leakage": float(segment_values["coi_leakage"].mean()),
|
||||
"train/global_step": global_step,
|
||||
},
|
||||
step=global_step,
|
||||
)
|
||||
if artifact_name is not None:
|
||||
checkpoint_payload = serialization.to_bytes(
|
||||
{
|
||||
"runner_state": runner_state,
|
||||
"updates_done": updates_done,
|
||||
}
|
||||
)
|
||||
log_checkpoint_bytes(
|
||||
artifact_name,
|
||||
file_name="jax_runner_state.msgpack",
|
||||
payload=checkpoint_payload,
|
||||
metadata={
|
||||
"step": global_step,
|
||||
"updates_done": updates_done,
|
||||
"rollout_steps": rollout_steps,
|
||||
"algo": "ppo",
|
||||
},
|
||||
)
|
||||
|
||||
train_state = runner_state[0]
|
||||
denom = float(metric_count) if metric_count > 0 else 1.0
|
||||
metrics = {
|
||||
"train/reward": float(np.mean(np.asarray(metric["reward"]))),
|
||||
"train/revenue": float(np.mean(np.asarray(metric["revenue"]))),
|
||||
"train/agent_prob": float(np.mean(np.asarray(metric["agent_prob"]))),
|
||||
"train/alpha_adv": float(np.mean(np.asarray(metric["alpha_adv"]))),
|
||||
"train/coi_leakage": float(np.mean(np.asarray(metric["coi_leakage"]))),
|
||||
"train/global_step": int(
|
||||
run_cfg["num_updates"] * run_cfg["num_steps"] * run_cfg["num_envs"]
|
||||
),
|
||||
"train/reward": float(metric_sums["reward"] / denom),
|
||||
"train/revenue": float(metric_sums["revenue"] / denom),
|
||||
"train/agent_prob": float(metric_sums["agent_prob"] / denom),
|
||||
"train/alpha_adv": float(metric_sums["alpha_adv"] / denom),
|
||||
"train/coi_leakage": float(metric_sums["coi_leakage"] / denom),
|
||||
"train/global_step": int(updates_done * rollout_steps),
|
||||
}
|
||||
|
||||
eval_metrics = evaluate_policy(
|
||||
|
||||
Reference in New Issue
Block a user