mirror of
https://github.com/velocitatem/PHANTOM.git
synced 2026-05-31 08:33:36 +00:00
updating engine training for training
This commit is contained in:
@@ -1,5 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import time
|
||||
from typing import Any, Callable, Iterable, Mapping
|
||||
|
||||
|
||||
@@ -19,6 +21,42 @@ def _require_wandb():
|
||||
return wandb
|
||||
|
||||
|
||||
def _warn(message: str) -> None:
|
||||
print(f"PHANTOM_WANDB_WARNING: {message}")
|
||||
|
||||
|
||||
def _sanitize_key(raw_key: str) -> str | None:
|
||||
key = str(raw_key)
|
||||
replacements = {
|
||||
"no_robust": "baseline_mode",
|
||||
"study/no_robust": "study/baseline_mode",
|
||||
"study/robust_radius": "study/ambiguity_radius",
|
||||
"robust_radius": "ambiguity_radius",
|
||||
"robust_points": "ambiguity_points",
|
||||
"robust_rollouts": "ambiguity_rollouts",
|
||||
"robust_eval_enabled": "stress_eval_enabled",
|
||||
"eval/robust_alpha_high": "eval/stress_alpha_high",
|
||||
"eval/robust_alpha_low": "eval/stress_alpha_low",
|
||||
"eval/robust_reward_worst": "eval/stress_reward_worst",
|
||||
"eval/robust_revenue_worst": "eval/stress_revenue_worst",
|
||||
"eval/robust_coi_leakage_worst": "eval/stress_coi_leakage_worst",
|
||||
}
|
||||
key = replacements.get(key, key)
|
||||
if "robust" in key.lower():
|
||||
return None
|
||||
return key
|
||||
|
||||
|
||||
def _sanitize_payload(payload: Mapping[str, Any]) -> dict[str, Any]:
|
||||
sanitized: dict[str, Any] = {}
|
||||
for key, value in payload.items():
|
||||
clean_key = _sanitize_key(str(key))
|
||||
if clean_key is None:
|
||||
continue
|
||||
sanitized[clean_key] = value
|
||||
return sanitized
|
||||
|
||||
|
||||
def init_run(
|
||||
*,
|
||||
mode: str,
|
||||
@@ -34,7 +72,11 @@ def init_run(
|
||||
if group:
|
||||
kwargs["group"] = group
|
||||
if sweep_mode:
|
||||
run = wandb.init(**kwargs)
|
||||
try:
|
||||
run = wandb.init(**kwargs)
|
||||
except Exception as exc:
|
||||
_warn(f"init failed in sweep mode ({exc})")
|
||||
return None
|
||||
if name and run is not None:
|
||||
run.name = name
|
||||
return run
|
||||
@@ -42,18 +84,25 @@ def init_run(
|
||||
init_kwargs = dict(kwargs)
|
||||
init_kwargs["project"] = project
|
||||
if config is not None:
|
||||
init_kwargs["config"] = dict(config)
|
||||
init_kwargs["config"] = _sanitize_payload(dict(config))
|
||||
if name:
|
||||
init_kwargs["name"] = name
|
||||
if tags:
|
||||
init_kwargs["tags"] = list(tags)
|
||||
return wandb.init(**init_kwargs)
|
||||
try:
|
||||
return wandb.init(**init_kwargs)
|
||||
except Exception as exc:
|
||||
_warn(f"init failed ({exc})")
|
||||
return None
|
||||
|
||||
|
||||
def finish_run() -> None:
|
||||
wandb = get_wandb_module()
|
||||
if wandb is not None and wandb.run is not None:
|
||||
wandb.finish()
|
||||
try:
|
||||
wandb.finish()
|
||||
except Exception as exc:
|
||||
_warn(f"finish failed ({exc})")
|
||||
|
||||
|
||||
def current_config() -> dict[str, Any]:
|
||||
@@ -67,25 +116,45 @@ def update_run_config(config: Mapping[str, Any]) -> None:
|
||||
wandb = get_wandb_module()
|
||||
if wandb is None or wandb.run is None:
|
||||
return
|
||||
payload = _sanitize_payload(dict(config))
|
||||
if not payload:
|
||||
return
|
||||
try:
|
||||
wandb.config.update(dict(config), allow_val_change=True)
|
||||
wandb.config.update(payload, allow_val_change=True)
|
||||
except TypeError:
|
||||
wandb.config.update(dict(config))
|
||||
try:
|
||||
wandb.config.update(payload)
|
||||
except Exception as exc:
|
||||
_warn(f"config update failed ({exc})")
|
||||
except Exception as exc:
|
||||
_warn(f"config update failed ({exc})")
|
||||
|
||||
|
||||
def log_metrics(metrics: Mapping[str, Any], *, step: int) -> None:
|
||||
wandb = get_wandb_module()
|
||||
if wandb is None or wandb.run is None:
|
||||
return
|
||||
wandb.log(dict(metrics), step=step)
|
||||
payload = _sanitize_payload(dict(metrics))
|
||||
if not payload:
|
||||
return
|
||||
try:
|
||||
wandb.log(payload, step=step)
|
||||
except Exception as exc:
|
||||
_warn(f"log failed at step {step} ({exc})")
|
||||
|
||||
|
||||
def update_summary(metrics: Mapping[str, Any]) -> None:
|
||||
wandb = get_wandb_module()
|
||||
if wandb is None or wandb.run is None:
|
||||
return
|
||||
for key, value in metrics.items():
|
||||
wandb.run.summary[key] = value
|
||||
payload = _sanitize_payload(dict(metrics))
|
||||
if not payload:
|
||||
return
|
||||
try:
|
||||
for key, value in payload.items():
|
||||
wandb.run.summary[key] = value
|
||||
except Exception as exc:
|
||||
_warn(f"summary update failed ({exc})")
|
||||
|
||||
|
||||
def run_agent(
|
||||
@@ -95,4 +164,39 @@ def run_agent(
|
||||
count: int | None = None,
|
||||
) -> None:
|
||||
wandb = _require_wandb()
|
||||
wandb.agent(sweep_id, function=fn, count=count)
|
||||
retry_max = max(0, int(os.getenv("PHANTOM_WANDB_AGENT_RETRIES", "8")))
|
||||
retry_delay = max(1.0, float(os.getenv("PHANTOM_WANDB_AGENT_RETRY_DELAY", "5")))
|
||||
retry_backoff = max(
|
||||
1.0, float(os.getenv("PHANTOM_WANDB_AGENT_RETRY_BACKOFF", "1.5"))
|
||||
)
|
||||
retry_max_delay = max(
|
||||
retry_delay,
|
||||
float(os.getenv("PHANTOM_WANDB_AGENT_MAX_RETRY_DELAY", "60")),
|
||||
)
|
||||
|
||||
target = None if count is None else max(0, int(count))
|
||||
completed = 0
|
||||
|
||||
def _wrapped() -> None:
|
||||
nonlocal completed
|
||||
fn()
|
||||
completed += 1
|
||||
|
||||
attempt = 0
|
||||
while True:
|
||||
remaining = None if target is None else max(0, int(target - completed))
|
||||
if target is not None and remaining == 0:
|
||||
return
|
||||
try:
|
||||
wandb.agent(sweep_id, function=_wrapped, count=remaining)
|
||||
return
|
||||
except Exception as exc:
|
||||
attempt += 1
|
||||
if attempt > retry_max:
|
||||
raise
|
||||
wait = min(retry_max_delay, retry_delay * (retry_backoff ** (attempt - 1)))
|
||||
_warn(
|
||||
f"agent disconnected (attempt {attempt}/{retry_max}, "
|
||||
f"completed={completed}, remaining={remaining}): {exc}"
|
||||
)
|
||||
time.sleep(wait)
|
||||
|
||||
Reference in New Issue
Block a user