mirror of
https://github.com/velocitatem/PHANTOM.git
synced 2026-05-31 08:33:36 +00:00
catchup: rogue scripts
This commit is contained in:
32
scripts/tpu_pod_run.sh
Executable file
32
scripts/tpu_pod_run.sh
Executable file
@@ -0,0 +1,32 @@
|
||||
#!/usr/bin/env sh
|
||||
# Executed on each TPU pod worker via `gcloud tpu-vm scp` + `gcloud tpu-vm ssh --worker=all`.
|
||||
# Authenticates with Artifact Registry using the VM's service account metadata token,
|
||||
# pulls the TPU trainer image, then runs the W&B sweep agent inside Docker.
|
||||
# TPU chip devices (/dev/accel*) are exposed via --privileged + /dev volume mount.
|
||||
# Required env vars: WANDB_API_KEY, SWEEP_ID
|
||||
# Optional: AGENT_COUNT (default 1, 0 = run until sweep ends)
|
||||
set -eu
|
||||
|
||||
IMAGE="us-central1-docker.pkg.dev/phantom-trc/phantom/phantom-trainer:tpu-latest"
|
||||
AGENT_COUNT="${AGENT_COUNT:-1}"
|
||||
|
||||
# use VM service account — no manual key needed on the pod
|
||||
TOKEN=$(curl -sf -H "Metadata-Flavor: Google" \
|
||||
"http://metadata.google.internal/computeMetadata/v1/instance/service-accounts/default/token" \
|
||||
| python3 -c 'import sys, json; print(json.load(sys.stdin)["access_token"])')
|
||||
|
||||
echo "$TOKEN" | sudo docker login -u oauth2accesstoken \
|
||||
--password-stdin https://us-central1-docker.pkg.dev
|
||||
|
||||
sudo docker pull "$IMAGE"
|
||||
|
||||
# --privileged + /dev mount gives the container access to /dev/accel* (TPU chips)
|
||||
# --network host lets JAX reach the other pod workers for distributed init
|
||||
sudo docker run --rm \
|
||||
--privileged \
|
||||
--network host \
|
||||
--volume /dev:/dev \
|
||||
-e WANDB_API_KEY="$WANDB_API_KEY" \
|
||||
-e SWEEP_ID="$SWEEP_ID" \
|
||||
-e AGENT_COUNT="$AGENT_COUNT" \
|
||||
"$IMAGE"
|
||||
83
scripts/tpu_sync_repo.sh
Normal file
83
scripts/tpu_sync_repo.sh
Normal file
@@ -0,0 +1,83 @@
|
||||
#!/usr/bin/env sh
|
||||
set -eu
|
||||
|
||||
TPU_NAME="${TPU_NAME:?TPU_NAME is required}"
|
||||
TPU_ZONE="${TPU_ZONE:-us-central2-b}"
|
||||
TPU_PROJECT="${TPU_PROJECT:-phantom-trc}"
|
||||
LOCAL_REPO_DIR="${LOCAL_REPO_DIR:-$(pwd)}"
|
||||
REMOTE_REPO_DIR="${REMOTE_REPO_DIR:-/tmp/PHANTOM}"
|
||||
ARCHIVE_PATH="${ARCHIVE_PATH:-/tmp/phantom-sync.tgz}"
|
||||
|
||||
FILE_LIST="$(mktemp /tmp/phantom-sync-files.XXXXXX)"
|
||||
CLEANUP_LIST=true
|
||||
|
||||
cleanup() {
|
||||
if [ "$CLEANUP_LIST" = "true" ]; then
|
||||
rm -f "$FILE_LIST"
|
||||
fi
|
||||
}
|
||||
trap cleanup EXIT
|
||||
|
||||
if [ ! -d "$LOCAL_REPO_DIR" ]; then
|
||||
echo "local repo directory not found: $LOCAL_REPO_DIR"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if git -C "$LOCAL_REPO_DIR" rev-parse --is-inside-work-tree >/dev/null 2>&1; then
|
||||
git -C "$LOCAL_REPO_DIR" ls-files -co --exclude-standard > "$FILE_LIST"
|
||||
python3 - "$FILE_LIST" <<'PY'
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
file_list = Path(sys.argv[1])
|
||||
skip_prefixes = (
|
||||
"wandb/",
|
||||
".venv/",
|
||||
"venv/",
|
||||
"node_modules/",
|
||||
".next/",
|
||||
".turbo/",
|
||||
"__pycache__/",
|
||||
".mypy_cache/",
|
||||
".pytest_cache/",
|
||||
".ruff_cache/",
|
||||
"paper/build/",
|
||||
"tests/e2e/test-results/",
|
||||
)
|
||||
|
||||
rows = file_list.read_text().splitlines()
|
||||
kept = [
|
||||
row
|
||||
for row in rows
|
||||
if row and not any(row == p.rstrip("/") or row.startswith(p) for p in skip_prefixes)
|
||||
]
|
||||
file_list.write_text("\n".join(kept) + ("\n" if kept else ""))
|
||||
PY
|
||||
tar -czf "$ARCHIVE_PATH" -C "$LOCAL_REPO_DIR" -T "$FILE_LIST"
|
||||
else
|
||||
tar \
|
||||
--exclude-vcs \
|
||||
--exclude=".venv" --exclude="*/.venv" \
|
||||
--exclude="venv" --exclude="*/venv" \
|
||||
--exclude="node_modules" --exclude="*/node_modules" \
|
||||
--exclude=".next" --exclude="*/.next" \
|
||||
--exclude=".turbo" --exclude="*/.turbo" \
|
||||
--exclude="__pycache__" --exclude="*/__pycache__" \
|
||||
--exclude=".mypy_cache" --exclude="*/.mypy_cache" \
|
||||
--exclude=".pytest_cache" --exclude="*/.pytest_cache" \
|
||||
--exclude=".ruff_cache" --exclude="*/.ruff_cache" \
|
||||
--exclude="wandb" --exclude="*/wandb" \
|
||||
--exclude="paper/build" \
|
||||
--exclude="tests/e2e/test-results" \
|
||||
-czf "$ARCHIVE_PATH" \
|
||||
-C "$LOCAL_REPO_DIR" .
|
||||
fi
|
||||
|
||||
gcloud compute tpus tpu-vm scp "$ARCHIVE_PATH" "$TPU_NAME:/tmp/phantom-sync.tgz" \
|
||||
--zone="$TPU_ZONE" --project="$TPU_PROJECT" --worker=all
|
||||
|
||||
gcloud compute tpus tpu-vm ssh "$TPU_NAME" \
|
||||
--zone="$TPU_ZONE" --project="$TPU_PROJECT" --worker=all \
|
||||
--command="rm -rf '$REMOTE_REPO_DIR' && mkdir -p '$REMOTE_REPO_DIR' && tar -xzf /tmp/phantom-sync.tgz -C '$REMOTE_REPO_DIR' && rm -f /tmp/phantom-sync.tgz"
|
||||
|
||||
rm -f "$ARCHIVE_PATH"
|
||||
183
scripts/tpu_vm_sweep_agent.py
Normal file
183
scripts/tpu_vm_sweep_agent.py
Normal file
@@ -0,0 +1,183 @@
|
||||
#!/usr/bin/env python3
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import shlex
|
||||
import subprocess
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import wandb
|
||||
|
||||
|
||||
CLI_MAP: dict[str, str] = {
|
||||
"algo": "--algo",
|
||||
"total_timesteps": "--total-timesteps",
|
||||
"alpha": "--alpha",
|
||||
"N": "--N",
|
||||
"n_products": "--n-products",
|
||||
"lambda_coi": "--lambda-coi",
|
||||
"info_value": "--info-value",
|
||||
"robust_radius": "--robust-radius",
|
||||
"robust_points": "--robust-points",
|
||||
"learning_rate": "--learning-rate",
|
||||
"gamma": "--gamma",
|
||||
"gae_lambda": "--gae-lambda",
|
||||
"clip_range": "--clip-range",
|
||||
"ent_coef": "--ent-coef",
|
||||
"revenue_weight": "--revenue-weight",
|
||||
"max_steps": "--max-steps",
|
||||
"margin_floor": "--margin-floor",
|
||||
"margin_floor_patience": "--margin-floor-patience",
|
||||
"arch": "--arch",
|
||||
"activation": "--activation",
|
||||
"jax_num_envs": "--jax-num-envs",
|
||||
"jax_num_steps": "--jax-num-steps",
|
||||
"jax_num_minibatches": "--jax-num-minibatches",
|
||||
"jax_update_epochs": "--jax-update-epochs",
|
||||
"jax_anneal_lr": "--jax-anneal-lr",
|
||||
"checkpoint_interval": "--checkpoint-interval",
|
||||
"action_levels": "--action-levels",
|
||||
"action_scale_low": "--action-scale-low",
|
||||
"action_scale_high": "--action-scale-high",
|
||||
}
|
||||
|
||||
|
||||
def _to_cli_args(cfg: dict) -> str:
|
||||
parts: list[str] = ["--jax", "--no-wandb"]
|
||||
for key, flag in CLI_MAP.items():
|
||||
if key not in cfg:
|
||||
continue
|
||||
value = cfg[key]
|
||||
if value is None:
|
||||
continue
|
||||
if isinstance(value, bool):
|
||||
if key == "jax_anneal_lr":
|
||||
parts.extend([flag, "true" if value else "false"])
|
||||
elif value:
|
||||
parts.append(flag)
|
||||
continue
|
||||
parts.extend([flag, str(value)])
|
||||
return " ".join(shlex.quote(p) for p in parts)
|
||||
|
||||
|
||||
_SENTINEL = "PHANTOM_METRICS:"
|
||||
|
||||
|
||||
def _extract_metrics(output: str) -> dict:
|
||||
# fast path: look for the dedicated sentinel line emitted by run_local
|
||||
for line in output.splitlines():
|
||||
if line.startswith(_SENTINEL):
|
||||
try:
|
||||
return json.loads(line[len(_SENTINEL) :])
|
||||
except Exception:
|
||||
break
|
||||
# fallback: scan for any JSON block containing eval/sweep keys;
|
||||
# use greedy match to capture the largest possible block first
|
||||
for block in re.findall(r"\{[^{}]*\}", output):
|
||||
try:
|
||||
obj = json.loads(block)
|
||||
except Exception:
|
||||
continue
|
||||
if isinstance(obj, dict) and ("sweep/score" in obj or "eval/reward" in obj):
|
||||
return obj
|
||||
return {}
|
||||
|
||||
|
||||
def main() -> None:
|
||||
p = argparse.ArgumentParser(
|
||||
description="Run W&B sweep where each trial uses full TPU pod"
|
||||
)
|
||||
p.add_argument("--sweep-id", required=True)
|
||||
p.add_argument("--tpu-name", required=True)
|
||||
p.add_argument("--tpu-zone", default="us-central2-b")
|
||||
p.add_argument("--tpu-project", default="phantom-trc")
|
||||
p.add_argument("--tpu-repo-dir", default="/tmp/PHANTOM")
|
||||
p.add_argument("--count", type=int, default=0)
|
||||
p.add_argument("--workdir", default=str(Path(__file__).resolve().parents[1]))
|
||||
args = p.parse_args()
|
||||
|
||||
workdir = Path(args.workdir).resolve()
|
||||
env = os.environ.copy()
|
||||
|
||||
prepare_cmd = [
|
||||
"make",
|
||||
"train.tpu.vm.prepare",
|
||||
f"TPU_NAME={args.tpu_name}",
|
||||
f"TPU_ZONE={args.tpu_zone}",
|
||||
f"TPU_PROJECT={args.tpu_project}",
|
||||
f"TPU_REPO_DIR={args.tpu_repo_dir}",
|
||||
]
|
||||
prepare = subprocess.run(
|
||||
prepare_cmd,
|
||||
cwd=workdir,
|
||||
env=env,
|
||||
text=True,
|
||||
capture_output=False,
|
||||
check=False,
|
||||
)
|
||||
if prepare.returncode != 0:
|
||||
raise RuntimeError("Failed to prepare TPU workers for sweep")
|
||||
|
||||
def run_trial() -> None:
|
||||
run = None
|
||||
try:
|
||||
run = wandb.init()
|
||||
cfg = dict(wandb.config)
|
||||
cli_args = _to_cli_args(cfg)
|
||||
env_trial = dict(env)
|
||||
env_trial["LOCAL_TRAIN_ARGS"] = cli_args
|
||||
|
||||
cmd = [
|
||||
"make",
|
||||
"train.tpu.vm.run",
|
||||
f"TPU_NAME={args.tpu_name}",
|
||||
f"TPU_ZONE={args.tpu_zone}",
|
||||
f"TPU_PROJECT={args.tpu_project}",
|
||||
f"TPU_REPO_DIR={args.tpu_repo_dir}",
|
||||
]
|
||||
|
||||
proc = subprocess.run(
|
||||
cmd,
|
||||
cwd=workdir,
|
||||
env=env_trial,
|
||||
text=True,
|
||||
capture_output=True,
|
||||
check=False,
|
||||
)
|
||||
|
||||
if proc.stdout:
|
||||
print(proc.stdout)
|
||||
if proc.stderr:
|
||||
print(proc.stderr)
|
||||
|
||||
if proc.returncode != 0:
|
||||
if run is not None:
|
||||
run.summary["runner/exit_code"] = proc.returncode
|
||||
raise RuntimeError(f"TPU trial failed with exit code {proc.returncode}")
|
||||
|
||||
metrics = _extract_metrics(proc.stdout)
|
||||
if metrics:
|
||||
wandb.log(metrics)
|
||||
for k, v in metrics.items():
|
||||
run.summary[k] = v
|
||||
run.summary["runner/exit_code"] = 0
|
||||
except Exception:
|
||||
time.sleep(2)
|
||||
raise
|
||||
finally:
|
||||
if run is not None and wandb.run is not None:
|
||||
wandb.finish()
|
||||
|
||||
wandb.agent(
|
||||
args.sweep_id,
|
||||
function=run_trial,
|
||||
count=args.count if args.count > 0 else None,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
43
scripts/tpu_vm_train.sh
Normal file
43
scripts/tpu_vm_train.sh
Normal file
@@ -0,0 +1,43 @@
|
||||
#!/usr/bin/env sh
|
||||
set -eu
|
||||
|
||||
REPO_DIR="${REPO_DIR:-$HOME/PHANTOM}"
|
||||
PYTHON_BIN="${PYTHON_BIN:-python3}"
|
||||
TRAIN_ARGS="${TRAIN_ARGS:---algo ppo --jax --total-timesteps 200000 --jax-num-envs 32 --jax-num-steps 128 --jax-num-minibatches 4 --jax-update-epochs 4}"
|
||||
EXTRA_PIP="${EXTRA_PIP:-flax optax distrax}"
|
||||
INSTALL_FULL_REQUIREMENTS="${INSTALL_FULL_REQUIREMENTS:-0}"
|
||||
|
||||
if [ ! -d "$REPO_DIR" ]; then
|
||||
echo "repo directory not found: $REPO_DIR"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
cd "$REPO_DIR"
|
||||
|
||||
if [ -d "wandb" ]; then
|
||||
rm -rf wandb
|
||||
fi
|
||||
|
||||
# keep install idempotent and avoid re-installing jax/libtpu each run
|
||||
if [ "$INSTALL_FULL_REQUIREMENTS" = "1" ] && [ -f "requirements.txt" ]; then
|
||||
$PYTHON_BIN -m pip install -r requirements.txt
|
||||
fi
|
||||
if ! $PYTHON_BIN -c 'import flax, optax, distrax' >/dev/null 2>&1; then
|
||||
if [ -f "engine/jax/requirements.txt" ]; then
|
||||
$PYTHON_BIN -m pip install -r engine/jax/requirements.txt
|
||||
fi
|
||||
$PYTHON_BIN -m pip install -U $EXTRA_PIP
|
||||
fi
|
||||
|
||||
if [ -n "${WANDB_API_KEY:-}" ]; then
|
||||
if ! $PYTHON_BIN -c 'import wandb; import inspect; assert hasattr(wandb, "init") and callable(wandb.init)' >/dev/null 2>&1; then
|
||||
$PYTHON_BIN -m pip install -U wandb
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ -n "${WANDB_API_KEY:-}" ]; then
|
||||
export WANDB_API_KEY
|
||||
exec $PYTHON_BIN -m engine.train $TRAIN_ARGS
|
||||
fi
|
||||
|
||||
exec $PYTHON_BIN -m engine.train $TRAIN_ARGS --no-wandb
|
||||
108
scripts/wandb_agent_bootstrap.sh
Executable file
108
scripts/wandb_agent_bootstrap.sh
Executable file
@@ -0,0 +1,108 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
need_env() {
|
||||
local name="$1"
|
||||
if [ -z "${!name:-}" ]; then
|
||||
echo "$name is required"
|
||||
exit 1
|
||||
fi
|
||||
}
|
||||
|
||||
need_cmd() {
|
||||
local c="$1"
|
||||
command -v "$c" >/dev/null 2>&1 || {
|
||||
echo "Missing command: $c"
|
||||
exit 1
|
||||
}
|
||||
}
|
||||
|
||||
need_cmd git
|
||||
need_cmd python3
|
||||
|
||||
need_env WANDB_API_KEY
|
||||
need_env GITHUB_TOKEN
|
||||
need_env REPO_URL
|
||||
need_env SWEEP_ID
|
||||
|
||||
BRANCH="${BRANCH:-main}"
|
||||
WORKDIR="${WORKDIR:-$HOME/PHANTOM-agent}"
|
||||
AGENT_COUNT="${AGENT_COUNT:-0}"
|
||||
AGENT_LOOP="${AGENT_LOOP:-1}"
|
||||
RETRY_SECONDS="${RETRY_SECONDS:-20}"
|
||||
PYTHON_BIN="${PYTHON_BIN:-python3}"
|
||||
|
||||
mkdir -p "$(dirname "$WORKDIR")"
|
||||
|
||||
ASKPASS_FILE="$(mktemp)"
|
||||
cat >"$ASKPASS_FILE" <<'EOF'
|
||||
#!/usr/bin/env sh
|
||||
case "$1" in
|
||||
*Username*) echo "x-access-token" ;;
|
||||
*Password*) echo "$GITHUB_TOKEN" ;;
|
||||
*) echo "" ;;
|
||||
esac
|
||||
EOF
|
||||
chmod 700 "$ASKPASS_FILE"
|
||||
|
||||
cleanup() {
|
||||
rm -f "$ASKPASS_FILE"
|
||||
}
|
||||
trap cleanup EXIT
|
||||
|
||||
git_auth() {
|
||||
GIT_TERMINAL_PROMPT=0 GIT_ASKPASS="$ASKPASS_FILE" git "$@"
|
||||
}
|
||||
|
||||
sync_repo() {
|
||||
if [ ! -d "$WORKDIR/.git" ]; then
|
||||
rm -rf "$WORKDIR"
|
||||
git_auth clone --single-branch --branch "$BRANCH" "$REPO_URL" "$WORKDIR"
|
||||
return
|
||||
fi
|
||||
|
||||
git -C "$WORKDIR" remote set-url origin "$REPO_URL"
|
||||
git_auth -C "$WORKDIR" fetch origin "$BRANCH" --prune
|
||||
git -C "$WORKDIR" checkout -B "$BRANCH" "origin/$BRANCH"
|
||||
git -C "$WORKDIR" reset --hard "origin/$BRANCH"
|
||||
}
|
||||
|
||||
install_deps() {
|
||||
"$PYTHON_BIN" -m venv "$WORKDIR/.venv"
|
||||
"$WORKDIR/.venv/bin/pip" install --upgrade pip
|
||||
"$WORKDIR/.venv/bin/pip" install -r "$WORKDIR/requirements.txt"
|
||||
}
|
||||
|
||||
run_agent() {
|
||||
local cmd=("$WORKDIR/.venv/bin/python" -m engine.train --sweep-agent --sweep-id "$SWEEP_ID")
|
||||
if [ "$AGENT_COUNT" != "0" ]; then
|
||||
cmd+=(--count "$AGENT_COUNT")
|
||||
fi
|
||||
|
||||
(
|
||||
cd "$WORKDIR"
|
||||
WANDB_API_KEY="$WANDB_API_KEY" \
|
||||
WANDB_ENTITY="${WANDB_ENTITY:-}" \
|
||||
WANDB_PROJECT="${WANDB_PROJECT:-}" \
|
||||
"${cmd[@]}"
|
||||
)
|
||||
}
|
||||
|
||||
while true; do
|
||||
sync_repo
|
||||
install_deps
|
||||
|
||||
if run_agent; then
|
||||
if [ "$AGENT_LOOP" = "1" ] && [ "$AGENT_COUNT" = "0" ]; then
|
||||
sleep "$RETRY_SECONDS"
|
||||
continue
|
||||
fi
|
||||
exit 0
|
||||
fi
|
||||
|
||||
if [ "$AGENT_LOOP" != "1" ]; then
|
||||
exit 1
|
||||
fi
|
||||
|
||||
sleep "$RETRY_SECONDS"
|
||||
done
|
||||
Reference in New Issue
Block a user