cleaning up jax bs

This commit is contained in:
2026-03-08 19:15:58 +01:00
parent 73246d7dd8
commit 4c658a93a7
27 changed files with 173 additions and 3146 deletions

View File

@@ -108,49 +108,6 @@ PY
image_ref="${TRAIN_IMAGE_REF:-us-central1-docker.pkg.dev/phantom-trc/phantom/phantom-trainer}"
docker build -f docker/Trainer.dockerfile --target gpu -t "$image_ref:gpu-latest" .
docker push "$image_ref:gpu-latest"
docker build -f docker/Trainer.dockerfile --target tpu -t "$image_ref:tpu-latest" .
docker push "$image_ref:tpu-latest"
;;
train-tpu-pod)
load_sweep_env
require_var TPU_NAME "TPU_NAME required, e.g. TPU_NAME=TPUlong"
require_var SWEEP_ID "SWEEP_ID required, e.g. SWEEP_ID=entity/project/id"
require_var WANDB_API_KEY "WANDB_API_KEY required - set it in $env_file"
gcloud compute tpus tpu-vm scp scripts/tpu_pod_run.sh "$TPU_NAME":/tmp/tpu_pod_run.sh --zone="${TPU_ZONE:-us-central2-b}" --project="${TPU_PROJECT:-phantom-trc}" --worker=all
gcloud compute tpus tpu-vm ssh "$TPU_NAME" --zone="${TPU_ZONE:-us-central2-b}" --project="${TPU_PROJECT:-phantom-trc}" --worker=all --command="WANDB_API_KEY='$WANDB_API_KEY' SWEEP_ID='$SWEEP_ID' AGENT_COUNT='${AGENT_COUNT:-0}' sh /tmp/tpu_pod_run.sh"
;;
train-tpu-vm-prepare)
require_var TPU_NAME "TPU_NAME required, e.g. TPU_NAME=TPUlong"
TPU_NAME="$TPU_NAME" \
TPU_ZONE="${TPU_ZONE:-us-central2-b}" \
TPU_PROJECT="${TPU_PROJECT:-phantom-trc}" \
LOCAL_REPO_DIR="$PWD" \
REMOTE_REPO_DIR="${TPU_REPO_DIR:-/tmp/PHANTOM}" \
sh scripts/tpu_sync_repo.sh
gcloud compute tpus tpu-vm scp scripts/tpu_vm_train.sh "$TPU_NAME":/tmp/tpu_vm_train.sh --zone="${TPU_ZONE:-us-central2-b}" --project="${TPU_PROJECT:-phantom-trc}" --worker=all
;;
train-tpu-vm-run)
load_sweep_env
require_var TPU_NAME "TPU_NAME required, e.g. TPU_NAME=TPUlong"
require_var LOCAL_TRAIN_ARGS "LOCAL_TRAIN_ARGS required, e.g. --algo ppo --jax --total-timesteps 200000"
gcloud compute tpus tpu-vm ssh "$TPU_NAME" --zone="${TPU_ZONE:-us-central2-b}" --project="${TPU_PROJECT:-phantom-trc}" --worker=all --command="REPO_DIR='${TPU_REPO_DIR:-/tmp/PHANTOM}' TRAIN_ARGS='${LOCAL_TRAIN_ARGS}' WANDB_API_KEY='${WANDB_API_KEY:-}' sh /tmp/tpu_vm_train.sh"
;;
train-tpu-vm-sweep)
load_sweep_env
require_var TPU_NAME "TPU_NAME required, e.g. TPU_NAME=TPUlong"
require_var SWEEP_ID "SWEEP_ID required, e.g. SWEEP_ID=lusiana/capstone/abc123"
require_var WANDB_API_KEY "WANDB_API_KEY required - set it in $env_file"
args=(
--sweep-id "$SWEEP_ID"
--tpu-name "$TPU_NAME"
--tpu-zone "${TPU_ZONE:-us-central2-b}"
--tpu-project "${TPU_PROJECT:-phantom-trc}"
--tpu-repo-dir "${TPU_REPO_DIR:-/tmp/PHANTOM}"
)
if [ -n "${AGENT_COUNT:-}" ] && [ "${AGENT_COUNT}" != "0" ]; then
args+=(--count "$AGENT_COUNT")
fi
WANDB_API_KEY="$WANDB_API_KEY" python3 scripts/tpu_vm_sweep_agent.py "${args[@]}"
;;
*)
printf '%s\n' "Unknown research command: $cmd" >&2

View File

