mirror of
https://github.com/velocitatem/PHANTOM.git
synced 2026-05-31 08:33:36 +00:00
catchup: rogue scripts
This commit is contained in:
93
engine/sweeps/tpu_jax.yaml
Normal file
93
engine/sweeps/tpu_jax.yaml
Normal file
@@ -0,0 +1,93 @@
|
||||
method: bayes
|
||||
metric:
|
||||
name: sweep/score
|
||||
goal: maximize
|
||||
command:
|
||||
- ${env}
|
||||
- python
|
||||
- -m
|
||||
- engine.train
|
||||
parameters:
|
||||
# fixed: always use JAX backend so TPU chips are actually exercised
|
||||
use_jax:
|
||||
value: true
|
||||
# all four algos have JAX implementations
|
||||
algo:
|
||||
values: [ppo, a2c, dqn, qtable]
|
||||
total_timesteps:
|
||||
values: [50000, 80000, 120000]
|
||||
checkpoint_interval:
|
||||
value: 200000
|
||||
seed:
|
||||
values: [13, 42, 77]
|
||||
n_products:
|
||||
values: [8, 10, 12]
|
||||
# COI framework parameters -- primary research variables
|
||||
alpha:
|
||||
distribution: uniform
|
||||
min: 0.1
|
||||
max: 0.6
|
||||
lambda_coi:
|
||||
distribution: uniform
|
||||
min: 0.05
|
||||
max: 0.6
|
||||
robust_radius:
|
||||
distribution: uniform
|
||||
min: 0.0
|
||||
max: 0.3
|
||||
robust_points:
|
||||
values: [3, 5, 7]
|
||||
info_value:
|
||||
distribution: uniform
|
||||
min: 0.5
|
||||
max: 2.0
|
||||
revenue_weight:
|
||||
values: [0.005, 0.01, 0.02]
|
||||
# shared hyperparameters
|
||||
learning_rate:
|
||||
distribution: log_uniform_values
|
||||
min: 1.0e-5
|
||||
max: 1.0e-3
|
||||
gamma:
|
||||
values: [0.97, 0.99, 0.995]
|
||||
# JAX parallelism -- key lever for TPU throughput
|
||||
jax_num_envs:
|
||||
values: [8, 16, 32]
|
||||
jax_num_steps:
|
||||
values: [64, 128, 256]
|
||||
jax_num_minibatches:
|
||||
values: [2, 4, 8]
|
||||
jax_update_epochs:
|
||||
values: [2, 4, 8]
|
||||
# PPO/A2C specific
|
||||
gae_lambda:
|
||||
values: [0.9, 0.95, 0.98]
|
||||
clip_range:
|
||||
values: [0.1, 0.2, 0.3]
|
||||
ent_coef:
|
||||
values: [0.0, 0.005, 0.01]
|
||||
# DQN specific
|
||||
buffer_size:
|
||||
values: [20000, 50000, 100000]
|
||||
batch_size:
|
||||
values: [128, 256, 512]
|
||||
learning_starts:
|
||||
values: [500, 1000, 3000]
|
||||
exploration_fraction:
|
||||
values: [0.1, 0.2, 0.3]
|
||||
exploration_final_eps:
|
||||
values: [0.01, 0.03, 0.05]
|
||||
# QTable specific
|
||||
q_lr:
|
||||
values: [0.03, 0.05, 0.1, 0.2]
|
||||
eps_end:
|
||||
values: [0.02, 0.05, 0.1]
|
||||
eps_decay:
|
||||
values: [0.999, 0.9995, 0.9999]
|
||||
# action space
|
||||
action_levels:
|
||||
values: [7, 9, 11]
|
||||
action_scale_low:
|
||||
values: [0.75, 0.8, 0.85]
|
||||
action_scale_high:
|
||||
values: [1.15, 1.2, 1.25]
|
||||
64
engine/sweeps/tpu_pod.yaml
Normal file
64
engine/sweeps/tpu_pod.yaml
Normal file
@@ -0,0 +1,64 @@
|
||||
method: bayes
|
||||
metric:
|
||||
name: sweep/score
|
||||
goal: maximize
|
||||
command:
|
||||
- ${env}
|
||||
- python
|
||||
- -m
|
||||
- engine.train
|
||||
parameters:
|
||||
use_jax:
|
||||
value: true
|
||||
# pmap requires all workers to compile the same computation graph shape,
|
||||
# so structural params are fixed -- only research/scalar params are swept
|
||||
algo:
|
||||
values: [ppo, a2c]
|
||||
jax_num_envs:
|
||||
value: 32
|
||||
jax_num_steps:
|
||||
value: 128
|
||||
jax_num_minibatches:
|
||||
value: 4
|
||||
jax_update_epochs:
|
||||
value: 4
|
||||
total_timesteps:
|
||||
value: 100000
|
||||
checkpoint_interval:
|
||||
value: 200000
|
||||
n_products:
|
||||
value: 10
|
||||
action_levels:
|
||||
value: 9
|
||||
# research parameters -- primary sweep targets
|
||||
alpha:
|
||||
distribution: uniform
|
||||
min: 0.1
|
||||
max: 0.6
|
||||
lambda_coi:
|
||||
distribution: uniform
|
||||
min: 0.05
|
||||
max: 0.6
|
||||
robust_radius:
|
||||
distribution: uniform
|
||||
min: 0.0
|
||||
max: 0.3
|
||||
info_value:
|
||||
distribution: uniform
|
||||
min: 0.5
|
||||
max: 2.0
|
||||
revenue_weight:
|
||||
values: [0.005, 0.01, 0.02]
|
||||
# training hyperparameters
|
||||
learning_rate:
|
||||
distribution: log_uniform_values
|
||||
min: 1.0e-5
|
||||
max: 1.0e-3
|
||||
gamma:
|
||||
values: [0.97, 0.99, 0.995]
|
||||
gae_lambda:
|
||||
values: [0.9, 0.95, 0.98]
|
||||
clip_range:
|
||||
values: [0.1, 0.2, 0.3]
|
||||
ent_coef:
|
||||
values: [0.0, 0.005, 0.01]
|
||||
130
engine/wandb_checkpoint.py
Normal file
130
engine/wandb_checkpoint.py
Normal file
@@ -0,0 +1,130 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import re
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import Any, Mapping
|
||||
|
||||
try:
|
||||
import wandb
|
||||
from wandb.errors import CommError
|
||||
|
||||
HAS_WANDB = True
|
||||
except ImportError:
|
||||
HAS_WANDB = False
|
||||
wandb = None # type: ignore[assignment]
|
||||
CommError = RuntimeError # type: ignore[assignment]
|
||||
|
||||
|
||||
def _safe_value(value: Any) -> Any:
|
||||
if isinstance(value, (str, int, float, bool)) or value is None:
|
||||
return value
|
||||
if isinstance(value, (list, tuple)):
|
||||
return [_safe_value(v) for v in value]
|
||||
if isinstance(value, dict):
|
||||
return {str(k): _safe_value(value[k]) for k in sorted(value)}
|
||||
return str(value)
|
||||
|
||||
|
||||
def _safe_scope(scope: str | None) -> str:
|
||||
raw = "manual" if scope in (None, "") else str(scope)
|
||||
cleaned = re.sub(r"[^A-Za-z0-9_.-]+", "-", raw).strip("-")
|
||||
return cleaned or "manual"
|
||||
|
||||
|
||||
def checkpoint_artifact_name(
|
||||
cfg: Mapping[str, Any], *, backend: str, sweep_id: str | None = None
|
||||
) -> str:
|
||||
payload = {k: _safe_value(cfg[k]) for k in sorted(cfg)}
|
||||
scope = _safe_scope(sweep_id)
|
||||
canonical = json.dumps(
|
||||
{"backend": backend, "scope": scope, "cfg": payload},
|
||||
sort_keys=True,
|
||||
separators=(",", ":"),
|
||||
)
|
||||
digest = hashlib.sha1(canonical.encode("utf-8")).hexdigest()[:14]
|
||||
return f"phantom-{backend}-ckpt-{scope}-{digest}"[:128]
|
||||
|
||||
|
||||
def _is_missing_artifact_error(exc: Exception) -> bool:
|
||||
if isinstance(exc, CommError):
|
||||
msg = str(exc).lower()
|
||||
return "not found" in msg or "does not exist" in msg
|
||||
return False
|
||||
|
||||
|
||||
def download_latest_checkpoint(
|
||||
artifact_name: str, *, file_name: str
|
||||
) -> tuple[Path, dict[str, Any]] | None:
|
||||
if not HAS_WANDB or wandb.run is None:
|
||||
return None
|
||||
try:
|
||||
artifact = wandb.run.use_artifact(f"{artifact_name}:latest")
|
||||
except Exception as exc:
|
||||
if _is_missing_artifact_error(exc):
|
||||
return None
|
||||
raise
|
||||
directory = Path(artifact.download())
|
||||
checkpoint_path = directory / file_name
|
||||
if not checkpoint_path.exists():
|
||||
return None
|
||||
metadata = dict(getattr(artifact, "metadata", {}) or {})
|
||||
return checkpoint_path, metadata
|
||||
|
||||
|
||||
def _aliases_from_metadata(metadata: dict[str, Any] | None) -> list[str]:
|
||||
aliases = ["latest"]
|
||||
if metadata is None:
|
||||
return aliases
|
||||
if "step" in metadata:
|
||||
try:
|
||||
aliases.append(f"step-{int(metadata['step'])}")
|
||||
except (TypeError, ValueError):
|
||||
pass
|
||||
return aliases
|
||||
|
||||
|
||||
def log_checkpoint_bytes(
|
||||
artifact_name: str,
|
||||
*,
|
||||
file_name: str,
|
||||
payload: bytes,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> bool:
|
||||
if not HAS_WANDB or wandb.run is None:
|
||||
return False
|
||||
with TemporaryDirectory(prefix="phantom-ckpt-") as tmpdir:
|
||||
path = Path(tmpdir) / file_name
|
||||
path.write_bytes(payload)
|
||||
artifact = wandb.Artifact(
|
||||
name=artifact_name,
|
||||
type="checkpoint",
|
||||
metadata=metadata or {},
|
||||
)
|
||||
artifact.add_file(path.as_posix(), name=file_name)
|
||||
wandb.log_artifact(artifact, aliases=_aliases_from_metadata(metadata))
|
||||
return True
|
||||
|
||||
|
||||
def log_checkpoint_file(
|
||||
artifact_name: str,
|
||||
*,
|
||||
file_path: str | Path,
|
||||
artifact_file_name: str,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> bool:
|
||||
if not HAS_WANDB or wandb.run is None:
|
||||
return False
|
||||
src = Path(file_path)
|
||||
if not src.exists():
|
||||
return False
|
||||
artifact = wandb.Artifact(
|
||||
name=artifact_name,
|
||||
type="checkpoint",
|
||||
metadata=metadata or {},
|
||||
)
|
||||
artifact.add_file(src.as_posix(), name=artifact_file_name)
|
||||
wandb.log_artifact(artifact, aliases=_aliases_from_metadata(metadata))
|
||||
return True
|
||||
Reference in New Issue
Block a user