mirror of
https://github.com/velocitatem/PHANTOM.git
synced 2026-05-31 08:33:36 +00:00
feat: training update
This commit is contained in:
43
Makefile
43
Makefile
@@ -26,6 +26,8 @@ RETRY_SECONDS ?= 20
|
|||||||
TRAIN_IMAGE_REF := us-central1-docker.pkg.dev/phantom-trc/phantom/phantom-trainer
|
TRAIN_IMAGE_REF := us-central1-docker.pkg.dev/phantom-trc/phantom/phantom-trainer
|
||||||
TPU_NAME ?=
|
TPU_NAME ?=
|
||||||
TPU_ZONE ?= us-central2-b
|
TPU_ZONE ?= us-central2-b
|
||||||
|
TPU_PROJECT ?= phantom-trc
|
||||||
|
TPU_REPO_DIR ?= /tmp/PHANTOM
|
||||||
|
|
||||||
SWEEP_ENV_LOAD = set -a; [ -f "$(SWEEP_ENV_FILE)" ] && . "$(SWEEP_ENV_FILE)" || true; set +a
|
SWEEP_ENV_LOAD = set -a; [ -f "$(SWEEP_ENV_FILE)" ] && . "$(SWEEP_ENV_FILE)" || true; set +a
|
||||||
|
|
||||||
@@ -33,7 +35,7 @@ SWEEP_ENV_LOAD = set -a; [ -f "$(SWEEP_ENV_FILE)" ] && . "$(SWEEP_ENV_FILE)" ||
|
|||||||
|
|
||||||
.PHONY: help
|
.PHONY: help
|
||||||
help:
|
help:
|
||||||
@echo "pdf.build pdf.watch pdf.clean | test.backend test.e2e test.all | web.dev | install | train | train.agent | train.bootstrap | train.tpu.pod | stats.lines"
|
@echo "pdf.build pdf.watch pdf.clean | test.backend test.e2e test.all | web.dev | install | train | train.agent | train.bootstrap | train.tpu.pod | train.tpu.vm | train.tpu.vm.sweep | stats.lines"
|
||||||
@echo "docker.train.publish"
|
@echo "docker.train.publish"
|
||||||
@echo ""
|
@echo ""
|
||||||
@echo "Local wandb run:"
|
@echo "Local wandb run:"
|
||||||
@@ -165,12 +167,47 @@ train.tpu.pod:
|
|||||||
@test -n "$(SWEEP_ID)" || (echo "SWEEP_ID required, e.g. SWEEP_ID=entity/project/id" && exit 1)
|
@test -n "$(SWEEP_ID)" || (echo "SWEEP_ID required, e.g. SWEEP_ID=entity/project/id" && exit 1)
|
||||||
@$(SWEEP_ENV_LOAD); test -n "$$WANDB_API_KEY" || (echo "WANDB_API_KEY required — set it in $(SWEEP_ENV_FILE)" && exit 1)
|
@$(SWEEP_ENV_LOAD); test -n "$$WANDB_API_KEY" || (echo "WANDB_API_KEY required — set it in $(SWEEP_ENV_FILE)" && exit 1)
|
||||||
gcloud compute tpus tpu-vm scp scripts/tpu_pod_run.sh $(TPU_NAME):/tmp/tpu_pod_run.sh \
|
gcloud compute tpus tpu-vm scp scripts/tpu_pod_run.sh $(TPU_NAME):/tmp/tpu_pod_run.sh \
|
||||||
--zone=$(TPU_ZONE) --project=phantom-trc --worker=all
|
--zone=$(TPU_ZONE) --project=$(TPU_PROJECT) --worker=all
|
||||||
@$(SWEEP_ENV_LOAD); \
|
@$(SWEEP_ENV_LOAD); \
|
||||||
gcloud compute tpus tpu-vm ssh $(TPU_NAME) \
|
gcloud compute tpus tpu-vm ssh $(TPU_NAME) \
|
||||||
--zone=$(TPU_ZONE) --project=phantom-trc --worker=all \
|
--zone=$(TPU_ZONE) --project=$(TPU_PROJECT) --worker=all \
|
||||||
--command="WANDB_API_KEY='$$WANDB_API_KEY' SWEEP_ID='$(SWEEP_ID)' AGENT_COUNT='$(AGENT_COUNT)' sh /tmp/tpu_pod_run.sh"
|
--command="WANDB_API_KEY='$$WANDB_API_KEY' SWEEP_ID='$(SWEEP_ID)' AGENT_COUNT='$(AGENT_COUNT)' sh /tmp/tpu_pod_run.sh"
|
||||||
|
|
||||||
|
.PHONY: train.tpu.vm.prepare
|
||||||
|
train.tpu.vm.prepare:
|
||||||
|
@test -n "$(TPU_NAME)" || (echo "TPU_NAME required, e.g. TPU_NAME=TPUlong" && exit 1)
|
||||||
|
TPU_NAME="$(TPU_NAME)" TPU_ZONE="$(TPU_ZONE)" TPU_PROJECT="$(TPU_PROJECT)" \
|
||||||
|
LOCAL_REPO_DIR="$(CURDIR)" REMOTE_REPO_DIR="$(TPU_REPO_DIR)" \
|
||||||
|
sh scripts/tpu_sync_repo.sh
|
||||||
|
gcloud compute tpus tpu-vm scp scripts/tpu_vm_train.sh $(TPU_NAME):/tmp/tpu_vm_train.sh \
|
||||||
|
--zone=$(TPU_ZONE) --project=$(TPU_PROJECT) --worker=all
|
||||||
|
|
||||||
|
.PHONY: train.tpu.vm.run
|
||||||
|
train.tpu.vm.run:
|
||||||
|
@test -n "$(TPU_NAME)" || (echo "TPU_NAME required, e.g. TPU_NAME=TPUlong" && exit 1)
|
||||||
|
@test -n "$(LOCAL_TRAIN_ARGS)" || (echo "LOCAL_TRAIN_ARGS required, e.g. --algo ppo --jax --total-timesteps 200000" && exit 1)
|
||||||
|
@$(SWEEP_ENV_LOAD); \
|
||||||
|
gcloud compute tpus tpu-vm ssh $(TPU_NAME) \
|
||||||
|
--zone=$(TPU_ZONE) --project=$(TPU_PROJECT) --worker=all \
|
||||||
|
--command="REPO_DIR='$(TPU_REPO_DIR)' TRAIN_ARGS='$(LOCAL_TRAIN_ARGS)' WANDB_API_KEY='$$WANDB_API_KEY' sh /tmp/tpu_vm_train.sh"
|
||||||
|
|
||||||
|
.PHONY: train.tpu.vm
|
||||||
|
train.tpu.vm: train.tpu.vm.prepare train.tpu.vm.run
|
||||||
|
|
||||||
|
.PHONY: train.tpu.vm.sweep
|
||||||
|
train.tpu.vm.sweep:
|
||||||
|
@test -n "$(TPU_NAME)" || (echo "TPU_NAME required, e.g. TPU_NAME=TPUlong" && exit 1)
|
||||||
|
@test -n "$(SWEEP_ID)" || (echo "SWEEP_ID required, e.g. SWEEP_ID=lusiana/phantom-pricing/abc123" && exit 1)
|
||||||
|
@$(SWEEP_ENV_LOAD); test -n "$$WANDB_API_KEY" || (echo "WANDB_API_KEY required — set it in $(SWEEP_ENV_FILE)" && exit 1)
|
||||||
|
@$(SWEEP_ENV_LOAD); WANDB_API_KEY="$$WANDB_API_KEY" \
|
||||||
|
python3 scripts/tpu_vm_sweep_agent.py \
|
||||||
|
--sweep-id "$(SWEEP_ID)" \
|
||||||
|
--tpu-name "$(TPU_NAME)" \
|
||||||
|
--tpu-zone "$(TPU_ZONE)" \
|
||||||
|
--tpu-project "$(TPU_PROJECT)" \
|
||||||
|
--tpu-repo-dir "$(TPU_REPO_DIR)" \
|
||||||
|
$(if $(filter-out 0,$(AGENT_COUNT)),--count $(AGENT_COUNT),)
|
||||||
|
|
||||||
.PHONY: pdf clean watch run.webapp test count-lines all
|
.PHONY: pdf clean watch run.webapp test count-lines all
|
||||||
pdf: pdf.build
|
pdf: pdf.build
|
||||||
clean: pdf.clean
|
clean: pdf.clean
|
||||||
|
|||||||
@@ -727,6 +727,8 @@ def _train_actor_critic(
|
|||||||
) -> tuple[dict[str, Any], dict[str, float]]:
|
) -> tuple[dict[str, Any], dict[str, float]]:
|
||||||
num_devices = jax.local_device_count()
|
num_devices = jax.local_device_count()
|
||||||
use_pmap = num_devices > 1
|
use_pmap = num_devices > 1
|
||||||
|
global_devices = max(1, int(jax.device_count()))
|
||||||
|
process_idx = int(jax.process_index())
|
||||||
|
|
||||||
init_runner_state, run_updates_raw, network, env, run_cfg = (
|
init_runner_state, run_updates_raw, network, env, run_cfg = (
|
||||||
_make_actor_critic_train(cfg, algo=algo, use_pmap=use_pmap)
|
_make_actor_critic_train(cfg, algo=algo, use_pmap=use_pmap)
|
||||||
@@ -743,18 +745,26 @@ def _train_actor_critic(
|
|||||||
run_fn = jax.jit(run_updates_raw, static_argnames=("num_updates",))
|
run_fn = jax.jit(run_updates_raw, static_argnames=("num_updates",))
|
||||||
|
|
||||||
rollout_steps = int(run_cfg["num_steps"] * run_cfg["num_envs"])
|
rollout_steps = int(run_cfg["num_steps"] * run_cfg["num_envs"])
|
||||||
|
rollout_steps_global = rollout_steps * (global_devices if use_pmap else 1)
|
||||||
total_updates = int(run_cfg["num_updates"])
|
total_updates = int(run_cfg["num_updates"])
|
||||||
checkpoint_interval = max(1, int(run_cfg.get("checkpoint_interval", 10_000)))
|
checkpoint_interval = max(1, int(run_cfg.get("checkpoint_interval", 10_000)))
|
||||||
segment_updates = max(1, checkpoint_interval // max(rollout_steps, 1))
|
segment_updates = max(1, checkpoint_interval // max(rollout_steps_global, 1))
|
||||||
|
|
||||||
rng = jax.random.PRNGKey(run_cfg["seed"])
|
base_rng = jax.random.PRNGKey(run_cfg["seed"])
|
||||||
# single-device state used as template for serialization and eval
|
base_rng = jax.random.fold_in(base_rng, process_idx)
|
||||||
single_runner_state = init_runner_state(rng)
|
if use_pmap:
|
||||||
|
init_keys = jax.random.split(base_rng, num_devices)
|
||||||
|
runner_state = jax.vmap(init_runner_state)(init_keys)
|
||||||
|
single_runner_state = jax.tree_util.tree_map(lambda x: x[0], runner_state)
|
||||||
|
else:
|
||||||
|
single_runner_state = init_runner_state(base_rng)
|
||||||
|
runner_state = single_runner_state
|
||||||
updates_done = 0
|
updates_done = 0
|
||||||
|
restored_train_state = None
|
||||||
|
|
||||||
is_primary = jax.process_index() == 0
|
is_primary = process_idx == 0
|
||||||
artifact_name = None
|
artifact_name = None
|
||||||
if is_primary and HAS_WANDB and wandb.run is not None:
|
if HAS_WANDB and wandb.run is not None:
|
||||||
sweep_id = getattr(wandb.run, "sweep_id", None)
|
sweep_id = getattr(wandb.run, "sweep_id", None)
|
||||||
artifact_name = checkpoint_artifact_name(
|
artifact_name = checkpoint_artifact_name(
|
||||||
run_cfg,
|
run_cfg,
|
||||||
@@ -770,16 +780,20 @@ def _train_actor_critic(
|
|||||||
template = {"runner_state": single_runner_state, "updates_done": 0}
|
template = {"runner_state": single_runner_state, "updates_done": 0}
|
||||||
payload = serialization.from_bytes(template, checkpoint_path.read_bytes())
|
payload = serialization.from_bytes(template, checkpoint_path.read_bytes())
|
||||||
single_runner_state = payload["runner_state"]
|
single_runner_state = payload["runner_state"]
|
||||||
|
restored_train_state = payload["runner_state"][0]
|
||||||
updates_done = int(payload.get("updates_done", 0))
|
updates_done = int(payload.get("updates_done", 0))
|
||||||
if updates_done <= 0:
|
if updates_done <= 0:
|
||||||
updates_done = int(metadata.get("updates_done", 0))
|
updates_done = int(metadata.get("updates_done", 0))
|
||||||
updates_done = max(0, min(updates_done, total_updates))
|
updates_done = max(0, min(updates_done, total_updates))
|
||||||
|
|
||||||
if use_pmap:
|
if use_pmap and restored_train_state is not None:
|
||||||
runner_state = jax.device_put_replicated(
|
runner_state = (
|
||||||
single_runner_state, jax.local_devices()
|
jax.device_put_replicated(restored_train_state, jax.local_devices()),
|
||||||
|
runner_state[1],
|
||||||
|
runner_state[2],
|
||||||
|
runner_state[3],
|
||||||
)
|
)
|
||||||
else:
|
elif not use_pmap:
|
||||||
runner_state = single_runner_state
|
runner_state = single_runner_state
|
||||||
|
|
||||||
metric_keys = ["reward", "revenue", "agent_prob", "alpha_adv", "coi_leakage"]
|
metric_keys = ["reward", "revenue", "agent_prob", "alpha_adv", "coi_leakage"]
|
||||||
@@ -796,13 +810,14 @@ def _train_actor_critic(
|
|||||||
metric = out["metrics"]
|
metric = out["metrics"]
|
||||||
|
|
||||||
if use_pmap:
|
if use_pmap:
|
||||||
# take device-0 slice; shape is (n_devices, segment_updates)
|
|
||||||
segment_values = {
|
segment_values = {
|
||||||
key: np.asarray(metric[key][0], dtype=np.float64) for key in metric_keys
|
key: np.asarray(metric[key], dtype=np.float64).reshape(-1)
|
||||||
|
for key in metric_keys
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
segment_values = {
|
segment_values = {
|
||||||
key: np.asarray(metric[key], dtype=np.float64) for key in metric_keys
|
key: np.asarray(metric[key], dtype=np.float64).reshape(-1)
|
||||||
|
for key in metric_keys
|
||||||
}
|
}
|
||||||
|
|
||||||
segment_count = int(segment_values["reward"].shape[0]) if segment_values else 0
|
segment_count = int(segment_values["reward"].shape[0]) if segment_values else 0
|
||||||
@@ -811,7 +826,7 @@ def _train_actor_critic(
|
|||||||
metric_sums[key] += float(segment_values[key].sum())
|
metric_sums[key] += float(segment_values[key].sum())
|
||||||
|
|
||||||
updates_done += int(updates_this_segment)
|
updates_done += int(updates_this_segment)
|
||||||
global_step = int(updates_done * rollout_steps)
|
global_step = int(updates_done * rollout_steps_global)
|
||||||
|
|
||||||
if is_primary and HAS_WANDB and wandb.run is not None:
|
if is_primary and HAS_WANDB and wandb.run is not None:
|
||||||
wandb.log(
|
wandb.log(
|
||||||
@@ -842,7 +857,7 @@ def _train_actor_critic(
|
|||||||
metadata={
|
metadata={
|
||||||
"step": global_step,
|
"step": global_step,
|
||||||
"updates_done": updates_done,
|
"updates_done": updates_done,
|
||||||
"rollout_steps": rollout_steps,
|
"rollout_steps": rollout_steps_global,
|
||||||
"algo": algo,
|
"algo": algo,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
@@ -863,7 +878,7 @@ def _train_actor_critic(
|
|||||||
"train/agent_prob": float(metric_sums["agent_prob"] / denom),
|
"train/agent_prob": float(metric_sums["agent_prob"] / denom),
|
||||||
"train/alpha_adv": float(metric_sums["alpha_adv"] / denom),
|
"train/alpha_adv": float(metric_sums["alpha_adv"] / denom),
|
||||||
"train/coi_leakage": float(metric_sums["coi_leakage"] / denom),
|
"train/coi_leakage": float(metric_sums["coi_leakage"] / denom),
|
||||||
"train/global_step": int(updates_done * rollout_steps),
|
"train/global_step": int(updates_done * rollout_steps_global),
|
||||||
}
|
}
|
||||||
|
|
||||||
eval_metrics = evaluate_policy(
|
eval_metrics = evaluate_policy(
|
||||||
|
|||||||
@@ -9,10 +9,16 @@ import numpy as np
|
|||||||
from .wandb_checkpoint import checkpoint_artifact_name, download_latest_checkpoint
|
from .wandb_checkpoint import checkpoint_artifact_name, download_latest_checkpoint
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import wandb
|
import wandb as _wandb
|
||||||
|
|
||||||
HAS_WANDB = True
|
if hasattr(_wandb, "init") and callable(_wandb.init):
|
||||||
|
wandb = _wandb
|
||||||
|
HAS_WANDB = True
|
||||||
|
else:
|
||||||
|
wandb = None
|
||||||
|
HAS_WANDB = False
|
||||||
except ImportError:
|
except ImportError:
|
||||||
|
wandb = None
|
||||||
HAS_WANDB = False
|
HAS_WANDB = False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -80,7 +86,7 @@ DEFAULT_CFG = {
|
|||||||
"jax_num_minibatches": 4,
|
"jax_num_minibatches": 4,
|
||||||
"jax_update_epochs": 4,
|
"jax_update_epochs": 4,
|
||||||
"jax_anneal_lr": True,
|
"jax_anneal_lr": True,
|
||||||
"checkpoint_interval": 10_000,
|
"checkpoint_interval": 200_000,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -404,6 +410,16 @@ def run_wandb(
|
|||||||
) -> dict:
|
) -> dict:
|
||||||
if not HAS_WANDB:
|
if not HAS_WANDB:
|
||||||
raise ImportError("wandb is required for sweep runs")
|
raise ImportError("wandb is required for sweep runs")
|
||||||
|
if not sweep_mode:
|
||||||
|
pre_cfg = _cfg(overrides)
|
||||||
|
if pre_cfg.get("use_jax"):
|
||||||
|
try:
|
||||||
|
import jax
|
||||||
|
|
||||||
|
if jax.process_count() > 1 and jax.process_index() != 0:
|
||||||
|
return train_once(pre_cfg)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
init_kwargs = {"mode": mode}
|
init_kwargs = {"mode": mode}
|
||||||
if sweep_mode:
|
if sweep_mode:
|
||||||
run = wandb.init(**init_kwargs)
|
run = wandb.init(**init_kwargs)
|
||||||
@@ -431,7 +447,16 @@ def run_wandb(
|
|||||||
def run_local(overrides: dict) -> dict:
|
def run_local(overrides: dict) -> dict:
|
||||||
cfg = _cfg(overrides)
|
cfg = _cfg(overrides)
|
||||||
metrics = train_once(cfg)
|
metrics = train_once(cfg)
|
||||||
print(json.dumps(metrics, indent=2))
|
should_print = True
|
||||||
|
if cfg.get("use_jax"):
|
||||||
|
try:
|
||||||
|
import jax
|
||||||
|
|
||||||
|
should_print = jax.process_index() == 0
|
||||||
|
except Exception:
|
||||||
|
should_print = True
|
||||||
|
if should_print:
|
||||||
|
print(json.dumps(metrics, indent=2))
|
||||||
return metrics
|
return metrics
|
||||||
|
|
||||||
|
|
||||||
@@ -439,15 +464,26 @@ def main():
|
|||||||
p = argparse.ArgumentParser(description="PHANTOM training and W&B sweeps")
|
p = argparse.ArgumentParser(description="PHANTOM training and W&B sweeps")
|
||||||
p.add_argument("--project", default=DEFAULT_CFG["project"])
|
p.add_argument("--project", default=DEFAULT_CFG["project"])
|
||||||
p.add_argument("--algo", choices=["ppo", "a2c", "dqn", "qtable"])
|
p.add_argument("--algo", choices=["ppo", "a2c", "dqn", "qtable"])
|
||||||
|
p.add_argument("--seed", type=int)
|
||||||
p.add_argument("--total-timesteps", type=int)
|
p.add_argument("--total-timesteps", type=int)
|
||||||
p.add_argument("--alpha", type=float)
|
p.add_argument("--alpha", type=float)
|
||||||
|
p.add_argument("--N", type=int)
|
||||||
p.add_argument("--n-products", type=int)
|
p.add_argument("--n-products", type=int)
|
||||||
p.add_argument("--lambda-coi", type=float)
|
p.add_argument("--lambda-coi", type=float)
|
||||||
|
p.add_argument("--info-value", type=float)
|
||||||
p.add_argument("--robust-radius", type=float)
|
p.add_argument("--robust-radius", type=float)
|
||||||
p.add_argument("--robust-points", type=int)
|
p.add_argument("--robust-points", type=int)
|
||||||
p.add_argument("--learning-rate", type=float)
|
p.add_argument("--learning-rate", type=float)
|
||||||
p.add_argument("--gamma", type=float)
|
p.add_argument("--gamma", type=float)
|
||||||
|
p.add_argument("--gae-lambda", type=float)
|
||||||
|
p.add_argument("--clip-range", type=float)
|
||||||
|
p.add_argument("--ent-coef", type=float)
|
||||||
p.add_argument("--revenue-weight", type=float)
|
p.add_argument("--revenue-weight", type=float)
|
||||||
|
p.add_argument("--price-low", type=float)
|
||||||
|
p.add_argument("--price-high", type=float)
|
||||||
|
p.add_argument("--action-levels", type=int)
|
||||||
|
p.add_argument("--action-scale-low", type=float)
|
||||||
|
p.add_argument("--action-scale-high", type=float)
|
||||||
p.add_argument("--max-steps", type=int)
|
p.add_argument("--max-steps", type=int)
|
||||||
p.add_argument("--margin-floor", type=float)
|
p.add_argument("--margin-floor", type=float)
|
||||||
p.add_argument("--margin-floor-patience", type=int)
|
p.add_argument("--margin-floor-patience", type=int)
|
||||||
@@ -469,15 +505,26 @@ def main():
|
|||||||
|
|
||||||
overrides = {
|
overrides = {
|
||||||
"algo": args.algo,
|
"algo": args.algo,
|
||||||
|
"seed": args.seed,
|
||||||
"total_timesteps": args.total_timesteps,
|
"total_timesteps": args.total_timesteps,
|
||||||
"alpha": args.alpha,
|
"alpha": args.alpha,
|
||||||
|
"N": args.N,
|
||||||
"n_products": args.n_products,
|
"n_products": args.n_products,
|
||||||
"lambda_coi": args.lambda_coi,
|
"lambda_coi": args.lambda_coi,
|
||||||
|
"info_value": args.info_value,
|
||||||
"robust_radius": args.robust_radius,
|
"robust_radius": args.robust_radius,
|
||||||
"robust_points": args.robust_points,
|
"robust_points": args.robust_points,
|
||||||
"learning_rate": args.learning_rate,
|
"learning_rate": args.learning_rate,
|
||||||
"gamma": args.gamma,
|
"gamma": args.gamma,
|
||||||
|
"gae_lambda": args.gae_lambda,
|
||||||
|
"clip_range": args.clip_range,
|
||||||
|
"ent_coef": args.ent_coef,
|
||||||
"revenue_weight": args.revenue_weight,
|
"revenue_weight": args.revenue_weight,
|
||||||
|
"price_low": args.price_low,
|
||||||
|
"price_high": args.price_high,
|
||||||
|
"action_levels": args.action_levels,
|
||||||
|
"action_scale_low": args.action_scale_low,
|
||||||
|
"action_scale_high": args.action_scale_high,
|
||||||
"max_steps": args.max_steps,
|
"max_steps": args.max_steps,
|
||||||
"margin_floor": args.margin_floor,
|
"margin_floor": args.margin_floor,
|
||||||
"margin_floor_patience": args.margin_floor_patience,
|
"margin_floor_patience": args.margin_floor_patience,
|
||||||
|
|||||||
Reference in New Issue
Block a user