mirror of
https://github.com/velocitatem/PHANTOM.git
synced 2026-05-31 16:43:36 +00:00
1305 lines
45 KiB
Python
1305 lines
45 KiB
Python
"""Pure JAX trainers for PHANTOM environment."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from pathlib import Path
|
|
from typing import Any, NamedTuple
|
|
|
|
import signal
|
|
import threading
|
|
|
|
import numpy as np
|
|
|
|
_stop_requested = threading.Event()
|
|
_jax_dist_initialized = False
|
|
|
|
|
|
def _init_jax_distributed() -> None:
|
|
"""Initialize JAX distributed if running on a multi-host TPU pod.
|
|
Safe to call multiple times; no-op after first successful init or when JAX unavailable."""
|
|
global _jax_dist_initialized
|
|
if _jax_dist_initialized:
|
|
return
|
|
_jax_dist_initialized = True
|
|
try:
|
|
import jax as _jax
|
|
|
|
_jax.distributed.initialize()
|
|
except Exception:
|
|
pass
|
|
|
|
|
|
try:
|
|
import wandb
|
|
|
|
HAS_WANDB = True
|
|
except ImportError:
|
|
HAS_WANDB = False
|
|
|
|
from ..wandb_checkpoint import (
|
|
checkpoint_artifact_name,
|
|
download_latest_checkpoint,
|
|
log_checkpoint_bytes,
|
|
)
|
|
|
|
try:
|
|
import jax
|
|
import jax.numpy as jnp
|
|
import distrax
|
|
import flax.linen as nn
|
|
import optax
|
|
from flax import serialization
|
|
from flax.linen.initializers import constant, orthogonal
|
|
from flax.training.train_state import TrainState
|
|
|
|
HAS_JAX_STACK = True
|
|
except ImportError:
|
|
jax = None # type: ignore[assignment]
|
|
jnp = None # type: ignore[assignment]
|
|
distrax = None # type: ignore[assignment]
|
|
optax = None # type: ignore[assignment]
|
|
serialization = None # type: ignore[assignment]
|
|
|
|
class _ModuleStub:
|
|
pass
|
|
|
|
class _NNStub:
|
|
Module = _ModuleStub
|
|
|
|
@staticmethod
|
|
def compact(fn):
|
|
return fn
|
|
|
|
nn = _NNStub() # type: ignore[assignment]
|
|
|
|
def constant(*_args, **_kwargs): # type: ignore[override]
|
|
return None
|
|
|
|
def orthogonal(*_args, **_kwargs): # type: ignore[override]
|
|
return None
|
|
|
|
class TrainState: # type: ignore[override]
|
|
pass
|
|
|
|
HAS_JAX_STACK = False
|
|
|
|
from .env import PHANTOMJAXEnv, make_env_params
|
|
|
|
|
|
class ActorCritic(nn.Module):
|
|
action_dim: int
|
|
activation: str = "tanh"
|
|
|
|
@nn.compact
|
|
def __call__(self, x):
|
|
activation_fn = nn.relu if self.activation == "relu" else nn.tanh
|
|
|
|
actor = nn.Dense(
|
|
64,
|
|
kernel_init=orthogonal(np.sqrt(2.0)),
|
|
bias_init=constant(0.0),
|
|
)(x)
|
|
actor = activation_fn(actor)
|
|
actor = nn.Dense(
|
|
64,
|
|
kernel_init=orthogonal(np.sqrt(2.0)),
|
|
bias_init=constant(0.0),
|
|
)(actor)
|
|
actor = activation_fn(actor)
|
|
logits = nn.Dense(
|
|
self.action_dim,
|
|
kernel_init=orthogonal(0.01),
|
|
bias_init=constant(0.0),
|
|
)(actor)
|
|
|
|
critic = nn.Dense(
|
|
64,
|
|
kernel_init=orthogonal(np.sqrt(2.0)),
|
|
bias_init=constant(0.0),
|
|
)(x)
|
|
critic = activation_fn(critic)
|
|
critic = nn.Dense(
|
|
64,
|
|
kernel_init=orthogonal(np.sqrt(2.0)),
|
|
bias_init=constant(0.0),
|
|
)(critic)
|
|
critic = activation_fn(critic)
|
|
value = nn.Dense(1, kernel_init=orthogonal(1.0), bias_init=constant(0.0))(
|
|
critic
|
|
)
|
|
return distrax.Categorical(logits=logits), jnp.squeeze(value, axis=-1)
|
|
|
|
|
|
class QNetwork(nn.Module):
|
|
action_dim: int
|
|
activation: str = "relu"
|
|
|
|
@nn.compact
|
|
def __call__(self, x):
|
|
activation_fn = nn.relu if self.activation == "relu" else nn.tanh
|
|
x = nn.Dense(
|
|
128,
|
|
kernel_init=orthogonal(np.sqrt(2.0)),
|
|
bias_init=constant(0.0),
|
|
)(x)
|
|
x = activation_fn(x)
|
|
x = nn.Dense(
|
|
128,
|
|
kernel_init=orthogonal(np.sqrt(2.0)),
|
|
bias_init=constant(0.0),
|
|
)(x)
|
|
x = activation_fn(x)
|
|
q_values = nn.Dense(
|
|
self.action_dim,
|
|
kernel_init=orthogonal(1.0),
|
|
bias_init=constant(0.0),
|
|
)(x)
|
|
return q_values
|
|
|
|
|
|
class Transition(NamedTuple):
|
|
done: jax.Array
|
|
action: jax.Array
|
|
value: jax.Array
|
|
reward: jax.Array
|
|
log_prob: jax.Array
|
|
obs: jax.Array
|
|
info: dict[str, jax.Array]
|
|
|
|
|
|
class ReplayBatch(NamedTuple):
|
|
obs: jax.Array
|
|
actions: jax.Array
|
|
rewards: jax.Array
|
|
next_obs: jax.Array
|
|
dones: jax.Array
|
|
|
|
|
|
class ReplayBuffer(NamedTuple):
|
|
obs: jax.Array
|
|
actions: jax.Array
|
|
rewards: jax.Array
|
|
next_obs: jax.Array
|
|
dones: jax.Array
|
|
ptr: jax.Array
|
|
size: jax.Array
|
|
|
|
|
|
def _jax_cfg(cfg: dict[str, Any]) -> dict[str, Any]:
|
|
out = {
|
|
"algo": str(cfg.get("algo", "ppo")).lower(),
|
|
"seed": int(cfg.get("seed", 42)),
|
|
"learning_rate": float(cfg.get("learning_rate", 3e-4)),
|
|
"gamma": float(cfg.get("gamma", 0.99)),
|
|
"gae_lambda": float(cfg.get("gae_lambda", 0.95)),
|
|
"clip_range": float(cfg.get("clip_range", 0.2)),
|
|
"ent_coef": float(cfg.get("ent_coef", 0.01)),
|
|
"vf_coef": float(cfg.get("vf_coef", 0.5)),
|
|
"max_grad_norm": float(cfg.get("max_grad_norm", 0.5)),
|
|
"activation": str(cfg.get("activation", "relu")),
|
|
"total_timesteps": int(cfg.get("total_timesteps", 50_000)),
|
|
"eval_episodes": int(cfg.get("eval_episodes", 5)),
|
|
"model_dir": str(cfg.get("model_dir", "engine/models")),
|
|
"log_freq": int(cfg.get("log_freq", 100)),
|
|
"n_products": int(cfg.get("n_products", 10)),
|
|
"N": int(cfg.get("N", 100)),
|
|
"alpha": float(cfg.get("alpha", 0.3)),
|
|
"lambda_coi": float(cfg.get("lambda_coi", 0.2)),
|
|
"robust_radius": float(cfg.get("robust_radius", 0.15)),
|
|
"robust_points": int(cfg.get("robust_points", 5)),
|
|
"info_value": float(cfg.get("info_value", 1.0)),
|
|
"price_low": float(cfg.get("price_low", 10.0)),
|
|
"price_high": float(cfg.get("price_high", 150.0)),
|
|
"action_levels": int(cfg.get("action_levels", 9)),
|
|
"action_scale_low": float(cfg.get("action_scale_low", 0.8)),
|
|
"action_scale_high": float(cfg.get("action_scale_high", 1.2)),
|
|
"max_episode_steps": int(cfg.get("max_steps", 100)),
|
|
"max_session_steps": int(cfg.get("max_session_steps", 40)),
|
|
"margin_floor": float(cfg.get("margin_floor", 0.05)),
|
|
"margin_floor_patience": int(cfg.get("margin_floor_patience", 5)),
|
|
"prefer_behavior_data": bool(cfg.get("prefer_behavior_data", True)),
|
|
"num_envs": int(cfg.get("jax_num_envs", 16)),
|
|
"num_steps": int(cfg.get("jax_num_steps", 128)),
|
|
"num_minibatches": int(cfg.get("jax_num_minibatches", 4)),
|
|
"update_epochs": int(cfg.get("jax_update_epochs", 4)),
|
|
"anneal_lr": bool(cfg.get("jax_anneal_lr", True)),
|
|
"checkpoint_interval": int(cfg.get("checkpoint_interval", 10_000)),
|
|
"buffer_size": int(cfg.get("buffer_size", 50_000)),
|
|
"batch_size": int(cfg.get("batch_size", 256)),
|
|
"train_freq": int(cfg.get("train_freq", 1)),
|
|
"learning_starts": int(cfg.get("learning_starts", 1_000)),
|
|
"target_update_interval": int(cfg.get("target_update_interval", 1_000)),
|
|
"exploration_fraction": float(cfg.get("exploration_fraction", 0.2)),
|
|
"exploration_final_eps": float(cfg.get("exploration_final_eps", 0.05)),
|
|
"eps_start": float(cfg.get("eps_start", 1.0)),
|
|
"eps_end": float(cfg.get("eps_end", 0.05)),
|
|
"eps_decay": float(cfg.get("eps_decay", 0.9995)),
|
|
"q_lr": float(cfg.get("q_lr", 0.1)),
|
|
"q_bins": int(cfg.get("q_bins", 6)),
|
|
}
|
|
rollout = out["num_envs"] * out["num_steps"]
|
|
out["num_updates"] = max(1, out["total_timesteps"] // max(rollout, 1))
|
|
out["minibatch_size"] = max(1, rollout // max(out["num_minibatches"], 1))
|
|
return out
|
|
|
|
|
|
def _scalar(value: Any) -> float:
|
|
return float(np.asarray(value))
|
|
|
|
|
|
def _scalar_int(value: Any) -> int:
|
|
return int(np.asarray(value))
|
|
|
|
|
|
def _make_env(cfg: dict[str, Any]) -> PHANTOMJAXEnv:
|
|
env_params = make_env_params(
|
|
n_products=cfg["n_products"],
|
|
alpha=cfg["alpha"],
|
|
n_sessions=cfg["N"],
|
|
lambda_coi=cfg["lambda_coi"],
|
|
robust_radius=cfg["robust_radius"],
|
|
robust_points=cfg["robust_points"],
|
|
info_value=cfg["info_value"],
|
|
action_levels=cfg["action_levels"],
|
|
action_scale_low=cfg["action_scale_low"],
|
|
action_scale_high=cfg["action_scale_high"],
|
|
price_low=cfg["price_low"],
|
|
price_high=cfg["price_high"],
|
|
max_episode_steps=cfg["max_episode_steps"],
|
|
max_session_steps=cfg["max_session_steps"],
|
|
margin_floor=cfg["margin_floor"],
|
|
margin_floor_patience=cfg["margin_floor_patience"],
|
|
prefer_behavior_data=cfg["prefer_behavior_data"],
|
|
)
|
|
return PHANTOMJAXEnv(env_params)
|
|
|
|
|
|
def _select_env_state(done: jax.Array, keep: jax.Array, reset: jax.Array) -> jax.Array:
|
|
mask = done
|
|
while mask.ndim < keep.ndim:
|
|
mask = mask[..., None]
|
|
return jnp.where(mask, reset, keep)
|
|
|
|
|
|
def _epsilon_by_fraction(step: int, cfg: dict[str, Any]) -> float:
|
|
start = float(cfg["eps_start"])
|
|
end = float(cfg["exploration_final_eps"])
|
|
frac = float(cfg["exploration_fraction"])
|
|
total = max(1, int(cfg["total_timesteps"]))
|
|
decay_steps = max(1, int(total * frac))
|
|
if step >= decay_steps:
|
|
return end
|
|
slope = (end - start) / decay_steps
|
|
return float(start + slope * step)
|
|
|
|
|
|
def _digitize_scalar(value: jax.Array, bins: jax.Array) -> jax.Array:
|
|
return jnp.sum(value > bins).astype(jnp.int32)
|
|
|
|
|
|
def _encode_qtable_state(
|
|
obs: jax.Array,
|
|
*,
|
|
n_products: int,
|
|
demand_bins: jax.Array,
|
|
price_bins: jax.Array,
|
|
) -> tuple[jax.Array, jax.Array, jax.Array]:
|
|
demand = obs[:n_products]
|
|
prices = obs[n_products : 2 * n_products]
|
|
d_mean = jnp.mean(demand)
|
|
d_std = jnp.std(demand)
|
|
p_mean = jnp.mean(prices)
|
|
return (
|
|
_digitize_scalar(d_mean, demand_bins),
|
|
_digitize_scalar(d_std, demand_bins),
|
|
_digitize_scalar(p_mean, price_bins),
|
|
)
|
|
|
|
|
|
def _init_replay_buffer(capacity: int, obs_dim: int) -> ReplayBuffer:
|
|
cap = max(1, int(capacity))
|
|
return ReplayBuffer(
|
|
obs=jnp.zeros((cap, obs_dim), dtype=jnp.float32),
|
|
actions=jnp.zeros((cap,), dtype=jnp.int32),
|
|
rewards=jnp.zeros((cap,), dtype=jnp.float32),
|
|
next_obs=jnp.zeros((cap, obs_dim), dtype=jnp.float32),
|
|
dones=jnp.zeros((cap,), dtype=jnp.float32),
|
|
ptr=jnp.asarray(0, dtype=jnp.int32),
|
|
size=jnp.asarray(0, dtype=jnp.int32),
|
|
)
|
|
|
|
|
|
def _replay_size(buffer: ReplayBuffer) -> int:
|
|
return _scalar_int(buffer.size)
|
|
|
|
|
|
def _replay_add(
|
|
buffer: ReplayBuffer,
|
|
obs: jax.Array,
|
|
action: jax.Array,
|
|
reward: jax.Array,
|
|
next_obs: jax.Array,
|
|
done: jax.Array,
|
|
) -> ReplayBuffer:
|
|
capacity = int(buffer.obs.shape[0])
|
|
idx = buffer.ptr % capacity
|
|
return ReplayBuffer(
|
|
obs=buffer.obs.at[idx].set(obs.astype(jnp.float32)),
|
|
actions=buffer.actions.at[idx].set(action.astype(jnp.int32)),
|
|
rewards=buffer.rewards.at[idx].set(reward.astype(jnp.float32)),
|
|
next_obs=buffer.next_obs.at[idx].set(next_obs.astype(jnp.float32)),
|
|
dones=buffer.dones.at[idx].set(done.astype(jnp.float32)),
|
|
ptr=buffer.ptr + 1,
|
|
size=jnp.minimum(buffer.size + 1, jnp.asarray(capacity, dtype=jnp.int32)),
|
|
)
|
|
|
|
|
|
def _replay_sample(
|
|
buffer: ReplayBuffer, key: jax.Array, batch_size: int
|
|
) -> ReplayBatch:
|
|
size = jnp.maximum(buffer.size, 1)
|
|
idx = jax.random.randint(key, shape=(batch_size,), minval=0, maxval=size)
|
|
return ReplayBatch(
|
|
obs=buffer.obs[idx],
|
|
actions=buffer.actions[idx],
|
|
rewards=buffer.rewards[idx],
|
|
next_obs=buffer.next_obs[idx],
|
|
dones=buffer.dones[idx],
|
|
)
|
|
|
|
|
|
def _make_actor_critic_train(
|
|
config: dict[str, Any], *, algo: str, use_pmap: bool = False
|
|
):
|
|
cfg = dict(config)
|
|
cfg["algo"] = algo
|
|
env = _make_env(cfg)
|
|
network = ActorCritic(env.action_space_n(), activation=cfg["activation"])
|
|
|
|
def linear_schedule(count: jax.Array) -> jax.Array:
|
|
updates_done = count // (cfg["num_minibatches"] * cfg["update_epochs"])
|
|
frac = 1.0 - updates_done / max(cfg["num_updates"], 1)
|
|
return cfg["learning_rate"] * frac
|
|
|
|
if cfg["anneal_lr"]:
|
|
tx = optax.chain(
|
|
optax.clip_by_global_norm(cfg["max_grad_norm"]),
|
|
optax.adam(learning_rate=linear_schedule, eps=1e-5),
|
|
)
|
|
else:
|
|
tx = optax.chain(
|
|
optax.clip_by_global_norm(cfg["max_grad_norm"]),
|
|
optax.adam(cfg["learning_rate"], eps=1e-5),
|
|
)
|
|
|
|
def init_runner_state(rng: jax.Array):
|
|
rng, init_key = jax.random.split(rng)
|
|
init_obs = jnp.zeros((env.observation_dim(),), dtype=jnp.float32)
|
|
params = network.init(init_key, init_obs)
|
|
train_state = TrainState.create(apply_fn=network.apply, params=params, tx=tx)
|
|
|
|
rng, reset_key = jax.random.split(rng)
|
|
reset_keys = jax.random.split(reset_key, cfg["num_envs"])
|
|
obs, env_state = jax.vmap(env.reset)(reset_keys)
|
|
return train_state, env_state, obs, rng
|
|
|
|
def _update_step(runner_state, _):
|
|
def _env_step(runner_state, _):
|
|
train_state, env_state, last_obs, rng = runner_state
|
|
rng, action_key = jax.random.split(rng)
|
|
policy, value = network.apply(train_state.params, last_obs)
|
|
action = policy.sample(seed=action_key)
|
|
log_prob = policy.log_prob(action)
|
|
|
|
rng, step_key = jax.random.split(rng)
|
|
step_keys = jax.random.split(step_key, cfg["num_envs"])
|
|
nxt_obs, nxt_state, reward, done, info = jax.vmap(
|
|
env.step,
|
|
in_axes=(0, 0, 0),
|
|
)(step_keys, env_state, action)
|
|
|
|
rng, reset_key = jax.random.split(rng)
|
|
reset_keys = jax.random.split(reset_key, cfg["num_envs"])
|
|
rst_obs, rst_state = jax.vmap(env.reset)(reset_keys)
|
|
obs_next = jnp.where(done[:, None], rst_obs, nxt_obs)
|
|
env_next = jax.tree_util.tree_map(
|
|
lambda keep, reset: _select_env_state(done, keep, reset),
|
|
nxt_state,
|
|
rst_state,
|
|
)
|
|
transition = Transition(
|
|
done=done,
|
|
action=action,
|
|
value=value,
|
|
reward=reward,
|
|
log_prob=log_prob,
|
|
obs=last_obs,
|
|
info=info,
|
|
)
|
|
return (train_state, env_next, obs_next, rng), transition
|
|
|
|
runner_state, traj_batch = jax.lax.scan(
|
|
_env_step,
|
|
runner_state,
|
|
None,
|
|
length=cfg["num_steps"],
|
|
)
|
|
|
|
train_state, env_state, last_obs, rng = runner_state
|
|
_, last_value = network.apply(train_state.params, last_obs)
|
|
|
|
def _compute_gae(traj_batch, last_value):
|
|
def _gae_step(carry, transition):
|
|
gae, next_value = carry
|
|
delta = (
|
|
transition.reward
|
|
+ cfg["gamma"] * next_value * (1.0 - transition.done)
|
|
- transition.value
|
|
)
|
|
gae = (
|
|
delta
|
|
+ cfg["gamma"] * cfg["gae_lambda"] * (1.0 - transition.done) * gae
|
|
)
|
|
return (gae, transition.value), gae
|
|
|
|
_, advantages = jax.lax.scan(
|
|
_gae_step,
|
|
(jnp.zeros_like(last_value), last_value),
|
|
traj_batch,
|
|
reverse=True,
|
|
unroll=16,
|
|
)
|
|
targets = advantages + traj_batch.value
|
|
return advantages, targets
|
|
|
|
advantages, targets = _compute_gae(traj_batch, last_value)
|
|
|
|
def _update_epoch(update_state, _):
|
|
def _update_minibatch(train_state, batch_info):
|
|
traj_b, adv_b, tgt_b = batch_info
|
|
|
|
def _loss_fn(params, traj_b, adv_b, tgt_b):
|
|
policy, value = network.apply(params, traj_b.obs)
|
|
log_prob = policy.log_prob(traj_b.action)
|
|
adv_norm = (adv_b - adv_b.mean()) / (adv_b.std() + 1e-8)
|
|
|
|
if algo == "ppo":
|
|
value_clipped = traj_b.value + (value - traj_b.value).clip(
|
|
-cfg["clip_range"], cfg["clip_range"]
|
|
)
|
|
value_loss = (
|
|
0.5
|
|
* jnp.maximum(
|
|
jnp.square(value - tgt_b),
|
|
jnp.square(value_clipped - tgt_b),
|
|
).mean()
|
|
)
|
|
ratio = jnp.exp(log_prob - traj_b.log_prob)
|
|
policy_loss = -jnp.minimum(
|
|
ratio * adv_norm,
|
|
jnp.clip(
|
|
ratio,
|
|
1.0 - cfg["clip_range"],
|
|
1.0 + cfg["clip_range"],
|
|
)
|
|
* adv_norm,
|
|
).mean()
|
|
else:
|
|
value_loss = 0.5 * jnp.mean(jnp.square(value - tgt_b))
|
|
policy_loss = -(log_prob * adv_norm).mean()
|
|
|
|
entropy = policy.entropy().mean()
|
|
total_loss = (
|
|
policy_loss
|
|
+ cfg["vf_coef"] * value_loss
|
|
- cfg["ent_coef"] * entropy
|
|
)
|
|
return total_loss, (value_loss, policy_loss, entropy)
|
|
|
|
grad_fn = jax.value_and_grad(_loss_fn, has_aux=True)
|
|
(_, _), grads = grad_fn(train_state.params, traj_b, adv_b, tgt_b)
|
|
if use_pmap:
|
|
grads = jax.lax.pmean(grads, axis_name="devices")
|
|
train_state = train_state.apply_gradients(grads=grads)
|
|
return train_state, jnp.asarray(0.0, dtype=jnp.float32)
|
|
|
|
train_state, traj_batch, advantages, targets, rng = update_state
|
|
rng, perm_key = jax.random.split(rng)
|
|
batch_size = cfg["num_envs"] * cfg["num_steps"]
|
|
permutation = jax.random.permutation(perm_key, batch_size)
|
|
|
|
batch = (traj_batch, advantages, targets)
|
|
batch = jax.tree_util.tree_map(
|
|
lambda x: x.reshape((batch_size,) + x.shape[2:]),
|
|
batch,
|
|
)
|
|
shuffled = jax.tree_util.tree_map(
|
|
lambda x: jnp.take(x, permutation, axis=0),
|
|
batch,
|
|
)
|
|
minibatches = jax.tree_util.tree_map(
|
|
lambda x: x.reshape(
|
|
(cfg["num_minibatches"], cfg["minibatch_size"]) + x.shape[1:]
|
|
),
|
|
shuffled,
|
|
)
|
|
train_state, _ = jax.lax.scan(_update_minibatch, train_state, minibatches)
|
|
return (train_state, traj_batch, advantages, targets, rng), None
|
|
|
|
update_state = (train_state, traj_batch, advantages, targets, rng)
|
|
update_state, _ = jax.lax.scan(
|
|
_update_epoch,
|
|
update_state,
|
|
None,
|
|
length=cfg["update_epochs"],
|
|
)
|
|
train_state = update_state[0]
|
|
rng = update_state[-1]
|
|
|
|
metric = {
|
|
"reward": jnp.mean(traj_batch.reward),
|
|
"revenue": jnp.mean(traj_batch.info["revenue"]),
|
|
"agent_prob": jnp.mean(traj_batch.info["agent_prob"]),
|
|
"alpha_adv": jnp.mean(traj_batch.info["alpha_adv"]),
|
|
"coi_leakage": jnp.mean(traj_batch.info["coi_leakage"]),
|
|
}
|
|
next_runner_state = (train_state, env_state, last_obs, rng)
|
|
return next_runner_state, metric
|
|
|
|
def run_updates(runner_state, num_updates: int):
|
|
updates = max(1, int(num_updates))
|
|
runner_state, metric = jax.lax.scan(
|
|
_update_step,
|
|
runner_state,
|
|
None,
|
|
length=updates,
|
|
)
|
|
return {
|
|
"runner_state": runner_state,
|
|
"metrics": metric,
|
|
}
|
|
|
|
return init_runner_state, run_updates, network, env, cfg
|
|
|
|
|
|
def make_train(config: dict[str, Any]):
|
|
cfg = _jax_cfg(config)
|
|
algo = cfg["algo"]
|
|
if algo not in {"ppo", "a2c"}:
|
|
raise ValueError(f"make_train supports actor-critic algos only, got '{algo}'")
|
|
return _make_actor_critic_train(cfg, algo=algo)
|
|
|
|
|
|
def evaluate_policy(
|
|
*,
|
|
network: ActorCritic,
|
|
params: Any,
|
|
env: PHANTOMJAXEnv,
|
|
episodes: int,
|
|
seed: int,
|
|
) -> dict[str, float]:
|
|
rewards: list[float] = []
|
|
revenues: list[float] = []
|
|
key = jax.random.PRNGKey(seed)
|
|
|
|
for _ in range(int(episodes)):
|
|
key, reset_key = jax.random.split(key)
|
|
obs, state = env.reset(reset_key)
|
|
ep_reward = 0.0
|
|
ep_revenue = 0.0
|
|
done = False
|
|
steps = 0
|
|
|
|
while not done and steps < int(env.params.max_episode_steps):
|
|
policy, _ = network.apply(params, obs)
|
|
action = jnp.argmax(policy.logits)
|
|
key, step_key = jax.random.split(key)
|
|
obs, state, reward, done_flag, info = env.step(step_key, state, action)
|
|
ep_reward += _scalar(reward)
|
|
ep_revenue += _scalar(info["revenue"])
|
|
done = bool(np.asarray(done_flag))
|
|
steps += 1
|
|
|
|
rewards.append(ep_reward)
|
|
revenues.append(ep_revenue)
|
|
|
|
return {
|
|
"eval/reward": float(np.mean(rewards)),
|
|
"eval/revenue": float(np.mean(revenues)),
|
|
"eval/reward_std": float(np.std(rewards)),
|
|
"eval/revenue_std": float(np.std(revenues)),
|
|
}
|
|
|
|
|
|
def _evaluate_q_network(
|
|
*,
|
|
network: QNetwork,
|
|
params: Any,
|
|
env: PHANTOMJAXEnv,
|
|
episodes: int,
|
|
seed: int,
|
|
) -> dict[str, float]:
|
|
rewards: list[float] = []
|
|
revenues: list[float] = []
|
|
key = jax.random.PRNGKey(seed)
|
|
|
|
for _ in range(int(episodes)):
|
|
key, reset_key = jax.random.split(key)
|
|
obs, state = env.reset(reset_key)
|
|
ep_reward = 0.0
|
|
ep_revenue = 0.0
|
|
done = False
|
|
steps = 0
|
|
|
|
while not done and steps < int(env.params.max_episode_steps):
|
|
q_values = network.apply(params, obs)
|
|
action = jnp.argmax(q_values)
|
|
key, step_key = jax.random.split(key)
|
|
obs, state, reward, done_flag, info = env.step(step_key, state, action)
|
|
ep_reward += _scalar(reward)
|
|
ep_revenue += _scalar(info["revenue"])
|
|
done = bool(np.asarray(done_flag))
|
|
steps += 1
|
|
|
|
rewards.append(ep_reward)
|
|
revenues.append(ep_revenue)
|
|
|
|
return {
|
|
"eval/reward": float(np.mean(rewards)),
|
|
"eval/revenue": float(np.mean(revenues)),
|
|
"eval/reward_std": float(np.std(rewards)),
|
|
"eval/revenue_std": float(np.std(revenues)),
|
|
}
|
|
|
|
|
|
def _evaluate_q_table(
|
|
*,
|
|
q_table: jax.Array,
|
|
env: PHANTOMJAXEnv,
|
|
episodes: int,
|
|
seed: int,
|
|
n_products: int,
|
|
demand_bins: jax.Array,
|
|
price_bins: jax.Array,
|
|
) -> dict[str, float]:
|
|
rewards: list[float] = []
|
|
revenues: list[float] = []
|
|
key = jax.random.PRNGKey(seed)
|
|
|
|
for _ in range(int(episodes)):
|
|
key, reset_key = jax.random.split(key)
|
|
obs, state = env.reset(reset_key)
|
|
ep_reward = 0.0
|
|
ep_revenue = 0.0
|
|
done = False
|
|
steps = 0
|
|
|
|
while not done and steps < int(env.params.max_episode_steps):
|
|
s0, s1, s2 = _encode_qtable_state(
|
|
obs,
|
|
n_products=n_products,
|
|
demand_bins=demand_bins,
|
|
price_bins=price_bins,
|
|
)
|
|
action = jnp.argmax(q_table[s0, s1, s2])
|
|
key, step_key = jax.random.split(key)
|
|
obs, state, reward, done_flag, info = env.step(step_key, state, action)
|
|
ep_reward += _scalar(reward)
|
|
ep_revenue += _scalar(info["revenue"])
|
|
done = bool(np.asarray(done_flag))
|
|
steps += 1
|
|
|
|
rewards.append(ep_reward)
|
|
revenues.append(ep_revenue)
|
|
|
|
return {
|
|
"eval/reward": float(np.mean(rewards)),
|
|
"eval/revenue": float(np.mean(revenues)),
|
|
"eval/reward_std": float(np.std(rewards)),
|
|
"eval/revenue_std": float(np.std(revenues)),
|
|
}
|
|
|
|
|
|
def _train_actor_critic(
|
|
cfg: dict[str, Any],
|
|
*,
|
|
algo: str,
|
|
) -> tuple[dict[str, Any], dict[str, float]]:
|
|
num_devices = jax.local_device_count()
|
|
use_pmap = num_devices > 1
|
|
|
|
init_runner_state, run_updates_raw, network, env, run_cfg = (
|
|
_make_actor_critic_train(cfg, algo=algo, use_pmap=use_pmap)
|
|
)
|
|
|
|
if use_pmap:
|
|
run_fn = jax.pmap(
|
|
run_updates_raw,
|
|
axis_name="devices",
|
|
static_broadcasted_argnums=(1,),
|
|
devices=jax.local_devices(),
|
|
)
|
|
else:
|
|
run_fn = jax.jit(run_updates_raw, static_argnames=("num_updates",))
|
|
|
|
rollout_steps = int(run_cfg["num_steps"] * run_cfg["num_envs"])
|
|
total_updates = int(run_cfg["num_updates"])
|
|
checkpoint_interval = max(1, int(run_cfg.get("checkpoint_interval", 10_000)))
|
|
segment_updates = max(1, checkpoint_interval // max(rollout_steps, 1))
|
|
|
|
rng = jax.random.PRNGKey(run_cfg["seed"])
|
|
# single-device state used as template for serialization and eval
|
|
single_runner_state = init_runner_state(rng)
|
|
updates_done = 0
|
|
|
|
is_primary = jax.process_index() == 0
|
|
artifact_name = None
|
|
if is_primary and HAS_WANDB and wandb.run is not None:
|
|
sweep_id = getattr(wandb.run, "sweep_id", None)
|
|
artifact_name = checkpoint_artifact_name(
|
|
run_cfg,
|
|
backend="jax",
|
|
sweep_id=sweep_id,
|
|
)
|
|
restored = download_latest_checkpoint(
|
|
artifact_name,
|
|
file_name=f"jax_{algo}_runner_state.msgpack",
|
|
)
|
|
if restored is not None:
|
|
checkpoint_path, metadata = restored
|
|
template = {"runner_state": single_runner_state, "updates_done": 0}
|
|
payload = serialization.from_bytes(template, checkpoint_path.read_bytes())
|
|
single_runner_state = payload["runner_state"]
|
|
updates_done = int(payload.get("updates_done", 0))
|
|
if updates_done <= 0:
|
|
updates_done = int(metadata.get("updates_done", 0))
|
|
updates_done = max(0, min(updates_done, total_updates))
|
|
|
|
if use_pmap:
|
|
runner_state = jax.device_put_replicated(
|
|
single_runner_state, jax.local_devices()
|
|
)
|
|
else:
|
|
runner_state = single_runner_state
|
|
|
|
metric_keys = ["reward", "revenue", "agent_prob", "alpha_adv", "coi_leakage"]
|
|
metric_sums = {k: 0.0 for k in metric_keys}
|
|
metric_count = 0
|
|
|
|
while updates_done < total_updates:
|
|
updates_this_segment = min(segment_updates, total_updates - updates_done)
|
|
if use_pmap:
|
|
out = run_fn(runner_state, updates_this_segment)
|
|
else:
|
|
out = run_fn(runner_state, updates_this_segment)
|
|
runner_state = out["runner_state"]
|
|
metric = out["metrics"]
|
|
|
|
if use_pmap:
|
|
# take device-0 slice; shape is (n_devices, segment_updates)
|
|
segment_values = {
|
|
key: np.asarray(metric[key][0], dtype=np.float64) for key in metric_keys
|
|
}
|
|
else:
|
|
segment_values = {
|
|
key: np.asarray(metric[key], dtype=np.float64) for key in metric_keys
|
|
}
|
|
|
|
segment_count = int(segment_values["reward"].shape[0]) if segment_values else 0
|
|
metric_count += segment_count
|
|
for key in metric_keys:
|
|
metric_sums[key] += float(segment_values[key].sum())
|
|
|
|
updates_done += int(updates_this_segment)
|
|
global_step = int(updates_done * rollout_steps)
|
|
|
|
if is_primary and HAS_WANDB and wandb.run is not None:
|
|
wandb.log(
|
|
{
|
|
"train/reward": float(segment_values["reward"].mean()),
|
|
"train/revenue": float(segment_values["revenue"].mean()),
|
|
"train/agent_prob": float(segment_values["agent_prob"].mean()),
|
|
"train/alpha_adv": float(segment_values["alpha_adv"].mean()),
|
|
"train/coi_leakage": float(segment_values["coi_leakage"].mean()),
|
|
"train/global_step": global_step,
|
|
},
|
|
step=global_step,
|
|
)
|
|
if artifact_name is not None:
|
|
# extract device-0 state for checkpoint portability
|
|
state_to_save = (
|
|
jax.tree_util.tree_map(lambda x: x[0], runner_state)
|
|
if use_pmap
|
|
else runner_state
|
|
)
|
|
checkpoint_payload = serialization.to_bytes(
|
|
{"runner_state": state_to_save, "updates_done": updates_done}
|
|
)
|
|
log_checkpoint_bytes(
|
|
artifact_name,
|
|
file_name=f"jax_{algo}_runner_state.msgpack",
|
|
payload=checkpoint_payload,
|
|
metadata={
|
|
"step": global_step,
|
|
"updates_done": updates_done,
|
|
"rollout_steps": rollout_steps,
|
|
"algo": algo,
|
|
},
|
|
)
|
|
if _stop_requested.is_set():
|
|
break
|
|
|
|
# extract device-0 params for eval and save
|
|
final_runner = (
|
|
jax.tree_util.tree_map(lambda x: x[0], runner_state)
|
|
if use_pmap
|
|
else runner_state
|
|
)
|
|
train_state = final_runner[0]
|
|
denom = float(metric_count) if metric_count > 0 else 1.0
|
|
metrics = {
|
|
"train/reward": float(metric_sums["reward"] / denom),
|
|
"train/revenue": float(metric_sums["revenue"] / denom),
|
|
"train/agent_prob": float(metric_sums["agent_prob"] / denom),
|
|
"train/alpha_adv": float(metric_sums["alpha_adv"] / denom),
|
|
"train/coi_leakage": float(metric_sums["coi_leakage"] / denom),
|
|
"train/global_step": int(updates_done * rollout_steps),
|
|
}
|
|
|
|
eval_metrics = evaluate_policy(
|
|
network=network,
|
|
params=train_state.params,
|
|
env=env,
|
|
episodes=run_cfg["eval_episodes"],
|
|
seed=run_cfg["seed"] + 7,
|
|
)
|
|
metrics.update(eval_metrics)
|
|
|
|
if is_primary:
|
|
model_dir = Path(run_cfg["model_dir"])
|
|
model_dir.mkdir(parents=True, exist_ok=True)
|
|
model_path = model_dir / f"phantom_{algo}_jax.msgpack"
|
|
model_path.write_bytes(serialization.to_bytes(train_state.params))
|
|
metrics["model/path"] = str(model_path)
|
|
|
|
return {"params": train_state.params}, metrics
|
|
|
|
|
|
def _train_dqn(cfg: dict[str, Any]) -> tuple[dict[str, Any], dict[str, float]]:
|
|
run_cfg = dict(cfg)
|
|
env = _make_env(run_cfg)
|
|
action_dim = env.action_space_n()
|
|
obs_dim = env.observation_dim()
|
|
q_net = QNetwork(action_dim=action_dim, activation=run_cfg["activation"])
|
|
|
|
init_obs = jnp.zeros((obs_dim,), dtype=jnp.float32)
|
|
rng = jax.random.PRNGKey(run_cfg["seed"])
|
|
rng, init_key = jax.random.split(rng)
|
|
params = q_net.init(init_key, init_obs)
|
|
tx = optax.adam(run_cfg["learning_rate"])
|
|
train_state = TrainState.create(apply_fn=q_net.apply, params=params, tx=tx)
|
|
target_params = train_state.params
|
|
|
|
buffer = _init_replay_buffer(run_cfg["buffer_size"], obs_dim)
|
|
|
|
rng, reset_key = jax.random.split(rng)
|
|
obs, env_state = env.reset(reset_key)
|
|
|
|
start_step = 0
|
|
epsilon_value = float(run_cfg["eps_start"])
|
|
artifact_name = None
|
|
|
|
if HAS_WANDB and wandb.run is not None:
|
|
sweep_id = getattr(wandb.run, "sweep_id", None)
|
|
artifact_name = checkpoint_artifact_name(
|
|
run_cfg,
|
|
backend="jax",
|
|
sweep_id=sweep_id,
|
|
)
|
|
restored = download_latest_checkpoint(
|
|
artifact_name,
|
|
file_name="jax_dqn_state.msgpack",
|
|
)
|
|
if restored is not None:
|
|
checkpoint_path, metadata = restored
|
|
template = {
|
|
"params": train_state.params,
|
|
"target_params": target_params,
|
|
"opt_state": train_state.opt_state,
|
|
"global_step": 0,
|
|
"epsilon": epsilon_value,
|
|
}
|
|
payload = serialization.from_bytes(template, checkpoint_path.read_bytes())
|
|
train_state = train_state.replace(
|
|
params=payload["params"],
|
|
opt_state=payload["opt_state"],
|
|
)
|
|
target_params = payload["target_params"]
|
|
start_step = int(payload.get("global_step", metadata.get("step", 0)))
|
|
start_step = max(0, min(start_step, int(run_cfg["total_timesteps"])))
|
|
epsilon_value = float(payload.get("epsilon", epsilon_value))
|
|
|
|
@jax.jit
|
|
def dqn_update(
|
|
state: TrainState,
|
|
target: Any,
|
|
batch: ReplayBatch,
|
|
) -> tuple[TrainState, jax.Array]:
|
|
def loss_fn(model_params):
|
|
q_values = q_net.apply(model_params, batch.obs)
|
|
chosen = jnp.take_along_axis(
|
|
q_values,
|
|
batch.actions[:, None],
|
|
axis=1,
|
|
).squeeze(-1)
|
|
next_q = q_net.apply(target, batch.next_obs)
|
|
next_max = jnp.max(next_q, axis=1)
|
|
td_target = (
|
|
batch.rewards + run_cfg["gamma"] * (1.0 - batch.dones) * next_max
|
|
)
|
|
td_error = chosen - jax.lax.stop_gradient(td_target)
|
|
return jnp.mean(jnp.square(td_error))
|
|
|
|
loss, grads = jax.value_and_grad(loss_fn)(state.params)
|
|
next_state = state.apply_gradients(grads=grads)
|
|
return next_state, loss
|
|
|
|
metric_sums = {
|
|
"reward": 0.0,
|
|
"revenue": 0.0,
|
|
"agent_prob": 0.0,
|
|
"alpha_adv": 0.0,
|
|
"coi_leakage": 0.0,
|
|
"loss": 0.0,
|
|
}
|
|
metric_count = 0
|
|
loss_count = 0
|
|
|
|
total_steps = int(run_cfg["total_timesteps"])
|
|
checkpoint_interval = max(1, int(run_cfg["checkpoint_interval"]))
|
|
batch_size = max(1, int(run_cfg["batch_size"]))
|
|
|
|
for global_step in range(start_step + 1, total_steps + 1):
|
|
epsilon_value = _epsilon_by_fraction(global_step - 1, run_cfg)
|
|
|
|
rng, eps_key, action_key, step_key, reset_key, sample_key = jax.random.split(
|
|
rng, 6
|
|
)
|
|
do_explore = bool(np.asarray(jax.random.uniform(eps_key) < epsilon_value))
|
|
if do_explore:
|
|
action = jax.random.randint(
|
|
action_key, shape=(), minval=0, maxval=action_dim
|
|
)
|
|
else:
|
|
q_values = q_net.apply(train_state.params, obs)
|
|
action = jnp.argmax(q_values)
|
|
|
|
next_obs, next_state, reward, done, info = env.step(step_key, env_state, action)
|
|
buffer = _replay_add(
|
|
buffer,
|
|
obs,
|
|
action,
|
|
reward,
|
|
next_obs,
|
|
done.astype(jnp.float32),
|
|
)
|
|
|
|
metric_count += 1
|
|
metric_sums["reward"] += _scalar(reward)
|
|
metric_sums["revenue"] += _scalar(info["revenue"])
|
|
metric_sums["agent_prob"] += _scalar(info["agent_prob"])
|
|
metric_sums["alpha_adv"] += _scalar(info["alpha_adv"])
|
|
metric_sums["coi_leakage"] += _scalar(info["coi_leakage"])
|
|
|
|
if bool(np.asarray(done)):
|
|
obs, env_state = env.reset(reset_key)
|
|
else:
|
|
obs, env_state = next_obs, next_state
|
|
|
|
ready = (
|
|
global_step >= int(run_cfg["learning_starts"])
|
|
and global_step % int(run_cfg["train_freq"]) == 0
|
|
and _replay_size(buffer) >= batch_size
|
|
)
|
|
if ready:
|
|
batch = _replay_sample(buffer, sample_key, batch_size)
|
|
train_state, loss = dqn_update(train_state, target_params, batch)
|
|
metric_sums["loss"] += _scalar(loss)
|
|
loss_count += 1
|
|
|
|
if global_step % int(run_cfg["target_update_interval"]) == 0:
|
|
target_params = train_state.params
|
|
|
|
if (
|
|
HAS_WANDB
|
|
and wandb.run is not None
|
|
and global_step % int(run_cfg["log_freq"]) == 0
|
|
):
|
|
wandb.log(
|
|
{
|
|
"train/reward": metric_sums["reward"] / max(metric_count, 1),
|
|
"train/revenue": metric_sums["revenue"] / max(metric_count, 1),
|
|
"train/agent_prob": metric_sums["agent_prob"]
|
|
/ max(metric_count, 1),
|
|
"train/alpha_adv": metric_sums["alpha_adv"] / max(metric_count, 1),
|
|
"train/coi_leakage": metric_sums["coi_leakage"]
|
|
/ max(metric_count, 1),
|
|
"train/dqn_loss": metric_sums["loss"] / max(loss_count, 1),
|
|
"train/epsilon": epsilon_value,
|
|
"train/global_step": global_step,
|
|
},
|
|
step=global_step,
|
|
)
|
|
|
|
if artifact_name is not None and global_step % checkpoint_interval == 0:
|
|
payload = serialization.to_bytes(
|
|
{
|
|
"params": train_state.params,
|
|
"target_params": target_params,
|
|
"opt_state": train_state.opt_state,
|
|
"global_step": global_step,
|
|
"epsilon": epsilon_value,
|
|
}
|
|
)
|
|
log_checkpoint_bytes(
|
|
artifact_name,
|
|
file_name="jax_dqn_state.msgpack",
|
|
payload=payload,
|
|
metadata={
|
|
"step": global_step,
|
|
"algo": "dqn",
|
|
},
|
|
)
|
|
if _stop_requested.is_set():
|
|
break
|
|
|
|
denom = float(metric_count) if metric_count > 0 else 1.0
|
|
metrics = {
|
|
"train/reward": float(metric_sums["reward"] / denom),
|
|
"train/revenue": float(metric_sums["revenue"] / denom),
|
|
"train/agent_prob": float(metric_sums["agent_prob"] / denom),
|
|
"train/alpha_adv": float(metric_sums["alpha_adv"] / denom),
|
|
"train/coi_leakage": float(metric_sums["coi_leakage"] / denom),
|
|
"train/dqn_loss": float(metric_sums["loss"] / max(loss_count, 1)),
|
|
"train/global_step": total_steps,
|
|
}
|
|
|
|
eval_metrics = _evaluate_q_network(
|
|
network=q_net,
|
|
params=train_state.params,
|
|
env=env,
|
|
episodes=run_cfg["eval_episodes"],
|
|
seed=run_cfg["seed"] + 7,
|
|
)
|
|
metrics.update(eval_metrics)
|
|
|
|
model_dir = Path(run_cfg["model_dir"])
|
|
model_dir.mkdir(parents=True, exist_ok=True)
|
|
model_path = model_dir / "phantom_dqn_jax.msgpack"
|
|
model_path.write_bytes(serialization.to_bytes(train_state.params))
|
|
metrics["model/path"] = str(model_path)
|
|
return {
|
|
"params": train_state.params,
|
|
"target_params": target_params,
|
|
}, metrics
|
|
|
|
|
|
def _train_qtable(cfg: dict[str, Any]) -> tuple[dict[str, Any], dict[str, float]]:
|
|
run_cfg = dict(cfg)
|
|
env = _make_env(run_cfg)
|
|
action_dim = env.action_space_n()
|
|
n_bins = max(2, int(run_cfg["q_bins"]))
|
|
n_products = int(run_cfg["n_products"])
|
|
|
|
q_table = jnp.zeros((n_bins, n_bins, n_bins, action_dim), dtype=jnp.float32)
|
|
demand_bins = jnp.linspace(0.0, 100.0, n_bins + 1, dtype=jnp.float32)[1:-1]
|
|
price_bins = jnp.linspace(
|
|
float(run_cfg["price_low"]),
|
|
float(run_cfg["price_high"]),
|
|
n_bins + 1,
|
|
dtype=jnp.float32,
|
|
)[1:-1]
|
|
|
|
rng = jax.random.PRNGKey(run_cfg["seed"])
|
|
rng, reset_key = jax.random.split(rng)
|
|
obs, env_state = env.reset(reset_key)
|
|
|
|
epsilon_value = float(run_cfg["eps_start"])
|
|
start_step = 0
|
|
artifact_name = None
|
|
|
|
if HAS_WANDB and wandb.run is not None:
|
|
sweep_id = getattr(wandb.run, "sweep_id", None)
|
|
artifact_name = checkpoint_artifact_name(
|
|
run_cfg,
|
|
backend="jax",
|
|
sweep_id=sweep_id,
|
|
)
|
|
restored = download_latest_checkpoint(
|
|
artifact_name,
|
|
file_name="jax_qtable_state.msgpack",
|
|
)
|
|
if restored is not None:
|
|
checkpoint_path, metadata = restored
|
|
template = {
|
|
"q_table": q_table,
|
|
"global_step": 0,
|
|
"epsilon": epsilon_value,
|
|
}
|
|
payload = serialization.from_bytes(template, checkpoint_path.read_bytes())
|
|
q_table = payload["q_table"]
|
|
start_step = int(payload.get("global_step", metadata.get("step", 0)))
|
|
start_step = max(0, min(start_step, int(run_cfg["total_timesteps"])))
|
|
epsilon_value = float(payload.get("epsilon", epsilon_value))
|
|
|
|
metric_sums = {
|
|
"reward": 0.0,
|
|
"revenue": 0.0,
|
|
"agent_prob": 0.0,
|
|
"alpha_adv": 0.0,
|
|
"coi_leakage": 0.0,
|
|
}
|
|
metric_count = 0
|
|
|
|
total_steps = int(run_cfg["total_timesteps"])
|
|
checkpoint_interval = max(1, int(run_cfg["checkpoint_interval"]))
|
|
|
|
for global_step in range(start_step + 1, total_steps + 1):
|
|
s0, s1, s2 = _encode_qtable_state(
|
|
obs,
|
|
n_products=n_products,
|
|
demand_bins=demand_bins,
|
|
price_bins=price_bins,
|
|
)
|
|
state_q = q_table[s0, s1, s2]
|
|
|
|
rng, eps_key, action_key, step_key, reset_key = jax.random.split(rng, 5)
|
|
do_explore = bool(np.asarray(jax.random.uniform(eps_key) < epsilon_value))
|
|
if do_explore:
|
|
action = jax.random.randint(
|
|
action_key, shape=(), minval=0, maxval=action_dim
|
|
)
|
|
else:
|
|
action = jnp.argmax(state_q)
|
|
|
|
next_obs, next_state, reward, done, info = env.step(step_key, env_state, action)
|
|
ns0, ns1, ns2 = _encode_qtable_state(
|
|
next_obs,
|
|
n_products=n_products,
|
|
demand_bins=demand_bins,
|
|
price_bins=price_bins,
|
|
)
|
|
|
|
best_next = jnp.max(q_table[ns0, ns1, ns2])
|
|
done_f = done.astype(jnp.float32)
|
|
td_target = reward + run_cfg["gamma"] * (1.0 - done_f) * best_next
|
|
old_value = q_table[s0, s1, s2, action]
|
|
new_value = old_value + run_cfg["q_lr"] * (td_target - old_value)
|
|
q_table = q_table.at[s0, s1, s2, action].set(new_value)
|
|
|
|
epsilon_value = max(
|
|
float(run_cfg["eps_end"]),
|
|
epsilon_value * float(run_cfg["eps_decay"]),
|
|
)
|
|
|
|
metric_count += 1
|
|
metric_sums["reward"] += _scalar(reward)
|
|
metric_sums["revenue"] += _scalar(info["revenue"])
|
|
metric_sums["agent_prob"] += _scalar(info["agent_prob"])
|
|
metric_sums["alpha_adv"] += _scalar(info["alpha_adv"])
|
|
metric_sums["coi_leakage"] += _scalar(info["coi_leakage"])
|
|
|
|
if bool(np.asarray(done)):
|
|
obs, env_state = env.reset(reset_key)
|
|
else:
|
|
obs, env_state = next_obs, next_state
|
|
|
|
if (
|
|
HAS_WANDB
|
|
and wandb.run is not None
|
|
and global_step % int(run_cfg["log_freq"]) == 0
|
|
):
|
|
wandb.log(
|
|
{
|
|
"train/reward": metric_sums["reward"] / max(metric_count, 1),
|
|
"train/revenue": metric_sums["revenue"] / max(metric_count, 1),
|
|
"train/agent_prob": metric_sums["agent_prob"]
|
|
/ max(metric_count, 1),
|
|
"train/alpha_adv": metric_sums["alpha_adv"] / max(metric_count, 1),
|
|
"train/coi_leakage": metric_sums["coi_leakage"]
|
|
/ max(metric_count, 1),
|
|
"train/epsilon": epsilon_value,
|
|
"train/global_step": global_step,
|
|
},
|
|
step=global_step,
|
|
)
|
|
|
|
if artifact_name is not None and global_step % checkpoint_interval == 0:
|
|
payload = serialization.to_bytes(
|
|
{
|
|
"q_table": q_table,
|
|
"global_step": global_step,
|
|
"epsilon": epsilon_value,
|
|
}
|
|
)
|
|
log_checkpoint_bytes(
|
|
artifact_name,
|
|
file_name="jax_qtable_state.msgpack",
|
|
payload=payload,
|
|
metadata={
|
|
"step": global_step,
|
|
"algo": "qtable",
|
|
},
|
|
)
|
|
|
|
denom = float(metric_count) if metric_count > 0 else 1.0
|
|
metrics = {
|
|
"train/reward": float(metric_sums["reward"] / denom),
|
|
"train/revenue": float(metric_sums["revenue"] / denom),
|
|
"train/agent_prob": float(metric_sums["agent_prob"] / denom),
|
|
"train/alpha_adv": float(metric_sums["alpha_adv"] / denom),
|
|
"train/coi_leakage": float(metric_sums["coi_leakage"] / denom),
|
|
"train/global_step": total_steps,
|
|
}
|
|
|
|
eval_metrics = _evaluate_q_table(
|
|
q_table=q_table,
|
|
env=env,
|
|
episodes=run_cfg["eval_episodes"],
|
|
seed=run_cfg["seed"] + 7,
|
|
n_products=n_products,
|
|
demand_bins=demand_bins,
|
|
price_bins=price_bins,
|
|
)
|
|
metrics.update(eval_metrics)
|
|
|
|
model_dir = Path(run_cfg["model_dir"])
|
|
model_dir.mkdir(parents=True, exist_ok=True)
|
|
model_path = model_dir / "phantom_qtable_jax.msgpack"
|
|
model_path.write_bytes(serialization.to_bytes(q_table))
|
|
metrics["model/path"] = str(model_path)
|
|
return {"q_table": q_table}, metrics
|
|
|
|
|
|
def train_jax(cfg: dict[str, Any]) -> tuple[dict[str, Any], dict[str, float]]:
|
|
if not HAS_JAX_STACK:
|
|
raise ImportError(
|
|
"JAX path requires jax, flax, optax, and distrax. "
|
|
"Install engine/jax/requirements.txt on this machine first."
|
|
)
|
|
|
|
_init_jax_distributed()
|
|
_stop_requested.clear()
|
|
run_cfg = _jax_cfg(cfg)
|
|
algo = run_cfg["algo"]
|
|
if threading.current_thread() is threading.main_thread():
|
|
signal.signal(signal.SIGTERM, lambda *_: _stop_requested.set())
|
|
|
|
if algo in {"ppo", "a2c"}:
|
|
return _train_actor_critic(run_cfg, algo=algo)
|
|
if algo == "dqn":
|
|
return _train_dqn(run_cfg)
|
|
if algo == "qtable":
|
|
return _train_qtable(run_cfg)
|
|
raise ValueError(f"Unsupported JAX algo '{algo}'")
|