mirror of
https://github.com/velocitatem/PHANTOM.git
synced 2026-05-31 08:33:36 +00:00
nightly benchmark run configureation
This commit is contained in:
@@ -15,6 +15,15 @@ def _has_flag(tokens: list[str], name: str) -> bool:
|
||||
return any(tok == name or tok.startswith(f"{name}=") for tok in tokens)
|
||||
|
||||
|
||||
def _entry_tokens(run_kind: str, entry_args: str) -> list[str]:
|
||||
tokens = shlex.split(entry_args)
|
||||
if run_kind == "benchmark" and not (
|
||||
_has_flag(tokens, "--run-kind") or _has_flag(tokens, "--run-mode")
|
||||
):
|
||||
return ["--run-kind", "benchmark", *tokens]
|
||||
return tokens
|
||||
|
||||
|
||||
def _alive_node_ips() -> list[str]:
|
||||
seen: set[str] = set()
|
||||
ips: list[str] = []
|
||||
@@ -33,13 +42,18 @@ def _alive_node_ips() -> list[str]:
|
||||
def _train_on_node(
|
||||
*,
|
||||
root: str,
|
||||
train_args: str,
|
||||
run_kind: str,
|
||||
entry_args: str,
|
||||
rank: int,
|
||||
world_size: int,
|
||||
coordinator_ip: str,
|
||||
coordinator_port: int,
|
||||
base_seed: int,
|
||||
run_group: str,
|
||||
compare_robust: bool,
|
||||
output_root: str,
|
||||
wandb_entity: str,
|
||||
wandb_project: str,
|
||||
sync_jax: bool,
|
||||
) -> int:
|
||||
env = dict(os.environ)
|
||||
@@ -53,6 +67,12 @@ def _train_on_node(
|
||||
env["JAX_PLATFORMS"] = requested_platform
|
||||
# Keep each train process in single-host mode to avoid accidental global stalls.
|
||||
env["CLOUD_TPU_TASK_ID"] = "0"
|
||||
if run_kind == "benchmark":
|
||||
env["PHANTOM_BENCHMARK_COMPARE_ROBUST"] = "1" if compare_robust else "0"
|
||||
if wandb_entity:
|
||||
env["WANDB_ENTITY"] = wandb_entity
|
||||
if wandb_project:
|
||||
env["WANDB_PROJECT"] = wandb_project
|
||||
|
||||
cwd = str(Path(root))
|
||||
|
||||
@@ -75,30 +95,62 @@ def _train_on_node(
|
||||
[sys.executable, "-c", probe], cwd=cwd, env=env_probe, check=True
|
||||
)
|
||||
|
||||
tokens = shlex.split(train_args)
|
||||
tokens = _entry_tokens(run_kind, entry_args)
|
||||
seed = int(base_seed + rank)
|
||||
if not _has_flag(tokens, "--seed"):
|
||||
tokens.extend(["--seed", str(base_seed + rank)])
|
||||
if not _has_flag(tokens, "--group"):
|
||||
tokens.extend(["--seed", str(seed)])
|
||||
|
||||
if run_kind == "train" and not _has_flag(tokens, "--group"):
|
||||
tokens.extend(["--group", run_group])
|
||||
|
||||
if (
|
||||
run_kind == "benchmark"
|
||||
and output_root
|
||||
and not _has_flag(tokens, "--output-dir")
|
||||
):
|
||||
out_dir = Path(output_root) / f"rank_{rank}" / f"seed_{seed}"
|
||||
out_dir.parent.mkdir(parents=True, exist_ok=True)
|
||||
tokens.extend(["--output-dir", str(out_dir)])
|
||||
|
||||
cmd = [sys.executable, "-m", "engine.train", *tokens]
|
||||
print(
|
||||
{
|
||||
"rank": int(rank),
|
||||
"run_kind": run_kind,
|
||||
"seed": int(seed),
|
||||
"compare_robust": bool(compare_robust),
|
||||
"wandb_entity": str(env.get("WANDB_ENTITY", "")),
|
||||
"wandb_project": str(env.get("WANDB_PROJECT", "")),
|
||||
"command": " ".join(cmd),
|
||||
}
|
||||
)
|
||||
proc = subprocess.run(cmd, cwd=cwd, env=env)
|
||||
return int(proc.returncode)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Launch one train run per Ray TPU node"
|
||||
description="Launch one train/benchmark run per Ray TPU node"
|
||||
)
|
||||
parser.add_argument("--train-args", type=str, required=True)
|
||||
parser.add_argument("--run-kind", choices=["train", "benchmark"], default="train")
|
||||
parser.add_argument("--entry-args", type=str, default="")
|
||||
parser.add_argument("--train-args", type=str, default="")
|
||||
parser.add_argument("--num-nodes", type=int, default=0)
|
||||
parser.add_argument("--tpu-per-task", type=float, default=8.0)
|
||||
parser.add_argument("--base-seed", type=int, default=42)
|
||||
parser.add_argument("--sync-jax", action="store_true")
|
||||
parser.add_argument("--coordinator-port", type=int, default=12355)
|
||||
parser.add_argument("--run-group", type=str, default="")
|
||||
parser.add_argument("--compare-robust", action="store_true")
|
||||
parser.add_argument("--output-root", type=str, default="")
|
||||
parser.add_argument("--wandb-entity", type=str, default="")
|
||||
parser.add_argument("--wandb-project", type=str, default="")
|
||||
args = parser.parse_args()
|
||||
|
||||
entry_args = str(args.entry_args or args.train_args).strip()
|
||||
if not entry_args:
|
||||
raise ValueError("--entry-args (or legacy --train-args) is required")
|
||||
|
||||
ray.init(address="auto")
|
||||
|
||||
node_ips = _alive_node_ips()
|
||||
@@ -118,8 +170,11 @@ def main() -> None:
|
||||
"nodes": node_ips,
|
||||
"world_size": world_size,
|
||||
"coordinator": f"{coordinator_ip}:{int(args.coordinator_port)}",
|
||||
"train_args": args.train_args,
|
||||
"run_kind": str(args.run_kind),
|
||||
"entry_args": entry_args,
|
||||
"run_group": run_group,
|
||||
"compare_robust": bool(args.compare_robust),
|
||||
"output_root": str(args.output_root),
|
||||
}
|
||||
)
|
||||
|
||||
@@ -130,14 +185,19 @@ def main() -> None:
|
||||
futures.append(
|
||||
_train_on_node.options(resources=resources).remote(
|
||||
root=root,
|
||||
train_args=args.train_args,
|
||||
run_kind=str(args.run_kind),
|
||||
entry_args=entry_args,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
coordinator_ip=coordinator_ip,
|
||||
coordinator_port=int(args.coordinator_port),
|
||||
base_seed=int(args.base_seed),
|
||||
run_group=run_group,
|
||||
sync_jax=bool(args.sync_jax),
|
||||
compare_robust=bool(args.compare_robust),
|
||||
output_root=str(args.output_root),
|
||||
wandb_entity=str(args.wandb_entity),
|
||||
wandb_project=str(args.wandb_project),
|
||||
sync_jax=bool(args.sync_jax and str(args.run_kind) == "train"),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user