@@ -1,32 +0,0 @@
#!/usr/bin/env sh
# Executed on each TPU pod worker via `gcloud tpu-vm scp` + `gcloud tpu-vm ssh --worker=all`.
# Authenticates with Artifact Registry using the VM's service account metadata token,
# pulls the TPU trainer image, then runs the W&B sweep agent inside Docker.
# TPU chip devices (/dev/accel*) are exposed via --privileged + /dev volume mount.
# Required env vars: WANDB_API_KEY, SWEEP_ID
# Optional: AGENT_COUNT (default 1, 0 = run until sweep ends)
set -eu
IMAGE="us-central1-docker.pkg.dev/phantom-trc/phantom/phantom-trainer:tpu-latest"
AGENT_COUNT="${AGENT_COUNT:-1}"
# use VM service account — no manual key needed on the pod
TOKEN=$(curl -sf -H "Metadata-Flavor: Google" \
"http://metadata.google.internal/computeMetadata/v1/instance/service-accounts/default/token" \
| python3 -c 'import sys, json; print(json.load(sys.stdin)["access_token"])')
echo "$TOKEN" | sudo docker login -u oauth2accesstoken \
--password-stdin https://us-central1-docker.pkg.dev
sudo docker pull "$IMAGE"
# --privileged + /dev mount gives the container access to /dev/accel* (TPU chips)
# --network host lets JAX reach the other pod workers for distributed init
sudo docker run --rm \
--privileged \
--network host \
--volume /dev:/dev \
-e WANDB_API_KEY="$WANDB_API_KEY" \
-e SWEEP_ID="$SWEEP_ID" \
-e AGENT_COUNT="$AGENT_COUNT" \
"$IMAGE"

View File

@@ -1,83 +0,0 @@
#!/usr/bin/env sh
set -eu
TPU_NAME="${TPU_NAME:?TPU_NAME is required}"
TPU_ZONE="${TPU_ZONE:-us-central2-b}"
TPU_PROJECT="${TPU_PROJECT:-phantom-trc}"
LOCAL_REPO_DIR="${LOCAL_REPO_DIR:-$(pwd)}"
REMOTE_REPO_DIR="${REMOTE_REPO_DIR:-/tmp/PHANTOM}"
ARCHIVE_PATH="${ARCHIVE_PATH:-/tmp/phantom-sync.tgz}"
FILE_LIST="$(mktemp /tmp/phantom-sync-files.XXXXXX)"
CLEANUP_LIST=true
cleanup() {
if [ "$CLEANUP_LIST" = "true" ]; then
rm -f "$FILE_LIST"
fi
}
trap cleanup EXIT
if [ ! -d "$LOCAL_REPO_DIR" ]; then
echo "local repo directory not found: $LOCAL_REPO_DIR"
exit 1
fi
if git -C "$LOCAL_REPO_DIR" rev-parse --is-inside-work-tree >/dev/null 2>&1; then
git -C "$LOCAL_REPO_DIR" ls-files -co --exclude-standard > "$FILE_LIST"
python3 - "$FILE_LIST" <<'PY'
import sys
from pathlib import Path
file_list = Path(sys.argv[1])
skip_prefixes = (
"wandb/",
".venv/",
"venv/",
"node_modules/",
".next/",
".turbo/",
"__pycache__/",
".mypy_cache/",
".pytest_cache/",
".ruff_cache/",
"paper/build/",
"tests/e2e/test-results/",
)
rows = file_list.read_text().splitlines()
kept = [
row
for row in rows
if row and not any(row == p.rstrip("/") or row.startswith(p) for p in skip_prefixes)
]
file_list.write_text("\n".join(kept) + ("\n" if kept else ""))
PY
tar -czf "$ARCHIVE_PATH" -C "$LOCAL_REPO_DIR" -T "$FILE_LIST"
else
tar \
--exclude-vcs \
--exclude=".venv" --exclude="*/.venv" \
--exclude="venv" --exclude="*/venv" \
--exclude="node_modules" --exclude="*/node_modules" \
--exclude=".next" --exclude="*/.next" \
--exclude=".turbo" --exclude="*/.turbo" \
--exclude="__pycache__" --exclude="*/__pycache__" \
--exclude=".mypy_cache" --exclude="*/.mypy_cache" \
--exclude=".pytest_cache" --exclude="*/.pytest_cache" \
--exclude=".ruff_cache" --exclude="*/.ruff_cache" \
--exclude="wandb" --exclude="*/wandb" \
--exclude="paper/build" \
--exclude="tests/e2e/test-results" \
-czf "$ARCHIVE_PATH" \
-C "$LOCAL_REPO_DIR" .
fi
gcloud compute tpus tpu-vm scp "$ARCHIVE_PATH" "$TPU_NAME:/tmp/phantom-sync.tgz" \
--zone="$TPU_ZONE" --project="$TPU_PROJECT" --worker=all
gcloud compute tpus tpu-vm ssh "$TPU_NAME" \
--zone="$TPU_ZONE" --project="$TPU_PROJECT" --worker=all \
--command="rm -rf '$REMOTE_REPO_DIR' && mkdir -p '$REMOTE_REPO_DIR' && tar -xzf /tmp/phantom-sync.tgz -C '$REMOTE_REPO_DIR' && rm -f /tmp/phantom-sync.tgz"
rm -f "$ARCHIVE_PATH"

