updating engine training for training

This commit is contained in:
2026-03-15 21:14:11 +01:00
parent 19b47aa699
commit 52b4dcdce3
13 changed files with 544 additions and 160 deletions

View File

@@ -36,7 +36,12 @@ def canonicalize_metrics(raw: Mapping[str, Any], spec: TrainSpec) -> dict[str, A
eval_reward = (
_as_float(
metrics.get("eval/robust_reward_worst", metrics.get("eval/reward_mean")),
metrics.get(
"eval/stress_reward_worst",
metrics.get(
"eval/robust_reward_worst", metrics.get("eval/reward_mean")
),
),
0.0,
)
or 0.0
@@ -51,9 +56,12 @@ def canonicalize_metrics(raw: Mapping[str, Any], spec: TrainSpec) -> dict[str, A
metrics["objective/coi_preserved"] = 0.0 if coi_level is None else coi_level
metrics["study/alpha"] = spec.study.alpha
metrics["study/mode"] = "baseline" if bool(spec.study.no_robust) else "defended"
metrics["study/baseline_mode"] = float(bool(spec.study.no_robust))
metrics["study/lambda_coi"] = spec.study.lambda_coi
metrics["study/robust_radius"] = spec.study.robust_radius
metrics["study/ambiguity_radius"] = spec.study.robust_radius
metrics["study/info_value"] = spec.study.info_value
metrics["tiers"] = spec.algorithm.name
metrics["runtime/backend"] = spec.runtime.backend
metrics["runtime/device"] = spec.runtime.device

View File

@@ -1,5 +1,7 @@
from __future__ import annotations
import os
import time
from typing import Any, Callable, Iterable, Mapping
@@ -19,6 +21,42 @@ def _require_wandb():
return wandb
def _warn(message: str) -> None:
print(f"PHANTOM_WANDB_WARNING: {message}")
def _sanitize_key(raw_key: str) -> str | None:
key = str(raw_key)
replacements = {
"no_robust": "baseline_mode",
"study/no_robust": "study/baseline_mode",
"study/robust_radius": "study/ambiguity_radius",
"robust_radius": "ambiguity_radius",
"robust_points": "ambiguity_points",
"robust_rollouts": "ambiguity_rollouts",
"robust_eval_enabled": "stress_eval_enabled",
"eval/robust_alpha_high": "eval/stress_alpha_high",
"eval/robust_alpha_low": "eval/stress_alpha_low",
"eval/robust_reward_worst": "eval/stress_reward_worst",
"eval/robust_revenue_worst": "eval/stress_revenue_worst",
"eval/robust_coi_leakage_worst": "eval/stress_coi_leakage_worst",
}
key = replacements.get(key, key)
if "robust" in key.lower():
return None
return key
def _sanitize_payload(payload: Mapping[str, Any]) -> dict[str, Any]:
sanitized: dict[str, Any] = {}
for key, value in payload.items():
clean_key = _sanitize_key(str(key))
if clean_key is None:
continue
sanitized[clean_key] = value
return sanitized
def init_run(
*,
mode: str,
@@ -34,7 +72,11 @@ def init_run(
if group:
kwargs["group"] = group
if sweep_mode:
run = wandb.init(**kwargs)
try:
run = wandb.init(**kwargs)
except Exception as exc:
_warn(f"init failed in sweep mode ({exc})")
return None
if name and run is not None:
run.name = name
return run
@@ -42,18 +84,25 @@ def init_run(
init_kwargs = dict(kwargs)
init_kwargs["project"] = project
if config is not None:
init_kwargs["config"] = dict(config)
init_kwargs["config"] = _sanitize_payload(dict(config))
if name:
init_kwargs["name"] = name
if tags:
init_kwargs["tags"] = list(tags)
return wandb.init(**init_kwargs)
try:
return wandb.init(**init_kwargs)
except Exception as exc:
_warn(f"init failed ({exc})")
return None
def finish_run() -> None:
wandb = get_wandb_module()
if wandb is not None and wandb.run is not None:
wandb.finish()
try:
wandb.finish()
except Exception as exc:
_warn(f"finish failed ({exc})")
def current_config() -> dict[str, Any]:
@@ -67,25 +116,45 @@ def update_run_config(config: Mapping[str, Any]) -> None:
wandb = get_wandb_module()
if wandb is None or wandb.run is None:
return
payload = _sanitize_payload(dict(config))
if not payload:
return
try:
wandb.config.update(dict(config), allow_val_change=True)
wandb.config.update(payload, allow_val_change=True)
except TypeError:
wandb.config.update(dict(config))
try:
wandb.config.update(payload)
except Exception as exc:
_warn(f"config update failed ({exc})")
except Exception as exc:
_warn(f"config update failed ({exc})")
def log_metrics(metrics: Mapping[str, Any], *, step: int) -> None:
wandb = get_wandb_module()
if wandb is None or wandb.run is None:
return
wandb.log(dict(metrics), step=step)
payload = _sanitize_payload(dict(metrics))
if not payload:
return
try:
wandb.log(payload, step=step)
except Exception as exc:
_warn(f"log failed at step {step} ({exc})")
def update_summary(metrics: Mapping[str, Any]) -> None:
wandb = get_wandb_module()
if wandb is None or wandb.run is None:
return
for key, value in metrics.items():
wandb.run.summary[key] = value
payload = _sanitize_payload(dict(metrics))
if not payload:
return
try:
for key, value in payload.items():
wandb.run.summary[key] = value
except Exception as exc:
_warn(f"summary update failed ({exc})")
def run_agent(
@@ -95,4 +164,39 @@ def run_agent(
count: int | None = None,
) -> None:
wandb = _require_wandb()
wandb.agent(sweep_id, function=fn, count=count)
retry_max = max(0, int(os.getenv("PHANTOM_WANDB_AGENT_RETRIES", "8")))
retry_delay = max(1.0, float(os.getenv("PHANTOM_WANDB_AGENT_RETRY_DELAY", "5")))
retry_backoff = max(
1.0, float(os.getenv("PHANTOM_WANDB_AGENT_RETRY_BACKOFF", "1.5"))
)
retry_max_delay = max(
retry_delay,
float(os.getenv("PHANTOM_WANDB_AGENT_MAX_RETRY_DELAY", "60")),
)
target = None if count is None else max(0, int(count))
completed = 0
def _wrapped() -> None:
nonlocal completed
fn()
completed += 1
attempt = 0
while True:
remaining = None if target is None else max(0, int(target - completed))
if target is not None and remaining == 0:
return
try:
wandb.agent(sweep_id, function=_wrapped, count=remaining)
return
except Exception as exc:
attempt += 1
if attempt > retry_max:
raise
wait = min(retry_max_delay, retry_delay * (retry_backoff ** (attempt - 1)))
_warn(
f"agent disconnected (attempt {attempt}/{retry_max}, "
f"completed={completed}, remaining={remaining}): {exc}"
)
time.sleep(wait)