mirror of
https://github.com/velocitatem/PHANTOM.git
synced 2026-05-31 16:43:36 +00:00
class separaiblity significance
This commit is contained in:
@@ -2,12 +2,15 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import gc
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import shlex
|
||||
import shutil
|
||||
import subprocess
|
||||
import time
|
||||
import resource
|
||||
from pathlib import Path
|
||||
|
||||
import wandb
|
||||
@@ -23,6 +26,7 @@ CLI_MAP: dict[str, str] = {
|
||||
"info_value": "--info-value",
|
||||
"robust_radius": "--robust-radius",
|
||||
"robust_points": "--robust-points",
|
||||
"no_robust": "--no-robust",
|
||||
"learning_rate": "--learning-rate",
|
||||
"gamma": "--gamma",
|
||||
"gae_lambda": "--gae-lambda",
|
||||
@@ -67,6 +71,16 @@ def _to_cli_args(cfg: dict) -> str:
|
||||
_SENTINEL = "PHANTOM_METRICS:"
|
||||
|
||||
|
||||
def _raise_nofile_limit(min_soft: int = 8192) -> None:
|
||||
try:
|
||||
soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
|
||||
target = min(hard, max(soft, min_soft))
|
||||
if target > soft:
|
||||
resource.setrlimit(resource.RLIMIT_NOFILE, (target, hard))
|
||||
except Exception:
|
||||
return
|
||||
|
||||
|
||||
def _extract_metrics(output: str) -> dict:
|
||||
# fast path: look for the dedicated sentinel line emitted by run_local
|
||||
for line in output.splitlines():
|
||||
@@ -88,6 +102,7 @@ def _extract_metrics(output: str) -> dict:
|
||||
|
||||
|
||||
def main() -> None:
|
||||
_raise_nofile_limit()
|
||||
p = argparse.ArgumentParser(
|
||||
description="Run W&B sweep where each trial uses full TPU pod"
|
||||
)
|
||||
@@ -102,6 +117,8 @@ def main() -> None:
|
||||
|
||||
workdir = Path(args.workdir).resolve()
|
||||
env = os.environ.copy()
|
||||
wandb_root = workdir / ".wandb-agent"
|
||||
wandb_root.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
prepare_cmd = [
|
||||
"make",
|
||||
@@ -124,12 +141,17 @@ def main() -> None:
|
||||
|
||||
def run_trial() -> None:
|
||||
run = None
|
||||
trial_wandb_dir = wandb_root / f"trial-{time.time_ns()}"
|
||||
trial_wandb_dir.mkdir(parents=True, exist_ok=True)
|
||||
try:
|
||||
run = wandb.init()
|
||||
run = wandb.init(dir=str(trial_wandb_dir))
|
||||
cfg = dict(wandb.config)
|
||||
cli_args = _to_cli_args(cfg)
|
||||
env_trial = dict(env)
|
||||
env_trial["LOCAL_TRAIN_ARGS"] = cli_args
|
||||
env_trial["WANDB_DIR"] = str(trial_wandb_dir)
|
||||
env_trial["WANDB_CACHE_DIR"] = str(trial_wandb_dir / "cache")
|
||||
env_trial["WANDB_DATA_DIR"] = str(trial_wandb_dir / "data")
|
||||
|
||||
cmd = [
|
||||
"make",
|
||||
@@ -171,6 +193,8 @@ def main() -> None:
|
||||
finally:
|
||||
if run is not None and wandb.run is not None:
|
||||
wandb.finish()
|
||||
shutil.rmtree(trial_wandb_dir, ignore_errors=True)
|
||||
gc.collect()
|
||||
|
||||
wandb.agent(
|
||||
args.sweep_id,
|
||||
|
||||
Reference in New Issue
Block a user