View File

@@ -1,211 +0,0 @@
#!/usr/bin/env python3
from __future__ import annotations
import argparse
import gc
import json
import os
import re
import shlex
import shutil
import subprocess
import time
import resource
from pathlib import Path
import wandb
CLI_MAP: dict[str, str] = {
"algo": "--algo",
"total_timesteps": "--total-timesteps",
"alpha": "--alpha",
"N": "--N",
"n_products": "--n-products",
"lambda_coi": "--lambda-coi",
"info_value": "--info-value",
"robust_radius": "--robust-radius",
"robust_points": "--robust-points",
"no_robust": "--no-robust",
"learning_rate": "--learning-rate",
"gamma": "--gamma",
"gae_lambda": "--gae-lambda",
"clip_range": "--clip-range",
"ent_coef": "--ent-coef",
"revenue_weight": "--revenue-weight",
"max_steps": "--max-steps",
"margin_floor": "--margin-floor",
"margin_floor_patience": "--margin-floor-patience",
"arch": "--arch",
"activation": "--activation",
"jax_num_envs": "--jax-num-envs",
"jax_num_steps": "--jax-num-steps",
"jax_num_minibatches": "--jax-num-minibatches",
"jax_update_epochs": "--jax-update-epochs",
"jax_anneal_lr": "--jax-anneal-lr",
"checkpoint_interval": "--checkpoint-interval",
"action_levels": "--action-levels",
"action_scale_low": "--action-scale-low",
"action_scale_high": "--action-scale-high",
}
def _to_cli_args(cfg: dict) -> str:
parts: list[str] = ["--jax", "--no-wandb"]
for key, flag in CLI_MAP.items():
if key not in cfg:
continue
value = cfg[key]
if value is None:
continue
if isinstance(value, bool):
if key == "jax_anneal_lr":
parts.extend([flag, "true" if value else "false"])
elif value:
parts.append(flag)
continue
parts.extend([flag, str(value)])
return " ".join(shlex.quote(p) for p in parts)
_SENTINEL = "PHANTOM_METRICS:"
def _raise_nofile_limit(min_soft: int = 8192) -> None:
try:
soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
target = min(hard, max(soft, min_soft))
if target > soft:
resource.setrlimit(resource.RLIMIT_NOFILE, (target, hard))
except Exception:
return
def _extract_metrics(output: str) -> dict:
# fast path: look for the dedicated sentinel line emitted by run_local
for line in output.splitlines():
if line.startswith(_SENTINEL):
try:
return json.loads(line[len(_SENTINEL) :])
except Exception:
break
# fallback: scan for any JSON block containing eval/sweep keys;
# use greedy match to capture the largest possible block first
for block in re.findall(r"\{[^{}]*\}", output):
try:
obj = json.loads(block)
except Exception:
continue
if isinstance(obj, dict) and (
"objective/score" in obj
or "eval/reward_mean" in obj
or "sweep/score" in obj
):
return obj
return {}
def main() -> None:
_raise_nofile_limit()
p = argparse.ArgumentParser(
description="Run W&B sweep where each trial uses full TPU pod"
)
p.add_argument("--sweep-id", required=True)
p.add_argument("--tpu-name", required=True)
p.add_argument("--tpu-zone", default="us-central2-b")
p.add_argument("--tpu-project", default="phantom-trc")
p.add_argument("--tpu-repo-dir", default="/tmp/PHANTOM")
p.add_argument("--count", type=int, default=0)
p.add_argument("--workdir", default=str(Path(__file__).resolve().parents[1]))
args = p.parse_args()
workdir = Path(args.workdir).resolve()
env = os.environ.copy()
wandb_root = workdir / ".wandb-agent"
wandb_root.mkdir(parents=True, exist_ok=True)
prepare_cmd = [
"make",
"train.tpu.vm.prepare",
f"TPU_NAME={args.tpu_name}",
f"TPU_ZONE={args.tpu_zone}",
f"TPU_PROJECT={args.tpu_project}",
f"TPU_REPO_DIR={args.tpu_repo_dir}",
]
prepare = subprocess.run(
prepare_cmd,
cwd=workdir,
env=env,
text=True,
capture_output=False,
check=False,
)
if prepare.returncode != 0:
raise RuntimeError("Failed to prepare TPU workers for sweep")
def run_trial() -> None:
run = None
trial_wandb_dir = wandb_root / f"trial-{time.time_ns()}"
trial_wandb_dir.mkdir(parents=True, exist_ok=True)
try:
run = wandb.init(dir=str(trial_wandb_dir))
cfg = dict(wandb.config)
cli_args = _to_cli_args(cfg)
env_trial = dict(env)
env_trial["LOCAL_TRAIN_ARGS"] = cli_args
env_trial["WANDB_DIR"] = str(trial_wandb_dir)
env_trial["WANDB_CACHE_DIR"] = str(trial_wandb_dir / "cache")
env_trial["WANDB_DATA_DIR"] = str(trial_wandb_dir / "data")
cmd = [
"make",
"train.tpu.vm.run",
f"TPU_NAME={args.tpu_name}",
f"TPU_ZONE={args.tpu_zone}",
f"TPU_PROJECT={args.tpu_project}",
f"TPU_REPO_DIR={args.tpu_repo_dir}",
]
proc = subprocess.run(
cmd,
cwd=workdir,
env=env_trial,
text=True,
capture_output=True,
check=False,
)
if proc.stdout:
print(proc.stdout)
if proc.stderr:
print(proc.stderr)
if proc.returncode != 0:
if run is not None:
run.summary["runner/exit_code"] = proc.returncode
raise RuntimeError(f"TPU trial failed with exit code {proc.returncode}")
metrics = _extract_metrics(proc.stdout)
if metrics:
wandb.log(metrics)
for k, v in metrics.items():
run.summary[k] = v
run.summary["runner/exit_code"] = 0
except Exception:
time.sleep(2)
raise
finally:
if run is not None and wandb.run is not None:
wandb.finish()
shutil.rmtree(trial_wandb_dir, ignore_errors=True)
gc.collect()
wandb.agent(
args.sweep_id,
function=run_trial,
count=args.count if args.count > 0 else None,
)
if __name__ == "__main__":
main()

