catchup: rogue scripts

This commit is contained in:
2026-02-27 12:45:46 +01:00
parent e8a9716f69
commit 5444a4ea13
27 changed files with 6908 additions and 2 deletions

View File

@@ -0,0 +1,93 @@
method: bayes
metric:
name: sweep/score
goal: maximize
command:
- ${env}
- python
- -m
- engine.train
parameters:
# fixed: always use JAX backend so TPU chips are actually exercised
use_jax:
value: true
# all four algos have JAX implementations
algo:
values: [ppo, a2c, dqn, qtable]
total_timesteps:
values: [50000, 80000, 120000]
checkpoint_interval:
value: 200000
seed:
values: [13, 42, 77]
n_products:
values: [8, 10, 12]
# COI framework parameters -- primary research variables
alpha:
distribution: uniform
min: 0.1
max: 0.6
lambda_coi:
distribution: uniform
min: 0.05
max: 0.6
robust_radius:
distribution: uniform
min: 0.0
max: 0.3
robust_points:
values: [3, 5, 7]
info_value:
distribution: uniform
min: 0.5
max: 2.0
revenue_weight:
values: [0.005, 0.01, 0.02]
# shared hyperparameters
learning_rate:
distribution: log_uniform_values
min: 1.0e-5
max: 1.0e-3
gamma:
values: [0.97, 0.99, 0.995]
# JAX parallelism -- key lever for TPU throughput
jax_num_envs:
values: [8, 16, 32]
jax_num_steps:
values: [64, 128, 256]
jax_num_minibatches:
values: [2, 4, 8]
jax_update_epochs:
values: [2, 4, 8]
# PPO/A2C specific
gae_lambda:
values: [0.9, 0.95, 0.98]
clip_range:
values: [0.1, 0.2, 0.3]
ent_coef:
values: [0.0, 0.005, 0.01]
# DQN specific
buffer_size:
values: [20000, 50000, 100000]
batch_size:
values: [128, 256, 512]
learning_starts:
values: [500, 1000, 3000]
exploration_fraction:
values: [0.1, 0.2, 0.3]
exploration_final_eps:
values: [0.01, 0.03, 0.05]
# QTable specific
q_lr:
values: [0.03, 0.05, 0.1, 0.2]
eps_end:
values: [0.02, 0.05, 0.1]
eps_decay:
values: [0.999, 0.9995, 0.9999]
# action space
action_levels:
values: [7, 9, 11]
action_scale_low:
values: [0.75, 0.8, 0.85]
action_scale_high:
values: [1.15, 1.2, 1.25]

View File

@@ -0,0 +1,64 @@
method: bayes
metric:
name: sweep/score
goal: maximize
command:
- ${env}
- python
- -m
- engine.train
parameters:
use_jax:
value: true
# pmap requires all workers to compile the same computation graph shape,
# so structural params are fixed -- only research/scalar params are swept
algo:
values: [ppo, a2c]
jax_num_envs:
value: 32
jax_num_steps:
value: 128
jax_num_minibatches:
value: 4
jax_update_epochs:
value: 4
total_timesteps:
value: 100000
checkpoint_interval:
value: 200000
n_products:
value: 10
action_levels:
value: 9
# research parameters -- primary sweep targets
alpha:
distribution: uniform
min: 0.1
max: 0.6
lambda_coi:
distribution: uniform
min: 0.05
max: 0.6
robust_radius:
distribution: uniform
min: 0.0
max: 0.3
info_value:
distribution: uniform
min: 0.5
max: 2.0
revenue_weight:
values: [0.005, 0.01, 0.02]
# training hyperparameters
learning_rate:
distribution: log_uniform_values
min: 1.0e-5
max: 1.0e-3
gamma:
values: [0.97, 0.99, 0.995]
gae_lambda:
values: [0.9, 0.95, 0.98]
clip_range:
values: [0.1, 0.2, 0.3]
ent_coef:
values: [0.0, 0.005, 0.01]

130
engine/wandb_checkpoint.py Normal file
View File

@@ -0,0 +1,130 @@
from __future__ import annotations
import hashlib
import json
import re
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Any, Mapping
try:
import wandb
from wandb.errors import CommError
HAS_WANDB = True
except ImportError:
HAS_WANDB = False
wandb = None # type: ignore[assignment]
CommError = RuntimeError # type: ignore[assignment]
def _safe_value(value: Any) -> Any:
if isinstance(value, (str, int, float, bool)) or value is None:
return value
if isinstance(value, (list, tuple)):
return [_safe_value(v) for v in value]
if isinstance(value, dict):
return {str(k): _safe_value(value[k]) for k in sorted(value)}
return str(value)
def _safe_scope(scope: str | None) -> str:
raw = "manual" if scope in (None, "") else str(scope)
cleaned = re.sub(r"[^A-Za-z0-9_.-]+", "-", raw).strip("-")
return cleaned or "manual"
def checkpoint_artifact_name(
cfg: Mapping[str, Any], *, backend: str, sweep_id: str | None = None
) -> str:
payload = {k: _safe_value(cfg[k]) for k in sorted(cfg)}
scope = _safe_scope(sweep_id)
canonical = json.dumps(
{"backend": backend, "scope": scope, "cfg": payload},
sort_keys=True,
separators=(",", ":"),
)
digest = hashlib.sha1(canonical.encode("utf-8")).hexdigest()[:14]
return f"phantom-{backend}-ckpt-{scope}-{digest}"[:128]
def _is_missing_artifact_error(exc: Exception) -> bool:
if isinstance(exc, CommError):
msg = str(exc).lower()
return "not found" in msg or "does not exist" in msg
return False
def download_latest_checkpoint(
artifact_name: str, *, file_name: str
) -> tuple[Path, dict[str, Any]] | None:
if not HAS_WANDB or wandb.run is None:
return None
try:
artifact = wandb.run.use_artifact(f"{artifact_name}:latest")
except Exception as exc:
if _is_missing_artifact_error(exc):
return None
raise
directory = Path(artifact.download())
checkpoint_path = directory / file_name
if not checkpoint_path.exists():
return None
metadata = dict(getattr(artifact, "metadata", {}) or {})
return checkpoint_path, metadata
def _aliases_from_metadata(metadata: dict[str, Any] | None) -> list[str]:
aliases = ["latest"]
if metadata is None:
return aliases
if "step" in metadata:
try:
aliases.append(f"step-{int(metadata['step'])}")
except (TypeError, ValueError):
pass
return aliases
def log_checkpoint_bytes(
artifact_name: str,
*,
file_name: str,
payload: bytes,
metadata: dict[str, Any] | None = None,
) -> bool:
if not HAS_WANDB or wandb.run is None:
return False
with TemporaryDirectory(prefix="phantom-ckpt-") as tmpdir:
path = Path(tmpdir) / file_name
path.write_bytes(payload)
artifact = wandb.Artifact(
name=artifact_name,
type="checkpoint",
metadata=metadata or {},
)
artifact.add_file(path.as_posix(), name=file_name)
wandb.log_artifact(artifact, aliases=_aliases_from_metadata(metadata))
return True
def log_checkpoint_file(
artifact_name: str,
*,
file_path: str | Path,
artifact_file_name: str,
metadata: dict[str, Any] | None = None,
) -> bool:
if not HAS_WANDB or wandb.run is None:
return False
src = Path(file_path)
if not src.exists():
return False
artifact = wandb.Artifact(
name=artifact_name,
type="checkpoint",
metadata=metadata or {},
)
artifact.add_file(src.as_posix(), name=artifact_file_name)
wandb.log_artifact(artifact, aliases=_aliases_from_metadata(metadata))
return True