mirror of
https://github.com/velocitatem/PHANTOM.git
synced 2026-05-31 08:33:36 +00:00
cleaning up jax bs
This commit is contained in:
@@ -108,49 +108,6 @@ PY
|
||||
image_ref="${TRAIN_IMAGE_REF:-us-central1-docker.pkg.dev/phantom-trc/phantom/phantom-trainer}"
|
||||
docker build -f docker/Trainer.dockerfile --target gpu -t "$image_ref:gpu-latest" .
|
||||
docker push "$image_ref:gpu-latest"
|
||||
docker build -f docker/Trainer.dockerfile --target tpu -t "$image_ref:tpu-latest" .
|
||||
docker push "$image_ref:tpu-latest"
|
||||
;;
|
||||
train-tpu-pod)
|
||||
load_sweep_env
|
||||
require_var TPU_NAME "TPU_NAME required, e.g. TPU_NAME=TPUlong"
|
||||
require_var SWEEP_ID "SWEEP_ID required, e.g. SWEEP_ID=entity/project/id"
|
||||
require_var WANDB_API_KEY "WANDB_API_KEY required - set it in $env_file"
|
||||
gcloud compute tpus tpu-vm scp scripts/tpu_pod_run.sh "$TPU_NAME":/tmp/tpu_pod_run.sh --zone="${TPU_ZONE:-us-central2-b}" --project="${TPU_PROJECT:-phantom-trc}" --worker=all
|
||||
gcloud compute tpus tpu-vm ssh "$TPU_NAME" --zone="${TPU_ZONE:-us-central2-b}" --project="${TPU_PROJECT:-phantom-trc}" --worker=all --command="WANDB_API_KEY='$WANDB_API_KEY' SWEEP_ID='$SWEEP_ID' AGENT_COUNT='${AGENT_COUNT:-0}' sh /tmp/tpu_pod_run.sh"
|
||||
;;
|
||||
train-tpu-vm-prepare)
|
||||
require_var TPU_NAME "TPU_NAME required, e.g. TPU_NAME=TPUlong"
|
||||
TPU_NAME="$TPU_NAME" \
|
||||
TPU_ZONE="${TPU_ZONE:-us-central2-b}" \
|
||||
TPU_PROJECT="${TPU_PROJECT:-phantom-trc}" \
|
||||
LOCAL_REPO_DIR="$PWD" \
|
||||
REMOTE_REPO_DIR="${TPU_REPO_DIR:-/tmp/PHANTOM}" \
|
||||
sh scripts/tpu_sync_repo.sh
|
||||
gcloud compute tpus tpu-vm scp scripts/tpu_vm_train.sh "$TPU_NAME":/tmp/tpu_vm_train.sh --zone="${TPU_ZONE:-us-central2-b}" --project="${TPU_PROJECT:-phantom-trc}" --worker=all
|
||||
;;
|
||||
train-tpu-vm-run)
|
||||
load_sweep_env
|
||||
require_var TPU_NAME "TPU_NAME required, e.g. TPU_NAME=TPUlong"
|
||||
require_var LOCAL_TRAIN_ARGS "LOCAL_TRAIN_ARGS required, e.g. --algo ppo --jax --total-timesteps 200000"
|
||||
gcloud compute tpus tpu-vm ssh "$TPU_NAME" --zone="${TPU_ZONE:-us-central2-b}" --project="${TPU_PROJECT:-phantom-trc}" --worker=all --command="REPO_DIR='${TPU_REPO_DIR:-/tmp/PHANTOM}' TRAIN_ARGS='${LOCAL_TRAIN_ARGS}' WANDB_API_KEY='${WANDB_API_KEY:-}' sh /tmp/tpu_vm_train.sh"
|
||||
;;
|
||||
train-tpu-vm-sweep)
|
||||
load_sweep_env
|
||||
require_var TPU_NAME "TPU_NAME required, e.g. TPU_NAME=TPUlong"
|
||||
require_var SWEEP_ID "SWEEP_ID required, e.g. SWEEP_ID=lusiana/capstone/abc123"
|
||||
require_var WANDB_API_KEY "WANDB_API_KEY required - set it in $env_file"
|
||||
args=(
|
||||
--sweep-id "$SWEEP_ID"
|
||||
--tpu-name "$TPU_NAME"
|
||||
--tpu-zone "${TPU_ZONE:-us-central2-b}"
|
||||
--tpu-project "${TPU_PROJECT:-phantom-trc}"
|
||||
--tpu-repo-dir "${TPU_REPO_DIR:-/tmp/PHANTOM}"
|
||||
)
|
||||
if [ -n "${AGENT_COUNT:-}" ] && [ "${AGENT_COUNT}" != "0" ]; then
|
||||
args+=(--count "$AGENT_COUNT")
|
||||
fi
|
||||
WANDB_API_KEY="$WANDB_API_KEY" python3 scripts/tpu_vm_sweep_agent.py "${args[@]}"
|
||||
;;
|
||||
*)
|
||||
printf '%s\n' "Unknown research command: $cmd" >&2
|
||||
|
||||
@@ -1,32 +0,0 @@
|
||||
#!/usr/bin/env sh
|
||||
# Executed on each TPU pod worker via `gcloud tpu-vm scp` + `gcloud tpu-vm ssh --worker=all`.
|
||||
# Authenticates with Artifact Registry using the VM's service account metadata token,
|
||||
# pulls the TPU trainer image, then runs the W&B sweep agent inside Docker.
|
||||
# TPU chip devices (/dev/accel*) are exposed via --privileged + /dev volume mount.
|
||||
# Required env vars: WANDB_API_KEY, SWEEP_ID
|
||||
# Optional: AGENT_COUNT (default 1, 0 = run until sweep ends)
|
||||
set -eu
|
||||
|
||||
IMAGE="us-central1-docker.pkg.dev/phantom-trc/phantom/phantom-trainer:tpu-latest"
|
||||
AGENT_COUNT="${AGENT_COUNT:-1}"
|
||||
|
||||
# use VM service account — no manual key needed on the pod
|
||||
TOKEN=$(curl -sf -H "Metadata-Flavor: Google" \
|
||||
"http://metadata.google.internal/computeMetadata/v1/instance/service-accounts/default/token" \
|
||||
| python3 -c 'import sys, json; print(json.load(sys.stdin)["access_token"])')
|
||||
|
||||
echo "$TOKEN" | sudo docker login -u oauth2accesstoken \
|
||||
--password-stdin https://us-central1-docker.pkg.dev
|
||||
|
||||
sudo docker pull "$IMAGE"
|
||||
|
||||
# --privileged + /dev mount gives the container access to /dev/accel* (TPU chips)
|
||||
# --network host lets JAX reach the other pod workers for distributed init
|
||||
sudo docker run --rm \
|
||||
--privileged \
|
||||
--network host \
|
||||
--volume /dev:/dev \
|
||||
-e WANDB_API_KEY="$WANDB_API_KEY" \
|
||||
-e SWEEP_ID="$SWEEP_ID" \
|
||||
-e AGENT_COUNT="$AGENT_COUNT" \
|
||||
"$IMAGE"
|
||||
@@ -1,83 +0,0 @@
|
||||
#!/usr/bin/env sh
|
||||
set -eu
|
||||
|
||||
TPU_NAME="${TPU_NAME:?TPU_NAME is required}"
|
||||
TPU_ZONE="${TPU_ZONE:-us-central2-b}"
|
||||
TPU_PROJECT="${TPU_PROJECT:-phantom-trc}"
|
||||
LOCAL_REPO_DIR="${LOCAL_REPO_DIR:-$(pwd)}"
|
||||
REMOTE_REPO_DIR="${REMOTE_REPO_DIR:-/tmp/PHANTOM}"
|
||||
ARCHIVE_PATH="${ARCHIVE_PATH:-/tmp/phantom-sync.tgz}"
|
||||
|
||||
FILE_LIST="$(mktemp /tmp/phantom-sync-files.XXXXXX)"
|
||||
CLEANUP_LIST=true
|
||||
|
||||
cleanup() {
|
||||
if [ "$CLEANUP_LIST" = "true" ]; then
|
||||
rm -f "$FILE_LIST"
|
||||
fi
|
||||
}
|
||||
trap cleanup EXIT
|
||||
|
||||
if [ ! -d "$LOCAL_REPO_DIR" ]; then
|
||||
echo "local repo directory not found: $LOCAL_REPO_DIR"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if git -C "$LOCAL_REPO_DIR" rev-parse --is-inside-work-tree >/dev/null 2>&1; then
|
||||
git -C "$LOCAL_REPO_DIR" ls-files -co --exclude-standard > "$FILE_LIST"
|
||||
python3 - "$FILE_LIST" <<'PY'
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
file_list = Path(sys.argv[1])
|
||||
skip_prefixes = (
|
||||
"wandb/",
|
||||
".venv/",
|
||||
"venv/",
|
||||
"node_modules/",
|
||||
".next/",
|
||||
".turbo/",
|
||||
"__pycache__/",
|
||||
".mypy_cache/",
|
||||
".pytest_cache/",
|
||||
".ruff_cache/",
|
||||
"paper/build/",
|
||||
"tests/e2e/test-results/",
|
||||
)
|
||||
|
||||
rows = file_list.read_text().splitlines()
|
||||
kept = [
|
||||
row
|
||||
for row in rows
|
||||
if row and not any(row == p.rstrip("/") or row.startswith(p) for p in skip_prefixes)
|
||||
]
|
||||
file_list.write_text("\n".join(kept) + ("\n" if kept else ""))
|
||||
PY
|
||||
tar -czf "$ARCHIVE_PATH" -C "$LOCAL_REPO_DIR" -T "$FILE_LIST"
|
||||
else
|
||||
tar \
|
||||
--exclude-vcs \
|
||||
--exclude=".venv" --exclude="*/.venv" \
|
||||
--exclude="venv" --exclude="*/venv" \
|
||||
--exclude="node_modules" --exclude="*/node_modules" \
|
||||
--exclude=".next" --exclude="*/.next" \
|
||||
--exclude=".turbo" --exclude="*/.turbo" \
|
||||
--exclude="__pycache__" --exclude="*/__pycache__" \
|
||||
--exclude=".mypy_cache" --exclude="*/.mypy_cache" \
|
||||
--exclude=".pytest_cache" --exclude="*/.pytest_cache" \
|
||||
--exclude=".ruff_cache" --exclude="*/.ruff_cache" \
|
||||
--exclude="wandb" --exclude="*/wandb" \
|
||||
--exclude="paper/build" \
|
||||
--exclude="tests/e2e/test-results" \
|
||||
-czf "$ARCHIVE_PATH" \
|
||||
-C "$LOCAL_REPO_DIR" .
|
||||
fi
|
||||
|
||||
gcloud compute tpus tpu-vm scp "$ARCHIVE_PATH" "$TPU_NAME:/tmp/phantom-sync.tgz" \
|
||||
--zone="$TPU_ZONE" --project="$TPU_PROJECT" --worker=all
|
||||
|
||||
gcloud compute tpus tpu-vm ssh "$TPU_NAME" \
|
||||
--zone="$TPU_ZONE" --project="$TPU_PROJECT" --worker=all \
|
||||
--command="rm -rf '$REMOTE_REPO_DIR' && mkdir -p '$REMOTE_REPO_DIR' && tar -xzf /tmp/phantom-sync.tgz -C '$REMOTE_REPO_DIR' && rm -f /tmp/phantom-sync.tgz"
|
||||
|
||||
rm -f "$ARCHIVE_PATH"
|
||||
@@ -1,211 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import gc
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import shlex
|
||||
import shutil
|
||||
import subprocess
|
||||
import time
|
||||
import resource
|
||||
from pathlib import Path
|
||||
|
||||
import wandb
|
||||
|
||||
|
||||
CLI_MAP: dict[str, str] = {
|
||||
"algo": "--algo",
|
||||
"total_timesteps": "--total-timesteps",
|
||||
"alpha": "--alpha",
|
||||
"N": "--N",
|
||||
"n_products": "--n-products",
|
||||
"lambda_coi": "--lambda-coi",
|
||||
"info_value": "--info-value",
|
||||
"robust_radius": "--robust-radius",
|
||||
"robust_points": "--robust-points",
|
||||
"no_robust": "--no-robust",
|
||||
"learning_rate": "--learning-rate",
|
||||
"gamma": "--gamma",
|
||||
"gae_lambda": "--gae-lambda",
|
||||
"clip_range": "--clip-range",
|
||||
"ent_coef": "--ent-coef",
|
||||
"revenue_weight": "--revenue-weight",
|
||||
"max_steps": "--max-steps",
|
||||
"margin_floor": "--margin-floor",
|
||||
"margin_floor_patience": "--margin-floor-patience",
|
||||
"arch": "--arch",
|
||||
"activation": "--activation",
|
||||
"jax_num_envs": "--jax-num-envs",
|
||||
"jax_num_steps": "--jax-num-steps",
|
||||
"jax_num_minibatches": "--jax-num-minibatches",
|
||||
"jax_update_epochs": "--jax-update-epochs",
|
||||
"jax_anneal_lr": "--jax-anneal-lr",
|
||||
"checkpoint_interval": "--checkpoint-interval",
|
||||
"action_levels": "--action-levels",
|
||||
"action_scale_low": "--action-scale-low",
|
||||
"action_scale_high": "--action-scale-high",
|
||||
}
|
||||
|
||||
|
||||
def _to_cli_args(cfg: dict) -> str:
|
||||
parts: list[str] = ["--jax", "--no-wandb"]
|
||||
for key, flag in CLI_MAP.items():
|
||||
if key not in cfg:
|
||||
continue
|
||||
value = cfg[key]
|
||||
if value is None:
|
||||
continue
|
||||
if isinstance(value, bool):
|
||||
if key == "jax_anneal_lr":
|
||||
parts.extend([flag, "true" if value else "false"])
|
||||
elif value:
|
||||
parts.append(flag)
|
||||
continue
|
||||
parts.extend([flag, str(value)])
|
||||
return " ".join(shlex.quote(p) for p in parts)
|
||||
|
||||
|
||||
_SENTINEL = "PHANTOM_METRICS:"
|
||||
|
||||
|
||||
def _raise_nofile_limit(min_soft: int = 8192) -> None:
|
||||
try:
|
||||
soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
|
||||
target = min(hard, max(soft, min_soft))
|
||||
if target > soft:
|
||||
resource.setrlimit(resource.RLIMIT_NOFILE, (target, hard))
|
||||
except Exception:
|
||||
return
|
||||
|
||||
|
||||
def _extract_metrics(output: str) -> dict:
|
||||
# fast path: look for the dedicated sentinel line emitted by run_local
|
||||
for line in output.splitlines():
|
||||
if line.startswith(_SENTINEL):
|
||||
try:
|
||||
return json.loads(line[len(_SENTINEL) :])
|
||||
except Exception:
|
||||
break
|
||||
# fallback: scan for any JSON block containing eval/sweep keys;
|
||||
# use greedy match to capture the largest possible block first
|
||||
for block in re.findall(r"\{[^{}]*\}", output):
|
||||
try:
|
||||
obj = json.loads(block)
|
||||
except Exception:
|
||||
continue
|
||||
if isinstance(obj, dict) and (
|
||||
"objective/score" in obj
|
||||
or "eval/reward_mean" in obj
|
||||
or "sweep/score" in obj
|
||||
):
|
||||
return obj
|
||||
return {}
|
||||
|
||||
|
||||
def main() -> None:
|
||||
_raise_nofile_limit()
|
||||
p = argparse.ArgumentParser(
|
||||
description="Run W&B sweep where each trial uses full TPU pod"
|
||||
)
|
||||
p.add_argument("--sweep-id", required=True)
|
||||
p.add_argument("--tpu-name", required=True)
|
||||
p.add_argument("--tpu-zone", default="us-central2-b")
|
||||
p.add_argument("--tpu-project", default="phantom-trc")
|
||||
p.add_argument("--tpu-repo-dir", default="/tmp/PHANTOM")
|
||||
p.add_argument("--count", type=int, default=0)
|
||||
p.add_argument("--workdir", default=str(Path(__file__).resolve().parents[1]))
|
||||
args = p.parse_args()
|
||||
|
||||
workdir = Path(args.workdir).resolve()
|
||||
env = os.environ.copy()
|
||||
wandb_root = workdir / ".wandb-agent"
|
||||
wandb_root.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
prepare_cmd = [
|
||||
"make",
|
||||
"train.tpu.vm.prepare",
|
||||
f"TPU_NAME={args.tpu_name}",
|
||||
f"TPU_ZONE={args.tpu_zone}",
|
||||
f"TPU_PROJECT={args.tpu_project}",
|
||||
f"TPU_REPO_DIR={args.tpu_repo_dir}",
|
||||
]
|
||||
prepare = subprocess.run(
|
||||
prepare_cmd,
|
||||
cwd=workdir,
|
||||
env=env,
|
||||
text=True,
|
||||
capture_output=False,
|
||||
check=False,
|
||||
)
|
||||
if prepare.returncode != 0:
|
||||
raise RuntimeError("Failed to prepare TPU workers for sweep")
|
||||
|
||||
def run_trial() -> None:
|
||||
run = None
|
||||
trial_wandb_dir = wandb_root / f"trial-{time.time_ns()}"
|
||||
trial_wandb_dir.mkdir(parents=True, exist_ok=True)
|
||||
try:
|
||||
run = wandb.init(dir=str(trial_wandb_dir))
|
||||
cfg = dict(wandb.config)
|
||||
cli_args = _to_cli_args(cfg)
|
||||
env_trial = dict(env)
|
||||
env_trial["LOCAL_TRAIN_ARGS"] = cli_args
|
||||
env_trial["WANDB_DIR"] = str(trial_wandb_dir)
|
||||
env_trial["WANDB_CACHE_DIR"] = str(trial_wandb_dir / "cache")
|
||||
env_trial["WANDB_DATA_DIR"] = str(trial_wandb_dir / "data")
|
||||
|
||||
cmd = [
|
||||
"make",
|
||||
"train.tpu.vm.run",
|
||||
f"TPU_NAME={args.tpu_name}",
|
||||
f"TPU_ZONE={args.tpu_zone}",
|
||||
f"TPU_PROJECT={args.tpu_project}",
|
||||
f"TPU_REPO_DIR={args.tpu_repo_dir}",
|
||||
]
|
||||
|
||||
proc = subprocess.run(
|
||||
cmd,
|
||||
cwd=workdir,
|
||||
env=env_trial,
|
||||
text=True,
|
||||
capture_output=True,
|
||||
check=False,
|
||||
)
|
||||
|
||||
if proc.stdout:
|
||||
print(proc.stdout)
|
||||
if proc.stderr:
|
||||
print(proc.stderr)
|
||||
|
||||
if proc.returncode != 0:
|
||||
if run is not None:
|
||||
run.summary["runner/exit_code"] = proc.returncode
|
||||
raise RuntimeError(f"TPU trial failed with exit code {proc.returncode}")
|
||||
|
||||
metrics = _extract_metrics(proc.stdout)
|
||||
if metrics:
|
||||
wandb.log(metrics)
|
||||
for k, v in metrics.items():
|
||||
run.summary[k] = v
|
||||
run.summary["runner/exit_code"] = 0
|
||||
except Exception:
|
||||
time.sleep(2)
|
||||
raise
|
||||
finally:
|
||||
if run is not None and wandb.run is not None:
|
||||
wandb.finish()
|
||||
shutil.rmtree(trial_wandb_dir, ignore_errors=True)
|
||||
gc.collect()
|
||||
|
||||
wandb.agent(
|
||||
args.sweep_id,
|
||||
function=run_trial,
|
||||
count=args.count if args.count > 0 else None,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,43 +0,0 @@
|
||||
#!/usr/bin/env sh
|
||||
set -eu
|
||||
|
||||
REPO_DIR="${REPO_DIR:-$HOME/PHANTOM}"
|
||||
PYTHON_BIN="${PYTHON_BIN:-python3}"
|
||||
TRAIN_ARGS="${TRAIN_ARGS:---algo ppo --jax --total-timesteps 200000 --jax-num-envs 32 --jax-num-steps 128 --jax-num-minibatches 4 --jax-update-epochs 4}"
|
||||
EXTRA_PIP="${EXTRA_PIP:-flax optax distrax}"
|
||||
INSTALL_FULL_REQUIREMENTS="${INSTALL_FULL_REQUIREMENTS:-0}"
|
||||
|
||||
if [ ! -d "$REPO_DIR" ]; then
|
||||
echo "repo directory not found: $REPO_DIR"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
cd "$REPO_DIR"
|
||||
|
||||
if [ -d "wandb" ]; then
|
||||
rm -rf wandb
|
||||
fi
|
||||
|
||||
# keep install idempotent and avoid re-installing jax/libtpu each run
|
||||
if [ "$INSTALL_FULL_REQUIREMENTS" = "1" ] && [ -f "requirements.txt" ]; then
|
||||
$PYTHON_BIN -m pip install -r requirements.txt
|
||||
fi
|
||||
if ! $PYTHON_BIN -c 'import flax, optax, distrax' >/dev/null 2>&1; then
|
||||
if [ -f "engine/jax/requirements.txt" ]; then
|
||||
$PYTHON_BIN -m pip install -r engine/jax/requirements.txt
|
||||
fi
|
||||
$PYTHON_BIN -m pip install -U $EXTRA_PIP
|
||||
fi
|
||||
|
||||
if [ -n "${WANDB_API_KEY:-}" ]; then
|
||||
if ! $PYTHON_BIN -c 'import wandb; import inspect; assert hasattr(wandb, "init") and callable(wandb.init)' >/dev/null 2>&1; then
|
||||
$PYTHON_BIN -m pip install -U wandb
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ -n "${WANDB_API_KEY:-}" ]; then
|
||||
export WANDB_API_KEY
|
||||
exec $PYTHON_BIN -m engine.train $TRAIN_ARGS
|
||||
fi
|
||||
|
||||
exec $PYTHON_BIN -m engine.train $TRAIN_ARGS --no-wandb
|
||||
Reference in New Issue
Block a user