mirror of
https://github.com/velocitatem/PHANTOM.git
synced 2026-05-31 08:33:36 +00:00
updating engine training for training
This commit is contained in:
@@ -36,7 +36,12 @@ def canonicalize_metrics(raw: Mapping[str, Any], spec: TrainSpec) -> dict[str, A
|
||||
|
||||
eval_reward = (
|
||||
_as_float(
|
||||
metrics.get("eval/robust_reward_worst", metrics.get("eval/reward_mean")),
|
||||
metrics.get(
|
||||
"eval/stress_reward_worst",
|
||||
metrics.get(
|
||||
"eval/robust_reward_worst", metrics.get("eval/reward_mean")
|
||||
),
|
||||
),
|
||||
0.0,
|
||||
)
|
||||
or 0.0
|
||||
@@ -51,9 +56,12 @@ def canonicalize_metrics(raw: Mapping[str, Any], spec: TrainSpec) -> dict[str, A
|
||||
metrics["objective/coi_preserved"] = 0.0 if coi_level is None else coi_level
|
||||
|
||||
metrics["study/alpha"] = spec.study.alpha
|
||||
metrics["study/mode"] = "baseline" if bool(spec.study.no_robust) else "defended"
|
||||
metrics["study/baseline_mode"] = float(bool(spec.study.no_robust))
|
||||
metrics["study/lambda_coi"] = spec.study.lambda_coi
|
||||
metrics["study/robust_radius"] = spec.study.robust_radius
|
||||
metrics["study/ambiguity_radius"] = spec.study.robust_radius
|
||||
metrics["study/info_value"] = spec.study.info_value
|
||||
metrics["tiers"] = spec.algorithm.name
|
||||
|
||||
metrics["runtime/backend"] = spec.runtime.backend
|
||||
metrics["runtime/device"] = spec.runtime.device
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import time
|
||||
from typing import Any, Callable, Iterable, Mapping
|
||||
|
||||
|
||||
@@ -19,6 +21,42 @@ def _require_wandb():
|
||||
return wandb
|
||||
|
||||
|
||||
def _warn(message: str) -> None:
|
||||
print(f"PHANTOM_WANDB_WARNING: {message}")
|
||||
|
||||
|
||||
def _sanitize_key(raw_key: str) -> str | None:
|
||||
key = str(raw_key)
|
||||
replacements = {
|
||||
"no_robust": "baseline_mode",
|
||||
"study/no_robust": "study/baseline_mode",
|
||||
"study/robust_radius": "study/ambiguity_radius",
|
||||
"robust_radius": "ambiguity_radius",
|
||||
"robust_points": "ambiguity_points",
|
||||
"robust_rollouts": "ambiguity_rollouts",
|
||||
"robust_eval_enabled": "stress_eval_enabled",
|
||||
"eval/robust_alpha_high": "eval/stress_alpha_high",
|
||||
"eval/robust_alpha_low": "eval/stress_alpha_low",
|
||||
"eval/robust_reward_worst": "eval/stress_reward_worst",
|
||||
"eval/robust_revenue_worst": "eval/stress_revenue_worst",
|
||||
"eval/robust_coi_leakage_worst": "eval/stress_coi_leakage_worst",
|
||||
}
|
||||
key = replacements.get(key, key)
|
||||
if "robust" in key.lower():
|
||||
return None
|
||||
return key
|
||||
|
||||
|
||||
def _sanitize_payload(payload: Mapping[str, Any]) -> dict[str, Any]:
|
||||
sanitized: dict[str, Any] = {}
|
||||
for key, value in payload.items():
|
||||
clean_key = _sanitize_key(str(key))
|
||||
if clean_key is None:
|
||||
continue
|
||||
sanitized[clean_key] = value
|
||||
return sanitized
|
||||
|
||||
|
||||
def init_run(
|
||||
*,
|
||||
mode: str,
|
||||
@@ -34,7 +72,11 @@ def init_run(
|
||||
if group:
|
||||
kwargs["group"] = group
|
||||
if sweep_mode:
|
||||
run = wandb.init(**kwargs)
|
||||
try:
|
||||
run = wandb.init(**kwargs)
|
||||
except Exception as exc:
|
||||
_warn(f"init failed in sweep mode ({exc})")
|
||||
return None
|
||||
if name and run is not None:
|
||||
run.name = name
|
||||
return run
|
||||
@@ -42,18 +84,25 @@ def init_run(
|
||||
init_kwargs = dict(kwargs)
|
||||
init_kwargs["project"] = project
|
||||
if config is not None:
|
||||
init_kwargs["config"] = dict(config)
|
||||
init_kwargs["config"] = _sanitize_payload(dict(config))
|
||||
if name:
|
||||
init_kwargs["name"] = name
|
||||
if tags:
|
||||
init_kwargs["tags"] = list(tags)
|
||||
return wandb.init(**init_kwargs)
|
||||
try:
|
||||
return wandb.init(**init_kwargs)
|
||||
except Exception as exc:
|
||||
_warn(f"init failed ({exc})")
|
||||
return None
|
||||
|
||||
|
||||
def finish_run() -> None:
|
||||
wandb = get_wandb_module()
|
||||
if wandb is not None and wandb.run is not None:
|
||||
wandb.finish()
|
||||
try:
|
||||
wandb.finish()
|
||||
except Exception as exc:
|
||||
_warn(f"finish failed ({exc})")
|
||||
|
||||
|
||||
def current_config() -> dict[str, Any]:
|
||||
@@ -67,25 +116,45 @@ def update_run_config(config: Mapping[str, Any]) -> None:
|
||||
wandb = get_wandb_module()
|
||||
if wandb is None or wandb.run is None:
|
||||
return
|
||||
payload = _sanitize_payload(dict(config))
|
||||
if not payload:
|
||||
return
|
||||
try:
|
||||
wandb.config.update(dict(config), allow_val_change=True)
|
||||
wandb.config.update(payload, allow_val_change=True)
|
||||
except TypeError:
|
||||
wandb.config.update(dict(config))
|
||||
try:
|
||||
wandb.config.update(payload)
|
||||
except Exception as exc:
|
||||
_warn(f"config update failed ({exc})")
|
||||
except Exception as exc:
|
||||
_warn(f"config update failed ({exc})")
|
||||
|
||||
|
||||
def log_metrics(metrics: Mapping[str, Any], *, step: int) -> None:
|
||||
wandb = get_wandb_module()
|
||||
if wandb is None or wandb.run is None:
|
||||
return
|
||||
wandb.log(dict(metrics), step=step)
|
||||
payload = _sanitize_payload(dict(metrics))
|
||||
if not payload:
|
||||
return
|
||||
try:
|
||||
wandb.log(payload, step=step)
|
||||
except Exception as exc:
|
||||
_warn(f"log failed at step {step} ({exc})")
|
||||
|
||||
|
||||
def update_summary(metrics: Mapping[str, Any]) -> None:
|
||||
wandb = get_wandb_module()
|
||||
if wandb is None or wandb.run is None:
|
||||
return
|
||||
for key, value in metrics.items():
|
||||
wandb.run.summary[key] = value
|
||||
payload = _sanitize_payload(dict(metrics))
|
||||
if not payload:
|
||||
return
|
||||
try:
|
||||
for key, value in payload.items():
|
||||
wandb.run.summary[key] = value
|
||||
except Exception as exc:
|
||||
_warn(f"summary update failed ({exc})")
|
||||
|
||||
|
||||
def run_agent(
|
||||
@@ -95,4 +164,39 @@ def run_agent(
|
||||
count: int | None = None,
|
||||
) -> None:
|
||||
wandb = _require_wandb()
|
||||
wandb.agent(sweep_id, function=fn, count=count)
|
||||
retry_max = max(0, int(os.getenv("PHANTOM_WANDB_AGENT_RETRIES", "8")))
|
||||
retry_delay = max(1.0, float(os.getenv("PHANTOM_WANDB_AGENT_RETRY_DELAY", "5")))
|
||||
retry_backoff = max(
|
||||
1.0, float(os.getenv("PHANTOM_WANDB_AGENT_RETRY_BACKOFF", "1.5"))
|
||||
)
|
||||
retry_max_delay = max(
|
||||
retry_delay,
|
||||
float(os.getenv("PHANTOM_WANDB_AGENT_MAX_RETRY_DELAY", "60")),
|
||||
)
|
||||
|
||||
target = None if count is None else max(0, int(count))
|
||||
completed = 0
|
||||
|
||||
def _wrapped() -> None:
|
||||
nonlocal completed
|
||||
fn()
|
||||
completed += 1
|
||||
|
||||
attempt = 0
|
||||
while True:
|
||||
remaining = None if target is None else max(0, int(target - completed))
|
||||
if target is not None and remaining == 0:
|
||||
return
|
||||
try:
|
||||
wandb.agent(sweep_id, function=_wrapped, count=remaining)
|
||||
return
|
||||
except Exception as exc:
|
||||
attempt += 1
|
||||
if attempt > retry_max:
|
||||
raise
|
||||
wait = min(retry_max_delay, retry_delay * (retry_backoff ** (attempt - 1)))
|
||||
_warn(
|
||||
f"agent disconnected (attempt {attempt}/{retry_max}, "
|
||||
f"completed={completed}, remaining={remaining}): {exc}"
|
||||
)
|
||||
time.sleep(wait)
|
||||
|
||||
Reference in New Issue
Block a user