mirror of
https://github.com/velocitatem/PHANTOM.git
synced 2026-05-31 16:43:36 +00:00
99 lines
2.3 KiB
Python
99 lines
2.3 KiB
Python
from __future__ import annotations
|
|
|
|
from typing import Any, Callable, Iterable, Mapping
|
|
|
|
|
|
def get_wandb_module():
|
|
try:
|
|
import wandb
|
|
|
|
return wandb
|
|
except ImportError:
|
|
return None
|
|
|
|
|
|
def _require_wandb():
|
|
wandb = get_wandb_module()
|
|
if wandb is None:
|
|
raise ImportError("wandb is required for this workflow")
|
|
return wandb
|
|
|
|
|
|
def init_run(
|
|
*,
|
|
mode: str,
|
|
project: str | None = None,
|
|
config: Mapping[str, Any] | None = None,
|
|
name: str | None = None,
|
|
tags: Iterable[str] | None = None,
|
|
group: str | None = None,
|
|
sweep_mode: bool = False,
|
|
):
|
|
wandb = _require_wandb()
|
|
kwargs: dict[str, Any] = {"mode": mode}
|
|
if group:
|
|
kwargs["group"] = group
|
|
if sweep_mode:
|
|
run = wandb.init(**kwargs)
|
|
if name and run is not None:
|
|
run.name = name
|
|
return run
|
|
|
|
init_kwargs = dict(kwargs)
|
|
init_kwargs["project"] = project
|
|
if config is not None:
|
|
init_kwargs["config"] = dict(config)
|
|
if name:
|
|
init_kwargs["name"] = name
|
|
if tags:
|
|
init_kwargs["tags"] = list(tags)
|
|
return wandb.init(**init_kwargs)
|
|
|
|
|
|
def finish_run() -> None:
|
|
wandb = get_wandb_module()
|
|
if wandb is not None and wandb.run is not None:
|
|
wandb.finish()
|
|
|
|
|
|
def current_config() -> dict[str, Any]:
|
|
wandb = get_wandb_module()
|
|
if wandb is None or wandb.run is None:
|
|
return {}
|
|
return {key: wandb.config[key] for key in wandb.config.keys()}
|
|
|
|
|
|
def update_run_config(config: Mapping[str, Any]) -> None:
|
|
wandb = get_wandb_module()
|
|
if wandb is None or wandb.run is None:
|
|
return
|
|
try:
|
|
wandb.config.update(dict(config), allow_val_change=True)
|
|
except TypeError:
|
|
wandb.config.update(dict(config))
|
|
|
|
|
|
def log_metrics(metrics: Mapping[str, Any], *, step: int) -> None:
|
|
wandb = get_wandb_module()
|
|
if wandb is None or wandb.run is None:
|
|
return
|
|
wandb.log(dict(metrics), step=step)
|
|
|
|
|
|
def update_summary(metrics: Mapping[str, Any]) -> None:
|
|
wandb = get_wandb_module()
|
|
if wandb is None or wandb.run is None:
|
|
return
|
|
for key, value in metrics.items():
|
|
wandb.run.summary[key] = value
|
|
|
|
|
|
def run_agent(
|
|
sweep_id: str,
|
|
fn: Callable[[], None],
|
|
*,
|
|
count: int | None = None,
|
|
) -> None:
|
|
wandb = _require_wandb()
|
|
wandb.agent(sweep_id, function=fn, count=count)
|