mirror of
https://github.com/velocitatem/PHANTOM.git
synced 2026-05-31 08:33:36 +00:00
131 lines
3.7 KiB
Python
131 lines
3.7 KiB
Python
from __future__ import annotations
|
|
|
|
import hashlib
|
|
import json
|
|
import re
|
|
from pathlib import Path
|
|
from tempfile import TemporaryDirectory
|
|
from typing import Any, Mapping
|
|
|
|
try:
|
|
import wandb
|
|
from wandb.errors import CommError
|
|
|
|
HAS_WANDB = True
|
|
except ImportError:
|
|
HAS_WANDB = False
|
|
wandb = None # type: ignore[assignment]
|
|
CommError = RuntimeError # type: ignore[assignment]
|
|
|
|
|
|
def _safe_value(value: Any) -> Any:
|
|
if isinstance(value, (str, int, float, bool)) or value is None:
|
|
return value
|
|
if isinstance(value, (list, tuple)):
|
|
return [_safe_value(v) for v in value]
|
|
if isinstance(value, dict):
|
|
return {str(k): _safe_value(value[k]) for k in sorted(value)}
|
|
return str(value)
|
|
|
|
|
|
def _safe_scope(scope: str | None) -> str:
|
|
raw = "manual" if scope in (None, "") else str(scope)
|
|
cleaned = re.sub(r"[^A-Za-z0-9_.-]+", "-", raw).strip("-")
|
|
return cleaned or "manual"
|
|
|
|
|
|
def checkpoint_artifact_name(
|
|
cfg: Mapping[str, Any], *, backend: str, sweep_id: str | None = None
|
|
) -> str:
|
|
payload = {k: _safe_value(cfg[k]) for k in sorted(cfg)}
|
|
scope = _safe_scope(sweep_id)
|
|
canonical = json.dumps(
|
|
{"backend": backend, "scope": scope, "cfg": payload},
|
|
sort_keys=True,
|
|
separators=(",", ":"),
|
|
)
|
|
digest = hashlib.sha1(canonical.encode("utf-8")).hexdigest()[:14]
|
|
return f"phantom-{backend}-ckpt-{scope}-{digest}"[:128]
|
|
|
|
|
|
def _is_missing_artifact_error(exc: Exception) -> bool:
|
|
if isinstance(exc, CommError):
|
|
msg = str(exc).lower()
|
|
return "not found" in msg or "does not exist" in msg
|
|
return False
|
|
|
|
|
|
def download_latest_checkpoint(
|
|
artifact_name: str, *, file_name: str
|
|
) -> tuple[Path, dict[str, Any]] | None:
|
|
if not HAS_WANDB or wandb.run is None:
|
|
return None
|
|
try:
|
|
artifact = wandb.run.use_artifact(f"{artifact_name}:latest")
|
|
except Exception as exc:
|
|
if _is_missing_artifact_error(exc):
|
|
return None
|
|
raise
|
|
directory = Path(artifact.download())
|
|
checkpoint_path = directory / file_name
|
|
if not checkpoint_path.exists():
|
|
return None
|
|
metadata = dict(getattr(artifact, "metadata", {}) or {})
|
|
return checkpoint_path, metadata
|
|
|
|
|
|
def _aliases_from_metadata(metadata: dict[str, Any] | None) -> list[str]:
|
|
aliases = ["latest"]
|
|
if metadata is None:
|
|
return aliases
|
|
if "step" in metadata:
|
|
try:
|
|
aliases.append(f"step-{int(metadata['step'])}")
|
|
except (TypeError, ValueError):
|
|
pass
|
|
return aliases
|
|
|
|
|
|
def log_checkpoint_bytes(
|
|
artifact_name: str,
|
|
*,
|
|
file_name: str,
|
|
payload: bytes,
|
|
metadata: dict[str, Any] | None = None,
|
|
) -> bool:
|
|
if not HAS_WANDB or wandb.run is None:
|
|
return False
|
|
with TemporaryDirectory(prefix="phantom-ckpt-") as tmpdir:
|
|
path = Path(tmpdir) / file_name
|
|
path.write_bytes(payload)
|
|
artifact = wandb.Artifact(
|
|
name=artifact_name,
|
|
type="checkpoint",
|
|
metadata=metadata or {},
|
|
)
|
|
artifact.add_file(path.as_posix(), name=file_name)
|
|
wandb.log_artifact(artifact, aliases=_aliases_from_metadata(metadata))
|
|
return True
|
|
|
|
|
|
def log_checkpoint_file(
|
|
artifact_name: str,
|
|
*,
|
|
file_path: str | Path,
|
|
artifact_file_name: str,
|
|
metadata: dict[str, Any] | None = None,
|
|
) -> bool:
|
|
if not HAS_WANDB or wandb.run is None:
|
|
return False
|
|
src = Path(file_path)
|
|
if not src.exists():
|
|
return False
|
|
artifact = wandb.Artifact(
|
|
name=artifact_name,
|
|
type="checkpoint",
|
|
metadata=metadata or {},
|
|
)
|
|
artifact.add_file(src.as_posix(), name=artifact_file_name)
|
|
wandb.log_artifact(artifact, aliases=_aliases_from_metadata(metadata))
|
|
return True
|