View File

@@ -1,43 +0,0 @@
#!/usr/bin/env sh
set -eu
REPO_DIR="${REPO_DIR:-$HOME/PHANTOM}"
PYTHON_BIN="${PYTHON_BIN:-python3}"
TRAIN_ARGS="${TRAIN_ARGS:---algo ppo --jax --total-timesteps 200000 --jax-num-envs 32 --jax-num-steps 128 --jax-num-minibatches 4 --jax-update-epochs 4}"
EXTRA_PIP="${EXTRA_PIP:-flax optax distrax}"
INSTALL_FULL_REQUIREMENTS="${INSTALL_FULL_REQUIREMENTS:-0}"
if [ ! -d "$REPO_DIR" ]; then
echo "repo directory not found: $REPO_DIR"
exit 1
fi
cd "$REPO_DIR"
if [ -d "wandb" ]; then
rm -rf wandb
fi
# keep install idempotent and avoid re-installing jax/libtpu each run
if [ "$INSTALL_FULL_REQUIREMENTS" = "1" ] && [ -f "requirements.txt" ]; then
$PYTHON_BIN -m pip install -r requirements.txt
fi
if ! $PYTHON_BIN -c 'import flax, optax, distrax' >/dev/null 2>&1; then
if [ -f "engine/jax/requirements.txt" ]; then
$PYTHON_BIN -m pip install -r engine/jax/requirements.txt
fi
$PYTHON_BIN -m pip install -U $EXTRA_PIP
fi
if [ -n "${WANDB_API_KEY:-}" ]; then
if ! $PYTHON_BIN -c 'import wandb; import inspect; assert hasattr(wandb, "init") and callable(wandb.init)' >/dev/null 2>&1; then
$PYTHON_BIN -m pip install -U wandb
fi
fi
if [ -n "${WANDB_API_KEY:-}" ]; then
export WANDB_API_KEY
exec $PYTHON_BIN -m engine.train $TRAIN_ARGS
fi
exec $PYTHON_BIN -m engine.train $TRAIN_ARGS --no-wandb