mirror of
https://github.com/velocitatem/PHANTOM.git
synced 2026-05-31 08:33:36 +00:00
feat: training update
This commit is contained in:
@@ -727,6 +727,8 @@ def _train_actor_critic(
|
||||
) -> tuple[dict[str, Any], dict[str, float]]:
|
||||
num_devices = jax.local_device_count()
|
||||
use_pmap = num_devices > 1
|
||||
global_devices = max(1, int(jax.device_count()))
|
||||
process_idx = int(jax.process_index())
|
||||
|
||||
init_runner_state, run_updates_raw, network, env, run_cfg = (
|
||||
_make_actor_critic_train(cfg, algo=algo, use_pmap=use_pmap)
|
||||
@@ -743,18 +745,26 @@ def _train_actor_critic(
|
||||
run_fn = jax.jit(run_updates_raw, static_argnames=("num_updates",))
|
||||
|
||||
rollout_steps = int(run_cfg["num_steps"] * run_cfg["num_envs"])
|
||||
rollout_steps_global = rollout_steps * (global_devices if use_pmap else 1)
|
||||
total_updates = int(run_cfg["num_updates"])
|
||||
checkpoint_interval = max(1, int(run_cfg.get("checkpoint_interval", 10_000)))
|
||||
segment_updates = max(1, checkpoint_interval // max(rollout_steps, 1))
|
||||
segment_updates = max(1, checkpoint_interval // max(rollout_steps_global, 1))
|
||||
|
||||
rng = jax.random.PRNGKey(run_cfg["seed"])
|
||||
# single-device state used as template for serialization and eval
|
||||
single_runner_state = init_runner_state(rng)
|
||||
base_rng = jax.random.PRNGKey(run_cfg["seed"])
|
||||
base_rng = jax.random.fold_in(base_rng, process_idx)
|
||||
if use_pmap:
|
||||
init_keys = jax.random.split(base_rng, num_devices)
|
||||
runner_state = jax.vmap(init_runner_state)(init_keys)
|
||||
single_runner_state = jax.tree_util.tree_map(lambda x: x[0], runner_state)
|
||||
else:
|
||||
single_runner_state = init_runner_state(base_rng)
|
||||
runner_state = single_runner_state
|
||||
updates_done = 0
|
||||
restored_train_state = None
|
||||
|
||||
is_primary = jax.process_index() == 0
|
||||
is_primary = process_idx == 0
|
||||
artifact_name = None
|
||||
if is_primary and HAS_WANDB and wandb.run is not None:
|
||||
if HAS_WANDB and wandb.run is not None:
|
||||
sweep_id = getattr(wandb.run, "sweep_id", None)
|
||||
artifact_name = checkpoint_artifact_name(
|
||||
run_cfg,
|
||||
@@ -770,16 +780,20 @@ def _train_actor_critic(
|
||||
template = {"runner_state": single_runner_state, "updates_done": 0}
|
||||
payload = serialization.from_bytes(template, checkpoint_path.read_bytes())
|
||||
single_runner_state = payload["runner_state"]
|
||||
restored_train_state = payload["runner_state"][0]
|
||||
updates_done = int(payload.get("updates_done", 0))
|
||||
if updates_done <= 0:
|
||||
updates_done = int(metadata.get("updates_done", 0))
|
||||
updates_done = max(0, min(updates_done, total_updates))
|
||||
|
||||
if use_pmap:
|
||||
runner_state = jax.device_put_replicated(
|
||||
single_runner_state, jax.local_devices()
|
||||
if use_pmap and restored_train_state is not None:
|
||||
runner_state = (
|
||||
jax.device_put_replicated(restored_train_state, jax.local_devices()),
|
||||
runner_state[1],
|
||||
runner_state[2],
|
||||
runner_state[3],
|
||||
)
|
||||
else:
|
||||
elif not use_pmap:
|
||||
runner_state = single_runner_state
|
||||
|
||||
metric_keys = ["reward", "revenue", "agent_prob", "alpha_adv", "coi_leakage"]
|
||||
@@ -796,13 +810,14 @@ def _train_actor_critic(
|
||||
metric = out["metrics"]
|
||||
|
||||
if use_pmap:
|
||||
# take device-0 slice; shape is (n_devices, segment_updates)
|
||||
segment_values = {
|
||||
key: np.asarray(metric[key][0], dtype=np.float64) for key in metric_keys
|
||||
key: np.asarray(metric[key], dtype=np.float64).reshape(-1)
|
||||
for key in metric_keys
|
||||
}
|
||||
else:
|
||||
segment_values = {
|
||||
key: np.asarray(metric[key], dtype=np.float64) for key in metric_keys
|
||||
key: np.asarray(metric[key], dtype=np.float64).reshape(-1)
|
||||
for key in metric_keys
|
||||
}
|
||||
|
||||
segment_count = int(segment_values["reward"].shape[0]) if segment_values else 0
|
||||
@@ -811,7 +826,7 @@ def _train_actor_critic(
|
||||
metric_sums[key] += float(segment_values[key].sum())
|
||||
|
||||
updates_done += int(updates_this_segment)
|
||||
global_step = int(updates_done * rollout_steps)
|
||||
global_step = int(updates_done * rollout_steps_global)
|
||||
|
||||
if is_primary and HAS_WANDB and wandb.run is not None:
|
||||
wandb.log(
|
||||
@@ -842,7 +857,7 @@ def _train_actor_critic(
|
||||
metadata={
|
||||
"step": global_step,
|
||||
"updates_done": updates_done,
|
||||
"rollout_steps": rollout_steps,
|
||||
"rollout_steps": rollout_steps_global,
|
||||
"algo": algo,
|
||||
},
|
||||
)
|
||||
@@ -863,7 +878,7 @@ def _train_actor_critic(
|
||||
"train/agent_prob": float(metric_sums["agent_prob"] / denom),
|
||||
"train/alpha_adv": float(metric_sums["alpha_adv"] / denom),
|
||||
"train/coi_leakage": float(metric_sums["coi_leakage"] / denom),
|
||||
"train/global_step": int(updates_done * rollout_steps),
|
||||
"train/global_step": int(updates_done * rollout_steps_global),
|
||||
}
|
||||
|
||||
eval_metrics = evaluate_policy(
|
||||
|
||||
Reference in New Issue
Block a user