mirror of
https://github.com/velocitatem/PHANTOM.git
synced 2026-05-31 08:33:36 +00:00
adding naive jax and libraries and make adjustments
This commit is contained in:
471
engine/jax/train.py
Normal file
471
engine/jax/train.py
Normal file
@@ -0,0 +1,471 @@
|
||||
"""Pure JAX PPO trainer for the PHANTOM environment."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any, NamedTuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
try:
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import distrax
|
||||
import flax.linen as nn
|
||||
import optax
|
||||
from flax import serialization
|
||||
from flax.linen.initializers import constant, orthogonal
|
||||
from flax.training.train_state import TrainState
|
||||
|
||||
HAS_JAX_STACK = True
|
||||
except ImportError:
|
||||
jax = None # type: ignore[assignment]
|
||||
jnp = None # type: ignore[assignment]
|
||||
distrax = None # type: ignore[assignment]
|
||||
optax = None # type: ignore[assignment]
|
||||
serialization = None # type: ignore[assignment]
|
||||
|
||||
class _ModuleStub:
|
||||
pass
|
||||
|
||||
class _NNStub:
|
||||
Module = _ModuleStub
|
||||
|
||||
@staticmethod
|
||||
def compact(fn):
|
||||
return fn
|
||||
|
||||
nn = _NNStub() # type: ignore[assignment]
|
||||
|
||||
def constant(*_args, **_kwargs): # type: ignore[override]
|
||||
return None
|
||||
|
||||
def orthogonal(*_args, **_kwargs): # type: ignore[override]
|
||||
return None
|
||||
|
||||
class TrainState: # type: ignore[override]
|
||||
pass
|
||||
|
||||
HAS_JAX_STACK = False
|
||||
|
||||
from .env import PHANTOMJAXEnv, make_env_params
|
||||
|
||||
|
||||
class ActorCritic(nn.Module):
|
||||
action_dim: int
|
||||
activation: str = "tanh"
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, x):
|
||||
activation_fn = nn.relu if self.activation == "relu" else nn.tanh
|
||||
|
||||
actor = nn.Dense(
|
||||
64,
|
||||
kernel_init=orthogonal(np.sqrt(2.0)),
|
||||
bias_init=constant(0.0),
|
||||
)(x)
|
||||
actor = activation_fn(actor)
|
||||
actor = nn.Dense(
|
||||
64,
|
||||
kernel_init=orthogonal(np.sqrt(2.0)),
|
||||
bias_init=constant(0.0),
|
||||
)(actor)
|
||||
actor = activation_fn(actor)
|
||||
logits = nn.Dense(
|
||||
self.action_dim,
|
||||
kernel_init=orthogonal(0.01),
|
||||
bias_init=constant(0.0),
|
||||
)(actor)
|
||||
|
||||
critic = nn.Dense(
|
||||
64,
|
||||
kernel_init=orthogonal(np.sqrt(2.0)),
|
||||
bias_init=constant(0.0),
|
||||
)(x)
|
||||
critic = activation_fn(critic)
|
||||
critic = nn.Dense(
|
||||
64,
|
||||
kernel_init=orthogonal(np.sqrt(2.0)),
|
||||
bias_init=constant(0.0),
|
||||
)(critic)
|
||||
critic = activation_fn(critic)
|
||||
value = nn.Dense(1, kernel_init=orthogonal(1.0), bias_init=constant(0.0))(
|
||||
critic
|
||||
)
|
||||
return distrax.Categorical(logits=logits), jnp.squeeze(value, axis=-1)
|
||||
|
||||
|
||||
class Transition(NamedTuple):
|
||||
done: jax.Array
|
||||
action: jax.Array
|
||||
value: jax.Array
|
||||
reward: jax.Array
|
||||
log_prob: jax.Array
|
||||
obs: jax.Array
|
||||
info: dict[str, jax.Array]
|
||||
|
||||
|
||||
def _jax_cfg(cfg: dict[str, Any]) -> dict[str, Any]:
|
||||
out = {
|
||||
"algo": str(cfg.get("algo", "ppo")).lower(),
|
||||
"seed": int(cfg.get("seed", 42)),
|
||||
"learning_rate": float(cfg.get("learning_rate", 3e-4)),
|
||||
"gamma": float(cfg.get("gamma", 0.99)),
|
||||
"gae_lambda": float(cfg.get("gae_lambda", 0.95)),
|
||||
"clip_range": float(cfg.get("clip_range", 0.2)),
|
||||
"ent_coef": float(cfg.get("ent_coef", 0.01)),
|
||||
"vf_coef": float(cfg.get("vf_coef", 0.5)),
|
||||
"max_grad_norm": float(cfg.get("max_grad_norm", 0.5)),
|
||||
"activation": str(cfg.get("activation", "relu")),
|
||||
"total_timesteps": int(cfg.get("total_timesteps", 50_000)),
|
||||
"eval_episodes": int(cfg.get("eval_episodes", 5)),
|
||||
"model_dir": str(cfg.get("model_dir", "engine/models")),
|
||||
"n_products": int(cfg.get("n_products", 10)),
|
||||
"N": int(cfg.get("N", 100)),
|
||||
"alpha": float(cfg.get("alpha", 0.3)),
|
||||
"lambda_coi": float(cfg.get("lambda_coi", 0.2)),
|
||||
"robust_radius": float(cfg.get("robust_radius", 0.15)),
|
||||
"robust_points": int(cfg.get("robust_points", 5)),
|
||||
"info_value": float(cfg.get("info_value", 1.0)),
|
||||
"price_low": float(cfg.get("price_low", 10.0)),
|
||||
"price_high": float(cfg.get("price_high", 150.0)),
|
||||
"action_levels": int(cfg.get("action_levels", 9)),
|
||||
"action_scale_low": float(cfg.get("action_scale_low", 0.8)),
|
||||
"action_scale_high": float(cfg.get("action_scale_high", 1.2)),
|
||||
"max_episode_steps": int(cfg.get("max_steps", 100)),
|
||||
"max_session_steps": int(cfg.get("max_session_steps", 40)),
|
||||
"margin_floor": float(cfg.get("margin_floor", 0.05)),
|
||||
"margin_floor_patience": int(cfg.get("margin_floor_patience", 5)),
|
||||
"prefer_behavior_data": bool(cfg.get("prefer_behavior_data", True)),
|
||||
"num_envs": int(cfg.get("jax_num_envs", 16)),
|
||||
"num_steps": int(cfg.get("jax_num_steps", 128)),
|
||||
"num_minibatches": int(cfg.get("jax_num_minibatches", 4)),
|
||||
"update_epochs": int(cfg.get("jax_update_epochs", 4)),
|
||||
"anneal_lr": bool(cfg.get("jax_anneal_lr", True)),
|
||||
}
|
||||
rollout = out["num_envs"] * out["num_steps"]
|
||||
out["num_updates"] = max(1, out["total_timesteps"] // max(rollout, 1))
|
||||
out["minibatch_size"] = max(1, rollout // max(out["num_minibatches"], 1))
|
||||
return out
|
||||
|
||||
|
||||
def _select_env_state(done: jax.Array, keep: jax.Array, reset: jax.Array) -> jax.Array:
|
||||
mask = done
|
||||
while mask.ndim < keep.ndim:
|
||||
mask = mask[..., None]
|
||||
return jnp.where(mask, reset, keep)
|
||||
|
||||
|
||||
def make_train(config: dict[str, Any]):
|
||||
cfg = _jax_cfg(config)
|
||||
env_params = make_env_params(
|
||||
n_products=cfg["n_products"],
|
||||
alpha=cfg["alpha"],
|
||||
n_sessions=cfg["N"],
|
||||
lambda_coi=cfg["lambda_coi"],
|
||||
robust_radius=cfg["robust_radius"],
|
||||
robust_points=cfg["robust_points"],
|
||||
info_value=cfg["info_value"],
|
||||
action_levels=cfg["action_levels"],
|
||||
action_scale_low=cfg["action_scale_low"],
|
||||
action_scale_high=cfg["action_scale_high"],
|
||||
price_low=cfg["price_low"],
|
||||
price_high=cfg["price_high"],
|
||||
max_episode_steps=cfg["max_episode_steps"],
|
||||
max_session_steps=cfg["max_session_steps"],
|
||||
margin_floor=cfg["margin_floor"],
|
||||
margin_floor_patience=cfg["margin_floor_patience"],
|
||||
prefer_behavior_data=cfg["prefer_behavior_data"],
|
||||
)
|
||||
env = PHANTOMJAXEnv(env_params)
|
||||
network = ActorCritic(env.action_space_n(), activation=cfg["activation"])
|
||||
|
||||
def linear_schedule(count: jax.Array) -> jax.Array:
|
||||
updates_done = count // (cfg["num_minibatches"] * cfg["update_epochs"])
|
||||
frac = 1.0 - updates_done / max(cfg["num_updates"], 1)
|
||||
return cfg["learning_rate"] * frac
|
||||
|
||||
def train(rng: jax.Array):
|
||||
rng, init_key = jax.random.split(rng)
|
||||
init_obs = jnp.zeros((env.observation_dim(),), dtype=jnp.float32)
|
||||
params = network.init(init_key, init_obs)
|
||||
|
||||
if cfg["anneal_lr"]:
|
||||
tx = optax.chain(
|
||||
optax.clip_by_global_norm(cfg["max_grad_norm"]),
|
||||
optax.adam(learning_rate=linear_schedule, eps=1e-5),
|
||||
)
|
||||
else:
|
||||
tx = optax.chain(
|
||||
optax.clip_by_global_norm(cfg["max_grad_norm"]),
|
||||
optax.adam(cfg["learning_rate"], eps=1e-5),
|
||||
)
|
||||
train_state = TrainState.create(apply_fn=network.apply, params=params, tx=tx)
|
||||
|
||||
rng, reset_key = jax.random.split(rng)
|
||||
reset_keys = jax.random.split(reset_key, cfg["num_envs"])
|
||||
obs, env_state = jax.vmap(env.reset)(reset_keys)
|
||||
|
||||
def _update_step(runner_state, _):
|
||||
def _env_step(runner_state, _):
|
||||
train_state, env_state, last_obs, rng = runner_state
|
||||
rng, action_key = jax.random.split(rng)
|
||||
policy, value = network.apply(train_state.params, last_obs)
|
||||
action = policy.sample(seed=action_key)
|
||||
log_prob = policy.log_prob(action)
|
||||
|
||||
rng, step_key = jax.random.split(rng)
|
||||
step_keys = jax.random.split(step_key, cfg["num_envs"])
|
||||
nxt_obs, nxt_state, reward, done, info = jax.vmap(
|
||||
env.step,
|
||||
in_axes=(0, 0, 0),
|
||||
)(step_keys, env_state, action)
|
||||
|
||||
rng, reset_key = jax.random.split(rng)
|
||||
reset_keys = jax.random.split(reset_key, cfg["num_envs"])
|
||||
rst_obs, rst_state = jax.vmap(env.reset)(reset_keys)
|
||||
obs_next = jnp.where(done[:, None], rst_obs, nxt_obs)
|
||||
env_next = jax.tree_util.tree_map(
|
||||
lambda keep, reset: _select_env_state(done, keep, reset),
|
||||
nxt_state,
|
||||
rst_state,
|
||||
)
|
||||
transition = Transition(
|
||||
done=done,
|
||||
action=action,
|
||||
value=value,
|
||||
reward=reward,
|
||||
log_prob=log_prob,
|
||||
obs=last_obs,
|
||||
info=info,
|
||||
)
|
||||
return (train_state, env_next, obs_next, rng), transition
|
||||
|
||||
runner_state, traj_batch = jax.lax.scan(
|
||||
_env_step,
|
||||
runner_state,
|
||||
None,
|
||||
length=cfg["num_steps"],
|
||||
)
|
||||
|
||||
train_state, env_state, last_obs, rng = runner_state
|
||||
_, last_value = network.apply(train_state.params, last_obs)
|
||||
|
||||
def _compute_gae(traj_batch, last_value):
|
||||
def _gae_step(carry, transition):
|
||||
gae, next_value = carry
|
||||
delta = (
|
||||
transition.reward
|
||||
+ cfg["gamma"] * next_value * (1.0 - transition.done)
|
||||
- transition.value
|
||||
)
|
||||
gae = (
|
||||
delta
|
||||
+ cfg["gamma"]
|
||||
* cfg["gae_lambda"]
|
||||
* (1.0 - transition.done)
|
||||
* gae
|
||||
)
|
||||
return (gae, transition.value), gae
|
||||
|
||||
_, advantages = jax.lax.scan(
|
||||
_gae_step,
|
||||
(jnp.zeros_like(last_value), last_value),
|
||||
traj_batch,
|
||||
reverse=True,
|
||||
unroll=16,
|
||||
)
|
||||
targets = advantages + traj_batch.value
|
||||
return advantages, targets
|
||||
|
||||
advantages, targets = _compute_gae(traj_batch, last_value)
|
||||
|
||||
def _update_epoch(update_state, _):
|
||||
def _update_minibatch(train_state, batch_info):
|
||||
traj_b, adv_b, tgt_b = batch_info
|
||||
|
||||
def _loss_fn(params, traj_b, adv_b, tgt_b):
|
||||
policy, value = network.apply(params, traj_b.obs)
|
||||
log_prob = policy.log_prob(traj_b.action)
|
||||
|
||||
value_clipped = traj_b.value + (value - traj_b.value).clip(
|
||||
-cfg["clip_range"], cfg["clip_range"]
|
||||
)
|
||||
value_loss = (
|
||||
0.5
|
||||
* jnp.maximum(
|
||||
jnp.square(value - tgt_b),
|
||||
jnp.square(value_clipped - tgt_b),
|
||||
).mean()
|
||||
)
|
||||
|
||||
adv_norm = (adv_b - adv_b.mean()) / (adv_b.std() + 1e-8)
|
||||
ratio = jnp.exp(log_prob - traj_b.log_prob)
|
||||
loss_actor = -jnp.minimum(
|
||||
ratio * adv_norm,
|
||||
jnp.clip(
|
||||
ratio,
|
||||
1.0 - cfg["clip_range"],
|
||||
1.0 + cfg["clip_range"],
|
||||
)
|
||||
* adv_norm,
|
||||
).mean()
|
||||
entropy = policy.entropy().mean()
|
||||
total_loss = (
|
||||
loss_actor
|
||||
+ cfg["vf_coef"] * value_loss
|
||||
- cfg["ent_coef"] * entropy
|
||||
)
|
||||
return total_loss, (value_loss, loss_actor, entropy)
|
||||
|
||||
grad_fn = jax.value_and_grad(_loss_fn, has_aux=True)
|
||||
(_, _), grads = grad_fn(train_state.params, traj_b, adv_b, tgt_b)
|
||||
train_state = train_state.apply_gradients(grads=grads)
|
||||
return train_state, jnp.asarray(0.0, dtype=jnp.float32)
|
||||
|
||||
train_state, traj_batch, advantages, targets, rng = update_state
|
||||
rng, perm_key = jax.random.split(rng)
|
||||
batch_size = cfg["num_envs"] * cfg["num_steps"]
|
||||
permutation = jax.random.permutation(perm_key, batch_size)
|
||||
batch = (traj_batch, advantages, targets)
|
||||
batch = jax.tree_util.tree_map(
|
||||
lambda x: x.reshape((batch_size,) + x.shape[2:]),
|
||||
batch,
|
||||
)
|
||||
shuffled = jax.tree_util.tree_map(
|
||||
lambda x: jnp.take(x, permutation, axis=0),
|
||||
batch,
|
||||
)
|
||||
minibatches = jax.tree_util.tree_map(
|
||||
lambda x: x.reshape(
|
||||
(cfg["num_minibatches"], cfg["minibatch_size"]) + x.shape[1:]
|
||||
),
|
||||
shuffled,
|
||||
)
|
||||
train_state, _ = jax.lax.scan(
|
||||
_update_minibatch, train_state, minibatches
|
||||
)
|
||||
return (train_state, traj_batch, advantages, targets, rng), None
|
||||
|
||||
update_state = (train_state, traj_batch, advantages, targets, rng)
|
||||
update_state, _ = jax.lax.scan(
|
||||
_update_epoch,
|
||||
update_state,
|
||||
None,
|
||||
length=cfg["update_epochs"],
|
||||
)
|
||||
train_state = update_state[0]
|
||||
rng = update_state[-1]
|
||||
|
||||
metric = {
|
||||
"reward": jnp.mean(traj_batch.reward),
|
||||
"revenue": jnp.mean(traj_batch.info["revenue"]),
|
||||
"agent_prob": jnp.mean(traj_batch.info["agent_prob"]),
|
||||
"alpha_adv": jnp.mean(traj_batch.info["alpha_adv"]),
|
||||
"coi_leakage": jnp.mean(traj_batch.info["coi_leakage"]),
|
||||
}
|
||||
runner_state = (train_state, env_state, last_obs, rng)
|
||||
return runner_state, metric
|
||||
|
||||
runner_state = (train_state, env_state, obs, rng)
|
||||
runner_state, metric = jax.lax.scan(
|
||||
_update_step,
|
||||
runner_state,
|
||||
None,
|
||||
length=cfg["num_updates"],
|
||||
)
|
||||
return {
|
||||
"runner_state": runner_state,
|
||||
"metrics": metric,
|
||||
}
|
||||
|
||||
return train, network, env, cfg
|
||||
|
||||
|
||||
def evaluate_policy(
|
||||
*,
|
||||
network: ActorCritic,
|
||||
params: Any,
|
||||
env: PHANTOMJAXEnv,
|
||||
episodes: int,
|
||||
seed: int,
|
||||
) -> dict[str, float]:
|
||||
rewards: list[float] = []
|
||||
revenues: list[float] = []
|
||||
key = jax.random.PRNGKey(seed)
|
||||
|
||||
for _ in range(int(episodes)):
|
||||
key, reset_key = jax.random.split(key)
|
||||
obs, state = env.reset(reset_key)
|
||||
ep_reward = 0.0
|
||||
ep_revenue = 0.0
|
||||
done = False
|
||||
steps = 0
|
||||
|
||||
while not done and steps < int(env.params.max_episode_steps):
|
||||
policy, _ = network.apply(params, obs)
|
||||
action = jnp.argmax(policy.logits)
|
||||
key, step_key = jax.random.split(key)
|
||||
obs, state, reward, done_flag, info = env.step(step_key, state, action)
|
||||
ep_reward += float(np.asarray(reward))
|
||||
ep_revenue += float(np.asarray(info["revenue"]))
|
||||
done = bool(np.asarray(done_flag))
|
||||
steps += 1
|
||||
|
||||
rewards.append(ep_reward)
|
||||
revenues.append(ep_revenue)
|
||||
|
||||
return {
|
||||
"eval/reward": float(np.mean(rewards)),
|
||||
"eval/revenue": float(np.mean(revenues)),
|
||||
"eval/reward_std": float(np.std(rewards)),
|
||||
"eval/revenue_std": float(np.std(revenues)),
|
||||
}
|
||||
|
||||
|
||||
def train_jax(cfg: dict[str, Any]) -> tuple[dict[str, Any], dict[str, float]]:
|
||||
if not HAS_JAX_STACK:
|
||||
raise ImportError(
|
||||
"JAX PPO path requires jax, flax, optax, and distrax. "
|
||||
"Install engine/jax/requirements.txt on this machine first."
|
||||
)
|
||||
|
||||
run_cfg = _jax_cfg(cfg)
|
||||
if run_cfg["algo"] != "ppo":
|
||||
raise ValueError(
|
||||
f"JAX backend currently supports algo='ppo' only, got '{run_cfg['algo']}'"
|
||||
)
|
||||
|
||||
train_fn, network, env, run_cfg = make_train(run_cfg)
|
||||
train_jit = jax.jit(train_fn)
|
||||
rng = jax.random.PRNGKey(run_cfg["seed"])
|
||||
out = train_jit(rng)
|
||||
|
||||
train_state = out["runner_state"][0]
|
||||
metric = out["metrics"]
|
||||
metrics = {
|
||||
"train/reward": float(np.mean(np.asarray(metric["reward"]))),
|
||||
"train/revenue": float(np.mean(np.asarray(metric["revenue"]))),
|
||||
"train/agent_prob": float(np.mean(np.asarray(metric["agent_prob"]))),
|
||||
"train/alpha_adv": float(np.mean(np.asarray(metric["alpha_adv"]))),
|
||||
"train/coi_leakage": float(np.mean(np.asarray(metric["coi_leakage"]))),
|
||||
"train/global_step": int(
|
||||
run_cfg["num_updates"] * run_cfg["num_steps"] * run_cfg["num_envs"]
|
||||
),
|
||||
}
|
||||
|
||||
eval_metrics = evaluate_policy(
|
||||
network=network,
|
||||
params=train_state.params,
|
||||
env=env,
|
||||
episodes=run_cfg["eval_episodes"],
|
||||
seed=run_cfg["seed"] + 7,
|
||||
)
|
||||
metrics.update(eval_metrics)
|
||||
|
||||
model_dir = Path(run_cfg["model_dir"])
|
||||
model_dir.mkdir(parents=True, exist_ok=True)
|
||||
model_path = model_dir / "phantom_ppo_jax.msgpack"
|
||||
model_path.write_bytes(serialization.to_bytes(train_state.params))
|
||||
metrics["model/path"] = str(model_path)
|
||||
return {"params": train_state.params}, metrics
|
||||
Reference in New Issue
Block a user