diff --git a/Makefile b/Makefile index 6a24577..c9203a4 100644 --- a/Makefile +++ b/Makefile @@ -26,6 +26,8 @@ RETRY_SECONDS ?= 20 TRAIN_IMAGE_REF := us-central1-docker.pkg.dev/phantom-trc/phantom/phantom-trainer TPU_NAME ?= TPU_ZONE ?= us-central2-b +TPU_PROJECT ?= phantom-trc +TPU_REPO_DIR ?= /tmp/PHANTOM SWEEP_ENV_LOAD = set -a; [ -f "$(SWEEP_ENV_FILE)" ] && . "$(SWEEP_ENV_FILE)" || true; set +a @@ -33,7 +35,7 @@ SWEEP_ENV_LOAD = set -a; [ -f "$(SWEEP_ENV_FILE)" ] && . "$(SWEEP_ENV_FILE)" || .PHONY: help help: - @echo "pdf.build pdf.watch pdf.clean | test.backend test.e2e test.all | web.dev | install | train | train.agent | train.bootstrap | train.tpu.pod | stats.lines" + @echo "pdf.build pdf.watch pdf.clean | test.backend test.e2e test.all | web.dev | install | train | train.agent | train.bootstrap | train.tpu.pod | train.tpu.vm | train.tpu.vm.sweep | stats.lines" @echo "docker.train.publish" @echo "" @echo "Local wandb run:" @@ -165,12 +167,47 @@ train.tpu.pod: @test -n "$(SWEEP_ID)" || (echo "SWEEP_ID required, e.g. SWEEP_ID=entity/project/id" && exit 1) @$(SWEEP_ENV_LOAD); test -n "$$WANDB_API_KEY" || (echo "WANDB_API_KEY required — set it in $(SWEEP_ENV_FILE)" && exit 1) gcloud compute tpus tpu-vm scp scripts/tpu_pod_run.sh $(TPU_NAME):/tmp/tpu_pod_run.sh \ - --zone=$(TPU_ZONE) --project=phantom-trc --worker=all + --zone=$(TPU_ZONE) --project=$(TPU_PROJECT) --worker=all @$(SWEEP_ENV_LOAD); \ gcloud compute tpus tpu-vm ssh $(TPU_NAME) \ - --zone=$(TPU_ZONE) --project=phantom-trc --worker=all \ + --zone=$(TPU_ZONE) --project=$(TPU_PROJECT) --worker=all \ --command="WANDB_API_KEY='$$WANDB_API_KEY' SWEEP_ID='$(SWEEP_ID)' AGENT_COUNT='$(AGENT_COUNT)' sh /tmp/tpu_pod_run.sh" +.PHONY: train.tpu.vm.prepare +train.tpu.vm.prepare: + @test -n "$(TPU_NAME)" || (echo "TPU_NAME required, e.g. TPU_NAME=TPUlong" && exit 1) + TPU_NAME="$(TPU_NAME)" TPU_ZONE="$(TPU_ZONE)" TPU_PROJECT="$(TPU_PROJECT)" \ + LOCAL_REPO_DIR="$(CURDIR)" REMOTE_REPO_DIR="$(TPU_REPO_DIR)" \ + sh scripts/tpu_sync_repo.sh + gcloud compute tpus tpu-vm scp scripts/tpu_vm_train.sh $(TPU_NAME):/tmp/tpu_vm_train.sh \ + --zone=$(TPU_ZONE) --project=$(TPU_PROJECT) --worker=all + +.PHONY: train.tpu.vm.run +train.tpu.vm.run: + @test -n "$(TPU_NAME)" || (echo "TPU_NAME required, e.g. TPU_NAME=TPUlong" && exit 1) + @test -n "$(LOCAL_TRAIN_ARGS)" || (echo "LOCAL_TRAIN_ARGS required, e.g. --algo ppo --jax --total-timesteps 200000" && exit 1) + @$(SWEEP_ENV_LOAD); \ + gcloud compute tpus tpu-vm ssh $(TPU_NAME) \ + --zone=$(TPU_ZONE) --project=$(TPU_PROJECT) --worker=all \ + --command="REPO_DIR='$(TPU_REPO_DIR)' TRAIN_ARGS='$(LOCAL_TRAIN_ARGS)' WANDB_API_KEY='$$WANDB_API_KEY' sh /tmp/tpu_vm_train.sh" + +.PHONY: train.tpu.vm +train.tpu.vm: train.tpu.vm.prepare train.tpu.vm.run + +.PHONY: train.tpu.vm.sweep +train.tpu.vm.sweep: + @test -n "$(TPU_NAME)" || (echo "TPU_NAME required, e.g. TPU_NAME=TPUlong" && exit 1) + @test -n "$(SWEEP_ID)" || (echo "SWEEP_ID required, e.g. SWEEP_ID=lusiana/phantom-pricing/abc123" && exit 1) + @$(SWEEP_ENV_LOAD); test -n "$$WANDB_API_KEY" || (echo "WANDB_API_KEY required — set it in $(SWEEP_ENV_FILE)" && exit 1) + @$(SWEEP_ENV_LOAD); WANDB_API_KEY="$$WANDB_API_KEY" \ + python3 scripts/tpu_vm_sweep_agent.py \ + --sweep-id "$(SWEEP_ID)" \ + --tpu-name "$(TPU_NAME)" \ + --tpu-zone "$(TPU_ZONE)" \ + --tpu-project "$(TPU_PROJECT)" \ + --tpu-repo-dir "$(TPU_REPO_DIR)" \ + $(if $(filter-out 0,$(AGENT_COUNT)),--count $(AGENT_COUNT),) + .PHONY: pdf clean watch run.webapp test count-lines all pdf: pdf.build clean: pdf.clean diff --git a/engine/jax/train.py b/engine/jax/train.py index 408f9b3..e5c4c03 100644 --- a/engine/jax/train.py +++ b/engine/jax/train.py @@ -727,6 +727,8 @@ def _train_actor_critic( ) -> tuple[dict[str, Any], dict[str, float]]: num_devices = jax.local_device_count() use_pmap = num_devices > 1 + global_devices = max(1, int(jax.device_count())) + process_idx = int(jax.process_index()) init_runner_state, run_updates_raw, network, env, run_cfg = ( _make_actor_critic_train(cfg, algo=algo, use_pmap=use_pmap) @@ -743,18 +745,26 @@ def _train_actor_critic( run_fn = jax.jit(run_updates_raw, static_argnames=("num_updates",)) rollout_steps = int(run_cfg["num_steps"] * run_cfg["num_envs"]) + rollout_steps_global = rollout_steps * (global_devices if use_pmap else 1) total_updates = int(run_cfg["num_updates"]) checkpoint_interval = max(1, int(run_cfg.get("checkpoint_interval", 10_000))) - segment_updates = max(1, checkpoint_interval // max(rollout_steps, 1)) + segment_updates = max(1, checkpoint_interval // max(rollout_steps_global, 1)) - rng = jax.random.PRNGKey(run_cfg["seed"]) - # single-device state used as template for serialization and eval - single_runner_state = init_runner_state(rng) + base_rng = jax.random.PRNGKey(run_cfg["seed"]) + base_rng = jax.random.fold_in(base_rng, process_idx) + if use_pmap: + init_keys = jax.random.split(base_rng, num_devices) + runner_state = jax.vmap(init_runner_state)(init_keys) + single_runner_state = jax.tree_util.tree_map(lambda x: x[0], runner_state) + else: + single_runner_state = init_runner_state(base_rng) + runner_state = single_runner_state updates_done = 0 + restored_train_state = None - is_primary = jax.process_index() == 0 + is_primary = process_idx == 0 artifact_name = None - if is_primary and HAS_WANDB and wandb.run is not None: + if HAS_WANDB and wandb.run is not None: sweep_id = getattr(wandb.run, "sweep_id", None) artifact_name = checkpoint_artifact_name( run_cfg, @@ -770,16 +780,20 @@ def _train_actor_critic( template = {"runner_state": single_runner_state, "updates_done": 0} payload = serialization.from_bytes(template, checkpoint_path.read_bytes()) single_runner_state = payload["runner_state"] + restored_train_state = payload["runner_state"][0] updates_done = int(payload.get("updates_done", 0)) if updates_done <= 0: updates_done = int(metadata.get("updates_done", 0)) updates_done = max(0, min(updates_done, total_updates)) - if use_pmap: - runner_state = jax.device_put_replicated( - single_runner_state, jax.local_devices() + if use_pmap and restored_train_state is not None: + runner_state = ( + jax.device_put_replicated(restored_train_state, jax.local_devices()), + runner_state[1], + runner_state[2], + runner_state[3], ) - else: + elif not use_pmap: runner_state = single_runner_state metric_keys = ["reward", "revenue", "agent_prob", "alpha_adv", "coi_leakage"] @@ -796,13 +810,14 @@ def _train_actor_critic( metric = out["metrics"] if use_pmap: - # take device-0 slice; shape is (n_devices, segment_updates) segment_values = { - key: np.asarray(metric[key][0], dtype=np.float64) for key in metric_keys + key: np.asarray(metric[key], dtype=np.float64).reshape(-1) + for key in metric_keys } else: segment_values = { - key: np.asarray(metric[key], dtype=np.float64) for key in metric_keys + key: np.asarray(metric[key], dtype=np.float64).reshape(-1) + for key in metric_keys } segment_count = int(segment_values["reward"].shape[0]) if segment_values else 0 @@ -811,7 +826,7 @@ def _train_actor_critic( metric_sums[key] += float(segment_values[key].sum()) updates_done += int(updates_this_segment) - global_step = int(updates_done * rollout_steps) + global_step = int(updates_done * rollout_steps_global) if is_primary and HAS_WANDB and wandb.run is not None: wandb.log( @@ -842,7 +857,7 @@ def _train_actor_critic( metadata={ "step": global_step, "updates_done": updates_done, - "rollout_steps": rollout_steps, + "rollout_steps": rollout_steps_global, "algo": algo, }, ) @@ -863,7 +878,7 @@ def _train_actor_critic( "train/agent_prob": float(metric_sums["agent_prob"] / denom), "train/alpha_adv": float(metric_sums["alpha_adv"] / denom), "train/coi_leakage": float(metric_sums["coi_leakage"] / denom), - "train/global_step": int(updates_done * rollout_steps), + "train/global_step": int(updates_done * rollout_steps_global), } eval_metrics = evaluate_policy( diff --git a/engine/train.py b/engine/train.py index f6b256d..3a7dd2c 100644 --- a/engine/train.py +++ b/engine/train.py @@ -9,10 +9,16 @@ import numpy as np from .wandb_checkpoint import checkpoint_artifact_name, download_latest_checkpoint try: - import wandb + import wandb as _wandb - HAS_WANDB = True + if hasattr(_wandb, "init") and callable(_wandb.init): + wandb = _wandb + HAS_WANDB = True + else: + wandb = None + HAS_WANDB = False except ImportError: + wandb = None HAS_WANDB = False try: @@ -80,7 +86,7 @@ DEFAULT_CFG = { "jax_num_minibatches": 4, "jax_update_epochs": 4, "jax_anneal_lr": True, - "checkpoint_interval": 10_000, + "checkpoint_interval": 200_000, } @@ -404,6 +410,16 @@ def run_wandb( ) -> dict: if not HAS_WANDB: raise ImportError("wandb is required for sweep runs") + if not sweep_mode: + pre_cfg = _cfg(overrides) + if pre_cfg.get("use_jax"): + try: + import jax + + if jax.process_count() > 1 and jax.process_index() != 0: + return train_once(pre_cfg) + except Exception: + pass init_kwargs = {"mode": mode} if sweep_mode: run = wandb.init(**init_kwargs) @@ -431,7 +447,16 @@ def run_wandb( def run_local(overrides: dict) -> dict: cfg = _cfg(overrides) metrics = train_once(cfg) - print(json.dumps(metrics, indent=2)) + should_print = True + if cfg.get("use_jax"): + try: + import jax + + should_print = jax.process_index() == 0 + except Exception: + should_print = True + if should_print: + print(json.dumps(metrics, indent=2)) return metrics @@ -439,15 +464,26 @@ def main(): p = argparse.ArgumentParser(description="PHANTOM training and W&B sweeps") p.add_argument("--project", default=DEFAULT_CFG["project"]) p.add_argument("--algo", choices=["ppo", "a2c", "dqn", "qtable"]) + p.add_argument("--seed", type=int) p.add_argument("--total-timesteps", type=int) p.add_argument("--alpha", type=float) + p.add_argument("--N", type=int) p.add_argument("--n-products", type=int) p.add_argument("--lambda-coi", type=float) + p.add_argument("--info-value", type=float) p.add_argument("--robust-radius", type=float) p.add_argument("--robust-points", type=int) p.add_argument("--learning-rate", type=float) p.add_argument("--gamma", type=float) + p.add_argument("--gae-lambda", type=float) + p.add_argument("--clip-range", type=float) + p.add_argument("--ent-coef", type=float) p.add_argument("--revenue-weight", type=float) + p.add_argument("--price-low", type=float) + p.add_argument("--price-high", type=float) + p.add_argument("--action-levels", type=int) + p.add_argument("--action-scale-low", type=float) + p.add_argument("--action-scale-high", type=float) p.add_argument("--max-steps", type=int) p.add_argument("--margin-floor", type=float) p.add_argument("--margin-floor-patience", type=int) @@ -469,15 +505,26 @@ def main(): overrides = { "algo": args.algo, + "seed": args.seed, "total_timesteps": args.total_timesteps, "alpha": args.alpha, + "N": args.N, "n_products": args.n_products, "lambda_coi": args.lambda_coi, + "info_value": args.info_value, "robust_radius": args.robust_radius, "robust_points": args.robust_points, "learning_rate": args.learning_rate, "gamma": args.gamma, + "gae_lambda": args.gae_lambda, + "clip_range": args.clip_range, + "ent_coef": args.ent_coef, "revenue_weight": args.revenue_weight, + "price_low": args.price_low, + "price_high": args.price_high, + "action_levels": args.action_levels, + "action_scale_low": args.action_scale_low, + "action_scale_high": args.action_scale_high, "max_steps": args.max_steps, "margin_floor": args.margin_floor, "margin_floor_patience": args.margin_floor_patience,