mirror of
https://github.com/velocitatem/PHANTOM.git
synced 2026-05-31 16:43:36 +00:00
catchup: rogue scripts
This commit is contained in:
183
scripts/tpu_vm_sweep_agent.py
Normal file
183
scripts/tpu_vm_sweep_agent.py
Normal file
@@ -0,0 +1,183 @@
|
||||
#!/usr/bin/env python3
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import shlex
|
||||
import subprocess
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import wandb
|
||||
|
||||
|
||||
CLI_MAP: dict[str, str] = {
|
||||
"algo": "--algo",
|
||||
"total_timesteps": "--total-timesteps",
|
||||
"alpha": "--alpha",
|
||||
"N": "--N",
|
||||
"n_products": "--n-products",
|
||||
"lambda_coi": "--lambda-coi",
|
||||
"info_value": "--info-value",
|
||||
"robust_radius": "--robust-radius",
|
||||
"robust_points": "--robust-points",
|
||||
"learning_rate": "--learning-rate",
|
||||
"gamma": "--gamma",
|
||||
"gae_lambda": "--gae-lambda",
|
||||
"clip_range": "--clip-range",
|
||||
"ent_coef": "--ent-coef",
|
||||
"revenue_weight": "--revenue-weight",
|
||||
"max_steps": "--max-steps",
|
||||
"margin_floor": "--margin-floor",
|
||||
"margin_floor_patience": "--margin-floor-patience",
|
||||
"arch": "--arch",
|
||||
"activation": "--activation",
|
||||
"jax_num_envs": "--jax-num-envs",
|
||||
"jax_num_steps": "--jax-num-steps",
|
||||
"jax_num_minibatches": "--jax-num-minibatches",
|
||||
"jax_update_epochs": "--jax-update-epochs",
|
||||
"jax_anneal_lr": "--jax-anneal-lr",
|
||||
"checkpoint_interval": "--checkpoint-interval",
|
||||
"action_levels": "--action-levels",
|
||||
"action_scale_low": "--action-scale-low",
|
||||
"action_scale_high": "--action-scale-high",
|
||||
}
|
||||
|
||||
|
||||
def _to_cli_args(cfg: dict) -> str:
|
||||
parts: list[str] = ["--jax", "--no-wandb"]
|
||||
for key, flag in CLI_MAP.items():
|
||||
if key not in cfg:
|
||||
continue
|
||||
value = cfg[key]
|
||||
if value is None:
|
||||
continue
|
||||
if isinstance(value, bool):
|
||||
if key == "jax_anneal_lr":
|
||||
parts.extend([flag, "true" if value else "false"])
|
||||
elif value:
|
||||
parts.append(flag)
|
||||
continue
|
||||
parts.extend([flag, str(value)])
|
||||
return " ".join(shlex.quote(p) for p in parts)
|
||||
|
||||
|
||||
_SENTINEL = "PHANTOM_METRICS:"
|
||||
|
||||
|
||||
def _extract_metrics(output: str) -> dict:
|
||||
# fast path: look for the dedicated sentinel line emitted by run_local
|
||||
for line in output.splitlines():
|
||||
if line.startswith(_SENTINEL):
|
||||
try:
|
||||
return json.loads(line[len(_SENTINEL) :])
|
||||
except Exception:
|
||||
break
|
||||
# fallback: scan for any JSON block containing eval/sweep keys;
|
||||
# use greedy match to capture the largest possible block first
|
||||
for block in re.findall(r"\{[^{}]*\}", output):
|
||||
try:
|
||||
obj = json.loads(block)
|
||||
except Exception:
|
||||
continue
|
||||
if isinstance(obj, dict) and ("sweep/score" in obj or "eval/reward" in obj):
|
||||
return obj
|
||||
return {}
|
||||
|
||||
|
||||
def main() -> None:
|
||||
p = argparse.ArgumentParser(
|
||||
description="Run W&B sweep where each trial uses full TPU pod"
|
||||
)
|
||||
p.add_argument("--sweep-id", required=True)
|
||||
p.add_argument("--tpu-name", required=True)
|
||||
p.add_argument("--tpu-zone", default="us-central2-b")
|
||||
p.add_argument("--tpu-project", default="phantom-trc")
|
||||
p.add_argument("--tpu-repo-dir", default="/tmp/PHANTOM")
|
||||
p.add_argument("--count", type=int, default=0)
|
||||
p.add_argument("--workdir", default=str(Path(__file__).resolve().parents[1]))
|
||||
args = p.parse_args()
|
||||
|
||||
workdir = Path(args.workdir).resolve()
|
||||
env = os.environ.copy()
|
||||
|
||||
prepare_cmd = [
|
||||
"make",
|
||||
"train.tpu.vm.prepare",
|
||||
f"TPU_NAME={args.tpu_name}",
|
||||
f"TPU_ZONE={args.tpu_zone}",
|
||||
f"TPU_PROJECT={args.tpu_project}",
|
||||
f"TPU_REPO_DIR={args.tpu_repo_dir}",
|
||||
]
|
||||
prepare = subprocess.run(
|
||||
prepare_cmd,
|
||||
cwd=workdir,
|
||||
env=env,
|
||||
text=True,
|
||||
capture_output=False,
|
||||
check=False,
|
||||
)
|
||||
if prepare.returncode != 0:
|
||||
raise RuntimeError("Failed to prepare TPU workers for sweep")
|
||||
|
||||
def run_trial() -> None:
|
||||
run = None
|
||||
try:
|
||||
run = wandb.init()
|
||||
cfg = dict(wandb.config)
|
||||
cli_args = _to_cli_args(cfg)
|
||||
env_trial = dict(env)
|
||||
env_trial["LOCAL_TRAIN_ARGS"] = cli_args
|
||||
|
||||
cmd = [
|
||||
"make",
|
||||
"train.tpu.vm.run",
|
||||
f"TPU_NAME={args.tpu_name}",
|
||||
f"TPU_ZONE={args.tpu_zone}",
|
||||
f"TPU_PROJECT={args.tpu_project}",
|
||||
f"TPU_REPO_DIR={args.tpu_repo_dir}",
|
||||
]
|
||||
|
||||
proc = subprocess.run(
|
||||
cmd,
|
||||
cwd=workdir,
|
||||
env=env_trial,
|
||||
text=True,
|
||||
capture_output=True,
|
||||
check=False,
|
||||
)
|
||||
|
||||
if proc.stdout:
|
||||
print(proc.stdout)
|
||||
if proc.stderr:
|
||||
print(proc.stderr)
|
||||
|
||||
if proc.returncode != 0:
|
||||
if run is not None:
|
||||
run.summary["runner/exit_code"] = proc.returncode
|
||||
raise RuntimeError(f"TPU trial failed with exit code {proc.returncode}")
|
||||
|
||||
metrics = _extract_metrics(proc.stdout)
|
||||
if metrics:
|
||||
wandb.log(metrics)
|
||||
for k, v in metrics.items():
|
||||
run.summary[k] = v
|
||||
run.summary["runner/exit_code"] = 0
|
||||
except Exception:
|
||||
time.sleep(2)
|
||||
raise
|
||||
finally:
|
||||
if run is not None and wandb.run is not None:
|
||||
wandb.finish()
|
||||
|
||||
wandb.agent(
|
||||
args.sweep_id,
|
||||
function=run_trial,
|
||||
count=args.count if args.count > 0 else None,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user