mirror of
https://github.com/velocitatem/PHANTOM.git
synced 2026-05-31 08:33:36 +00:00
cleaning up jax bs
This commit is contained in:
26
Makefile
26
Makefile
@@ -27,10 +27,6 @@ AGENT_LOOP ?= 1
|
|||||||
RETRY_SECONDS ?= 20
|
RETRY_SECONDS ?= 20
|
||||||
|
|
||||||
TRAIN_IMAGE_REF := us-central1-docker.pkg.dev/phantom-trc/phantom/phantom-trainer
|
TRAIN_IMAGE_REF := us-central1-docker.pkg.dev/phantom-trc/phantom/phantom-trainer
|
||||||
TPU_NAME ?=
|
|
||||||
TPU_ZONE ?= us-central2-b
|
|
||||||
TPU_PROJECT ?= phantom-trc
|
|
||||||
TPU_REPO_DIR ?= /tmp/PHANTOM
|
|
||||||
|
|
||||||
SWEEP_ENV_LOAD = set -a; [ -f "$(SWEEP_ENV_FILE)" ] && . "$(SWEEP_ENV_FILE)" || true; set +a
|
SWEEP_ENV_LOAD = set -a; [ -f "$(SWEEP_ENV_FILE)" ] && . "$(SWEEP_ENV_FILE)" || true; set +a
|
||||||
|
|
||||||
@@ -38,7 +34,7 @@ SWEEP_ENV_LOAD = set -a; [ -f "$(SWEEP_ENV_FILE)" ] && . "$(SWEEP_ENV_FILE)" ||
|
|||||||
|
|
||||||
.PHONY: help
|
.PHONY: help
|
||||||
help:
|
help:
|
||||||
@echo "pdf.build pdf.watch pdf.clean pdf.genpop pdf.genpop.watch | test.backend test.e2e test.all | web.dev | install | train | benchmark | benchmark.agent | train.agent | train.bootstrap | train.tpu.pod | train.tpu.vm | train.tpu.vm.sweep | stats.lines"
|
@echo "pdf.build pdf.watch pdf.clean pdf.genpop pdf.genpop.watch | test.backend test.e2e test.all | web.dev | install | train | benchmark | benchmark.agent | train.agent | train.bootstrap | stats.lines"
|
||||||
@echo "backend.server backend.provider backend.worker | platform.up platform.down platform.logs | docker.train.publish"
|
@echo "backend.server backend.provider backend.worker | platform.up platform.down platform.logs | docker.train.publish"
|
||||||
@echo ""
|
@echo ""
|
||||||
@echo "Build general public version:"
|
@echo "Build general public version:"
|
||||||
@@ -137,26 +133,6 @@ wordcount:
|
|||||||
docker.train.publish:
|
docker.train.publish:
|
||||||
@TRAIN_IMAGE_REF="$(TRAIN_IMAGE_REF)" $(NX) run research:docker-train-publish
|
@TRAIN_IMAGE_REF="$(TRAIN_IMAGE_REF)" $(NX) run research:docker-train-publish
|
||||||
|
|
||||||
.PHONY: train.tpu.pod
|
|
||||||
train.tpu.pod:
|
|
||||||
@TPU_NAME="$(TPU_NAME)" TPU_ZONE="$(TPU_ZONE)" TPU_PROJECT="$(TPU_PROJECT)" SWEEP_ENV_FILE="$(SWEEP_ENV_FILE)" SWEEP_ID="$(SWEEP_ID)" AGENT_COUNT="$(AGENT_COUNT)" $(NX) run research:train-tpu-pod
|
|
||||||
|
|
||||||
.PHONY: train.tpu.vm.prepare
|
|
||||||
train.tpu.vm.prepare:
|
|
||||||
@TPU_NAME="$(TPU_NAME)" TPU_ZONE="$(TPU_ZONE)" TPU_PROJECT="$(TPU_PROJECT)" TPU_REPO_DIR="$(TPU_REPO_DIR)" $(NX) run research:train-tpu-vm-prepare
|
|
||||||
|
|
||||||
.PHONY: train.tpu.vm.run
|
|
||||||
train.tpu.vm.run:
|
|
||||||
@TPU_NAME="$(TPU_NAME)" TPU_ZONE="$(TPU_ZONE)" TPU_PROJECT="$(TPU_PROJECT)" TPU_REPO_DIR="$(TPU_REPO_DIR)" SWEEP_ENV_FILE="$(SWEEP_ENV_FILE)" LOCAL_TRAIN_ARGS="$(LOCAL_TRAIN_ARGS)" $(NX) run research:train-tpu-vm-run
|
|
||||||
|
|
||||||
.PHONY: train.tpu.vm
|
|
||||||
train.tpu.vm:
|
|
||||||
@TPU_NAME="$(TPU_NAME)" TPU_ZONE="$(TPU_ZONE)" TPU_PROJECT="$(TPU_PROJECT)" TPU_REPO_DIR="$(TPU_REPO_DIR)" SWEEP_ENV_FILE="$(SWEEP_ENV_FILE)" LOCAL_TRAIN_ARGS="$(LOCAL_TRAIN_ARGS)" $(NX) run research:train-tpu-vm
|
|
||||||
|
|
||||||
.PHONY: train.tpu.vm.sweep
|
|
||||||
train.tpu.vm.sweep:
|
|
||||||
@TPU_NAME="$(TPU_NAME)" TPU_ZONE="$(TPU_ZONE)" TPU_PROJECT="$(TPU_PROJECT)" TPU_REPO_DIR="$(TPU_REPO_DIR)" SWEEP_ENV_FILE="$(SWEEP_ENV_FILE)" SWEEP_ID="$(SWEEP_ID)" AGENT_COUNT="$(AGENT_COUNT)" $(NX) run research:train-tpu-vm-sweep
|
|
||||||
|
|
||||||
.PHONY: backend.server backend.provider backend.worker platform.up platform.down platform.logs
|
.PHONY: backend.server backend.provider backend.worker platform.up platform.down platform.logs
|
||||||
backend.server:
|
backend.server:
|
||||||
@$(NX) run backend-server:dev
|
@$(NX) run backend-server:dev
|
||||||
|
|||||||
@@ -7,36 +7,9 @@ WORKDIR /app
|
|||||||
COPY docker/trainer.requirements.txt /tmp/requirements.txt
|
COPY docker/trainer.requirements.txt /tmp/requirements.txt
|
||||||
RUN pip install --no-cache-dir -r /tmp/requirements.txt
|
RUN pip install --no-cache-dir -r /tmp/requirements.txt
|
||||||
|
|
||||||
# Optional for JAX-on-GPU workflows.
|
|
||||||
ARG INSTALL_JAX_GPU=false
|
|
||||||
RUN if [ "${INSTALL_JAX_GPU}" = "true" ]; then \
|
|
||||||
pip install --no-cache-dir "jax[cuda12]==0.4.30" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html; \
|
|
||||||
fi
|
|
||||||
|
|
||||||
COPY --chmod=755 docker/trainer-agent-entrypoint.sh /usr/local/bin/trainer-agent-entrypoint
|
COPY --chmod=755 docker/trainer-agent-entrypoint.sh /usr/local/bin/trainer-agent-entrypoint
|
||||||
COPY engine /app/engine
|
COPY engine /app/engine
|
||||||
|
|
||||||
ENV PYTHONPATH=/app \
|
ENV PYTHONPATH=/app
|
||||||
XLA_PYTHON_CLIENT_PREALLOCATE=false
|
|
||||||
|
|
||||||
ENTRYPOINT ["/usr/local/bin/trainer-agent-entrypoint"]
|
|
||||||
|
|
||||||
|
|
||||||
FROM python:3.11-slim AS tpu
|
|
||||||
|
|
||||||
WORKDIR /app
|
|
||||||
|
|
||||||
COPY docker/trainer.requirements.txt /tmp/requirements.txt
|
|
||||||
RUN pip install --no-cache-dir -r /tmp/requirements.txt
|
|
||||||
|
|
||||||
RUN pip install --no-cache-dir "jax[tpu]==0.4.30" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
|
|
||||||
|
|
||||||
COPY --chmod=755 docker/trainer-agent-entrypoint.sh /usr/local/bin/trainer-agent-entrypoint
|
|
||||||
COPY engine /app/engine
|
|
||||||
|
|
||||||
ENV PYTHONPATH=/app \
|
|
||||||
PHANTOM_USE_JAX=1 \
|
|
||||||
PHANTOM_DEFAULT_AGENT_ARGS="--jax" \
|
|
||||||
XLA_PYTHON_CLIENT_PREALLOCATE=false
|
|
||||||
|
|
||||||
ENTRYPOINT ["/usr/local/bin/trainer-agent-entrypoint"]
|
ENTRYPOINT ["/usr/local/bin/trainer-agent-entrypoint"]
|
||||||
|
|||||||
@@ -5,9 +5,3 @@ gymnasium>=0.29.0
|
|||||||
stable-baselines3>=2.2.0
|
stable-baselines3>=2.2.0
|
||||||
tensorboard>=2.15.0
|
tensorboard>=2.15.0
|
||||||
wandb>=0.17.0
|
wandb>=0.17.0
|
||||||
tensorflow-probability==0.24.0
|
|
||||||
flax==0.10.7
|
|
||||||
optax==0.2.7
|
|
||||||
distrax==0.1.5
|
|
||||||
orbax-checkpoint==0.11.32
|
|
||||||
chex==0.1.90
|
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
__all__ = ["evaluate", "make_env", "train_jax_backend", "train_qtable", "train_sb3"]
|
__all__ = ["evaluate", "make_env", "train_qtable", "train_sb3"]
|
||||||
|
|||||||
@@ -1,18 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import Any, Mapping
|
|
||||||
|
|
||||||
from ..jax import JAX_AVAILABLE
|
|
||||||
|
|
||||||
|
|
||||||
def train_jax_backend(
|
|
||||||
cfg: Mapping[str, Any],
|
|
||||||
) -> tuple[dict[str, Any], dict[str, float | int | str]]:
|
|
||||||
if not JAX_AVAILABLE:
|
|
||||||
raise ImportError(
|
|
||||||
"JAX backend requested but JAX is not installed. "
|
|
||||||
"Install engine/jax/requirements.txt and jax[tpu] for TPU runs."
|
|
||||||
)
|
|
||||||
from ..jax.train import train_jax
|
|
||||||
|
|
||||||
return train_jax(dict(cfg))
|
|
||||||
@@ -7,7 +7,9 @@ import numpy as np
|
|||||||
from .common import evaluate, make_env
|
from .common import evaluate, make_env
|
||||||
|
|
||||||
|
|
||||||
def train_qtable(cfg: Mapping[str, Any]) -> tuple[object, dict[str, float | int]]:
|
def train_qtable(
|
||||||
|
cfg: Mapping[str, Any],
|
||||||
|
) -> tuple[object, dict[str, Any]]:
|
||||||
from ..lib.discrete import EventQTable
|
from ..lib.discrete import EventQTable
|
||||||
|
|
||||||
np.random.seed(int(cfg["seed"]))
|
np.random.seed(int(cfg["seed"]))
|
||||||
@@ -26,8 +28,19 @@ def train_qtable(cfg: Mapping[str, Any]) -> tuple[object, dict[str, float | int]
|
|||||||
total_revenue = 0.0
|
total_revenue = 0.0
|
||||||
steps = 0
|
steps = 0
|
||||||
epsilon = float(cfg["eps_start"])
|
epsilon = float(cfg["eps_start"])
|
||||||
|
log_freq = max(1, int(cfg.get("log_freq", 100)))
|
||||||
obs, _ = env.reset(seed=int(cfg["seed"]))
|
obs, _ = env.reset(seed=int(cfg["seed"]))
|
||||||
|
|
||||||
|
interval_sums = {
|
||||||
|
"reward": 0.0,
|
||||||
|
"revenue": 0.0,
|
||||||
|
"agent_prob": 0.0,
|
||||||
|
"alpha_adv": 0.0,
|
||||||
|
"coi_leakage": 0.0,
|
||||||
|
}
|
||||||
|
interval_count = 0
|
||||||
|
train_events: list[dict[str, float | int]] = []
|
||||||
|
|
||||||
for _ in range(int(cfg["total_timesteps"])):
|
for _ in range(int(cfg["total_timesteps"])):
|
||||||
action, state = agent.act(obs, epsilon)
|
action, state = agent.act(obs, epsilon)
|
||||||
nxt, reward, term, trunc, info = env.step(action)
|
nxt, reward, term, trunc, info = env.step(action)
|
||||||
@@ -35,18 +48,57 @@ def train_qtable(cfg: Mapping[str, Any]) -> tuple[object, dict[str, float | int]
|
|||||||
agent.update(state, action, float(reward), agent.encode(nxt), done)
|
agent.update(state, action, float(reward), agent.encode(nxt), done)
|
||||||
|
|
||||||
total_reward += float(reward)
|
total_reward += float(reward)
|
||||||
total_revenue += float(info.get("economics", {}).get("revenue", 0.0))
|
revenue = float(info.get("economics", {}).get("revenue", 0.0))
|
||||||
|
total_revenue += revenue
|
||||||
steps += 1
|
steps += 1
|
||||||
|
interval_sums["reward"] += float(reward)
|
||||||
|
interval_sums["revenue"] += revenue
|
||||||
|
interval_sums["agent_prob"] += float(info.get("agent_prob", 0.0))
|
||||||
|
interval_sums["alpha_adv"] += float(info.get("alpha_adv", 0.0))
|
||||||
|
interval_sums["coi_leakage"] += float(info.get("coi_leakage", 0.0))
|
||||||
|
interval_count += 1
|
||||||
|
|
||||||
|
if steps % log_freq == 0 and interval_count > 0:
|
||||||
|
denom = float(interval_count)
|
||||||
|
train_events.append(
|
||||||
|
{
|
||||||
|
"train/reward_mean": interval_sums["reward"] / denom,
|
||||||
|
"train/revenue_mean": interval_sums["revenue"] / denom,
|
||||||
|
"train/agent_prob": interval_sums["agent_prob"] / denom,
|
||||||
|
"train/alpha_adv": interval_sums["alpha_adv"] / denom,
|
||||||
|
"train/coi_leakage": interval_sums["coi_leakage"] / denom,
|
||||||
|
"train/epsilon": float(epsilon),
|
||||||
|
"train/global_step": int(steps),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
interval_sums = {key: 0.0 for key in interval_sums}
|
||||||
|
interval_count = 0
|
||||||
|
|
||||||
epsilon = max(float(cfg["eps_end"]), epsilon * float(cfg["eps_decay"]))
|
epsilon = max(float(cfg["eps_end"]), epsilon * float(cfg["eps_decay"]))
|
||||||
obs = env.reset()[0] if done else nxt
|
obs = env.reset()[0] if done else nxt
|
||||||
|
|
||||||
metrics: dict[str, float | int] = {
|
if interval_count > 0:
|
||||||
|
denom = float(interval_count)
|
||||||
|
train_events.append(
|
||||||
|
{
|
||||||
|
"train/reward_mean": interval_sums["reward"] / denom,
|
||||||
|
"train/revenue_mean": interval_sums["revenue"] / denom,
|
||||||
|
"train/agent_prob": interval_sums["agent_prob"] / denom,
|
||||||
|
"train/alpha_adv": interval_sums["alpha_adv"] / denom,
|
||||||
|
"train/coi_leakage": interval_sums["coi_leakage"] / denom,
|
||||||
|
"train/epsilon": float(epsilon),
|
||||||
|
"train/global_step": int(steps),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
metrics: dict[str, Any] = {
|
||||||
"train/reward_mean": total_reward / max(steps, 1),
|
"train/reward_mean": total_reward / max(steps, 1),
|
||||||
"train/revenue_mean": total_revenue / max(steps, 1),
|
"train/revenue_mean": total_revenue / max(steps, 1),
|
||||||
"train/epsilon": float(epsilon),
|
"train/epsilon": float(epsilon),
|
||||||
"train/global_step": int(cfg["total_timesteps"]),
|
"train/global_step": int(cfg["total_timesteps"]),
|
||||||
}
|
}
|
||||||
metrics.update(evaluate(agent, eval_env, int(cfg["eval_episodes"])))
|
metrics.update(evaluate(agent, eval_env, int(cfg["eval_episodes"])))
|
||||||
|
metrics["_train_events"] = train_events
|
||||||
|
|
||||||
env.close()
|
env.close()
|
||||||
eval_env.close()
|
eval_env.close()
|
||||||
|
|||||||
@@ -4,9 +4,7 @@ import json
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Mapping
|
from typing import Any, Mapping
|
||||||
|
|
||||||
from ..lib.callbacks import CheckpointArtifactCallback, MetricsCallback
|
from ..lib.callbacks import MetricsCallback
|
||||||
from ..telemetry.wandb import get_wandb_module
|
|
||||||
from ..wandb_checkpoint import checkpoint_artifact_name, download_latest_checkpoint
|
|
||||||
from .common import evaluate, make_env
|
from .common import evaluate, make_env
|
||||||
|
|
||||||
|
|
||||||
@@ -52,21 +50,6 @@ def _policy_kwargs(cfg: Mapping[str, Any]) -> dict[str, Any]:
|
|||||||
return kwargs
|
return kwargs
|
||||||
|
|
||||||
|
|
||||||
def _sb3_model_cls(algo: str):
|
|
||||||
try:
|
|
||||||
from stable_baselines3 import A2C, DQN, PPO
|
|
||||||
except ImportError as exc:
|
|
||||||
raise ImportError("stable-baselines3 is required for SB3 algorithms") from exc
|
|
||||||
|
|
||||||
if algo == "ppo":
|
|
||||||
return PPO
|
|
||||||
if algo == "a2c":
|
|
||||||
return A2C
|
|
||||||
if algo == "dqn":
|
|
||||||
return DQN
|
|
||||||
raise ValueError(f"unsupported algo '{algo}'")
|
|
||||||
|
|
||||||
|
|
||||||
def build_model(cfg: Mapping[str, Any], env: Any):
|
def build_model(cfg: Mapping[str, Any], env: Any):
|
||||||
try:
|
try:
|
||||||
from stable_baselines3 import A2C, DQN, PPO
|
from stable_baselines3 import A2C, DQN, PPO
|
||||||
@@ -132,29 +115,7 @@ def build_model(cfg: Mapping[str, Any], env: Any):
|
|||||||
raise ValueError(f"unsupported algo '{algo}'")
|
raise ValueError(f"unsupported algo '{algo}'")
|
||||||
|
|
||||||
|
|
||||||
def _maybe_resume_model(cfg: Mapping[str, Any], env: Any, model: Any):
|
def train_sb3(cfg: Mapping[str, Any]) -> tuple[object, dict[str, Any]]:
|
||||||
wandb = get_wandb_module()
|
|
||||||
if wandb is None or wandb.run is None:
|
|
||||||
return model
|
|
||||||
|
|
||||||
sweep_id = getattr(wandb.run, "sweep_id", None)
|
|
||||||
artifact_name = checkpoint_artifact_name(cfg, backend="sb3", sweep_id=sweep_id)
|
|
||||||
checkpoint_file = f"phantom_{cfg['algo']}_checkpoint.zip"
|
|
||||||
restored = download_latest_checkpoint(artifact_name, file_name=checkpoint_file)
|
|
||||||
if restored is None:
|
|
||||||
return model
|
|
||||||
|
|
||||||
checkpoint_path, metadata = restored
|
|
||||||
resumed = _sb3_model_cls(str(cfg["algo"]).lower()).load(
|
|
||||||
checkpoint_path.as_posix(),
|
|
||||||
env=env,
|
|
||||||
)
|
|
||||||
resume_step = int(metadata.get("step", getattr(resumed, "num_timesteps", 0)))
|
|
||||||
resumed.num_timesteps = max(int(getattr(resumed, "num_timesteps", 0)), resume_step)
|
|
||||||
return resumed
|
|
||||||
|
|
||||||
|
|
||||||
def train_sb3(cfg: Mapping[str, Any]) -> tuple[object, dict[str, float | int | str]]:
|
|
||||||
try:
|
try:
|
||||||
from stable_baselines3.common.callbacks import EvalCallback
|
from stable_baselines3.common.callbacks import EvalCallback
|
||||||
from stable_baselines3.common.monitor import Monitor
|
from stable_baselines3.common.monitor import Monitor
|
||||||
@@ -182,15 +143,10 @@ def train_sb3(cfg: Mapping[str, Any]) -> tuple[object, dict[str, float | int | s
|
|||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
model = _maybe_resume_model(cfg, env, model)
|
metrics_callback = MetricsCallback(
|
||||||
|
log_histograms=False, log_freq=int(cfg["log_freq"])
|
||||||
callbacks = [MetricsCallback(log_histograms=False, log_freq=int(cfg["log_freq"]))]
|
|
||||||
callbacks.append(
|
|
||||||
CheckpointArtifactCallback(
|
|
||||||
dict(cfg),
|
|
||||||
interval=int(cfg.get("checkpoint_interval", 10_000)),
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
callbacks = [metrics_callback]
|
||||||
callbacks.append(
|
callbacks.append(
|
||||||
EvalCallback(
|
EvalCallback(
|
||||||
eval_env,
|
eval_env,
|
||||||
@@ -215,13 +171,14 @@ def train_sb3(cfg: Mapping[str, Any]) -> tuple[object, dict[str, float | int | s
|
|||||||
model_path = model_dir / f"phantom_{cfg['algo']}"
|
model_path = model_dir / f"phantom_{cfg['algo']}"
|
||||||
model.save(str(model_path))
|
model.save(str(model_path))
|
||||||
|
|
||||||
metrics: dict[str, float | int | str] = evaluate(
|
metrics: dict[str, Any] = evaluate(
|
||||||
model,
|
model,
|
||||||
eval_env,
|
eval_env,
|
||||||
int(cfg["eval_episodes"]),
|
int(cfg["eval_episodes"]),
|
||||||
)
|
)
|
||||||
metrics["train/global_step"] = int(model.num_timesteps)
|
metrics["train/global_step"] = int(model.num_timesteps)
|
||||||
metrics["model/path"] = str(model_path.with_suffix(".zip"))
|
metrics["model/path"] = str(model_path.with_suffix(".zip"))
|
||||||
|
metrics["_train_events"] = list(metrics_callback.events)
|
||||||
|
|
||||||
env.close()
|
env.close()
|
||||||
eval_env.close()
|
eval_env.close()
|
||||||
|
|||||||
@@ -1,13 +0,0 @@
|
|||||||
"""JAX-compatible training and environment modules for PHANTOM."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
try:
|
|
||||||
import jax # noqa: F401
|
|
||||||
import jax.numpy as jnp # noqa: F401
|
|
||||||
|
|
||||||
JAX_AVAILABLE = True
|
|
||||||
except ImportError:
|
|
||||||
JAX_AVAILABLE = False
|
|
||||||
|
|
||||||
__all__ = ["JAX_AVAILABLE"]
|
|
||||||
@@ -1,49 +0,0 @@
|
|||||||
"""Orbax checkpoint helpers for JAX training runs."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
try:
|
|
||||||
import orbax.checkpoint as ocp
|
|
||||||
|
|
||||||
HAS_ORBAX = True
|
|
||||||
except ImportError:
|
|
||||||
HAS_ORBAX = False
|
|
||||||
|
|
||||||
|
|
||||||
def _require_orbax() -> None:
|
|
||||||
if not HAS_ORBAX:
|
|
||||||
raise ImportError(
|
|
||||||
"orbax-checkpoint is required for checkpoint support. "
|
|
||||||
"Install engine/jax/requirements.txt first."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def create_manager(directory: str | Path, max_to_keep: int = 5):
|
|
||||||
_require_orbax()
|
|
||||||
root = Path(directory)
|
|
||||||
root.mkdir(parents=True, exist_ok=True)
|
|
||||||
options = ocp.CheckpointManagerOptions(
|
|
||||||
max_to_keep=max(1, int(max_to_keep)), create=True
|
|
||||||
)
|
|
||||||
return ocp.CheckpointManager(root.as_posix(), ocp.PyTreeCheckpointer(), options)
|
|
||||||
|
|
||||||
|
|
||||||
def save(manager, *, step: int, payload: Any) -> bool:
|
|
||||||
_require_orbax()
|
|
||||||
return bool(manager.save(int(step), payload))
|
|
||||||
|
|
||||||
|
|
||||||
def latest_step(manager) -> int | None:
|
|
||||||
_require_orbax()
|
|
||||||
return manager.latest_step()
|
|
||||||
|
|
||||||
|
|
||||||
def restore(manager, *, target: Any, step: int | None = None) -> Any:
|
|
||||||
_require_orbax()
|
|
||||||
step_to_restore = manager.latest_step() if step is None else int(step)
|
|
||||||
if step_to_restore is None:
|
|
||||||
return target
|
|
||||||
return manager.restore(step_to_restore, items=target)
|
|
||||||
@@ -1,304 +0,0 @@
|
|||||||
"""JAX-native PHANTOM environment with robust contamination step."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import NamedTuple
|
|
||||||
|
|
||||||
try:
|
|
||||||
import jax
|
|
||||||
import jax.numpy as jnp
|
|
||||||
except ImportError as exc: # pragma: no cover
|
|
||||||
raise ImportError("engine.jax.env requires JAX") from exc
|
|
||||||
|
|
||||||
from .primitives import (
|
|
||||||
_sample_sessions_jax,
|
|
||||||
agent_probability_from_kl,
|
|
||||||
batch_kl,
|
|
||||||
compute_session_transitions,
|
|
||||||
load_transition_data,
|
|
||||||
purchase_flags,
|
|
||||||
reward_with_coi_penalty,
|
|
||||||
revenue_from_demand,
|
|
||||||
weighted_demand,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class EnvParams(NamedTuple):
|
|
||||||
n_products: int
|
|
||||||
n_sessions: int
|
|
||||||
max_episode_steps: int
|
|
||||||
max_session_steps: int
|
|
||||||
price_low: float
|
|
||||||
price_high: float
|
|
||||||
lambda_coi: float
|
|
||||||
info_value: float
|
|
||||||
eta_ux: float
|
|
||||||
robust_radius: float
|
|
||||||
margin_floor: float
|
|
||||||
margin_floor_patience: int
|
|
||||||
action_scales: jax.Array
|
|
||||||
alpha_nominal: float
|
|
||||||
alpha_candidates: jax.Array
|
|
||||||
human_T: jax.Array
|
|
||||||
agent_T: jax.Array
|
|
||||||
terminal_mask: jax.Array
|
|
||||||
purchase_mask: jax.Array
|
|
||||||
event_weights: jax.Array
|
|
||||||
start_idx: int
|
|
||||||
term_idx: int
|
|
||||||
|
|
||||||
|
|
||||||
class EnvState(NamedTuple):
|
|
||||||
prices: jax.Array
|
|
||||||
demand: jax.Array
|
|
||||||
step_count: jax.Array
|
|
||||||
low_margin_streak: jax.Array
|
|
||||||
last_agent_prob: jax.Array
|
|
||||||
last_alpha_adv: jax.Array
|
|
||||||
|
|
||||||
|
|
||||||
class CandidateEval(NamedTuple):
|
|
||||||
reward: jax.Array
|
|
||||||
revenue: jax.Array
|
|
||||||
demand: jax.Array
|
|
||||||
agent_prob: jax.Array
|
|
||||||
leakage: jax.Array
|
|
||||||
discount: jax.Array
|
|
||||||
ux_penalty: jax.Array
|
|
||||||
n_purchases: jax.Array
|
|
||||||
n_agents: jax.Array
|
|
||||||
|
|
||||||
|
|
||||||
def make_env_params(
|
|
||||||
*,
|
|
||||||
n_products: int,
|
|
||||||
alpha: float,
|
|
||||||
n_sessions: int,
|
|
||||||
lambda_coi: float,
|
|
||||||
robust_radius: float,
|
|
||||||
robust_points: int,
|
|
||||||
info_value: float,
|
|
||||||
eta_ux: float = 0.5,
|
|
||||||
action_levels: int,
|
|
||||||
action_scale_low: float,
|
|
||||||
action_scale_high: float,
|
|
||||||
price_low: float,
|
|
||||||
price_high: float,
|
|
||||||
max_episode_steps: int,
|
|
||||||
max_session_steps: int = 40,
|
|
||||||
margin_floor: float = 0.05,
|
|
||||||
margin_floor_patience: int = 5,
|
|
||||||
prefer_behavior_data: bool = True,
|
|
||||||
) -> EnvParams:
|
|
||||||
transition = load_transition_data(prefer_data=prefer_behavior_data).to_jax()
|
|
||||||
if robust_radius <= 0.0 or robust_points <= 1:
|
|
||||||
alpha_candidates = jnp.asarray([float(alpha)], dtype=jnp.float32)
|
|
||||||
else:
|
|
||||||
lo = max(0.0, float(alpha) - float(robust_radius))
|
|
||||||
hi = min(1.0, float(alpha) + float(robust_radius))
|
|
||||||
alpha_candidates = jnp.linspace(lo, hi, int(robust_points), dtype=jnp.float32)
|
|
||||||
|
|
||||||
action_scales = jnp.linspace(
|
|
||||||
float(action_scale_low),
|
|
||||||
float(action_scale_high),
|
|
||||||
int(action_levels),
|
|
||||||
dtype=jnp.float32,
|
|
||||||
)
|
|
||||||
return EnvParams(
|
|
||||||
n_products=int(n_products),
|
|
||||||
n_sessions=int(n_sessions),
|
|
||||||
max_episode_steps=int(max_episode_steps),
|
|
||||||
max_session_steps=int(max_session_steps),
|
|
||||||
price_low=float(price_low),
|
|
||||||
price_high=float(price_high),
|
|
||||||
lambda_coi=float(lambda_coi),
|
|
||||||
info_value=float(info_value),
|
|
||||||
eta_ux=float(eta_ux),
|
|
||||||
robust_radius=float(robust_radius),
|
|
||||||
margin_floor=float(margin_floor),
|
|
||||||
margin_floor_patience=int(margin_floor_patience),
|
|
||||||
action_scales=action_scales,
|
|
||||||
alpha_nominal=float(alpha),
|
|
||||||
alpha_candidates=alpha_candidates,
|
|
||||||
human_T=jnp.asarray(transition.human_T),
|
|
||||||
agent_T=jnp.asarray(transition.agent_T),
|
|
||||||
terminal_mask=jnp.asarray(transition.terminal_mask),
|
|
||||||
purchase_mask=jnp.asarray(transition.purchase_mask),
|
|
||||||
event_weights=jnp.asarray(transition.event_weights),
|
|
||||||
start_idx=int(transition.start_idx),
|
|
||||||
term_idx=int(transition.term_idx),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _flatten_obs(demand: jax.Array, prices: jax.Array) -> jax.Array:
|
|
||||||
return jnp.concatenate([demand.astype(jnp.float32), prices.astype(jnp.float32)])
|
|
||||||
|
|
||||||
|
|
||||||
def _decode_action(
|
|
||||||
prices: jax.Array, action: jax.Array, params: EnvParams
|
|
||||||
) -> jax.Array:
|
|
||||||
idx = jnp.clip(action.astype(jnp.int32), 0, params.action_scales.shape[0] - 1)
|
|
||||||
scale = params.action_scales[idx]
|
|
||||||
next_prices = prices * scale
|
|
||||||
return jnp.clip(next_prices, params.price_low, params.price_high)
|
|
||||||
|
|
||||||
|
|
||||||
def _evaluate_candidate(
|
|
||||||
key: jax.Array,
|
|
||||||
alpha_candidate: jax.Array,
|
|
||||||
prices: jax.Array,
|
|
||||||
ux_volatility: jax.Array,
|
|
||||||
params: EnvParams,
|
|
||||||
) -> CandidateEval:
|
|
||||||
states, products, actors, lengths = _sample_sessions_jax(
|
|
||||||
key,
|
|
||||||
params.human_T,
|
|
||||||
params.agent_T,
|
|
||||||
params.terminal_mask,
|
|
||||||
params.start_idx,
|
|
||||||
params.term_idx,
|
|
||||||
alpha_candidate,
|
|
||||||
params.n_products,
|
|
||||||
params.n_sessions,
|
|
||||||
params.max_session_steps,
|
|
||||||
int(params.human_T.shape[0]),
|
|
||||||
)
|
|
||||||
session_trans = compute_session_transitions(
|
|
||||||
states, lengths, int(params.human_T.shape[0])
|
|
||||||
)
|
|
||||||
delta_h, delta_a = batch_kl(session_trans, params.human_T, params.agent_T)
|
|
||||||
agent_probs = agent_probability_from_kl(delta_h, delta_a)
|
|
||||||
agent_prob = jnp.mean(agent_probs)
|
|
||||||
|
|
||||||
demand = weighted_demand(states, products, params.n_products, params.event_weights)
|
|
||||||
revenue = revenue_from_demand(prices, demand)
|
|
||||||
reward, leakage, discount, ux_penalty = reward_with_coi_penalty(
|
|
||||||
revenue,
|
|
||||||
agent_prob,
|
|
||||||
params.lambda_coi,
|
|
||||||
params.info_value,
|
|
||||||
params.eta_ux,
|
|
||||||
ux_volatility,
|
|
||||||
)
|
|
||||||
purchases = purchase_flags(states, params.purchase_mask)
|
|
||||||
return CandidateEval(
|
|
||||||
reward=reward,
|
|
||||||
revenue=revenue,
|
|
||||||
demand=demand,
|
|
||||||
agent_prob=agent_prob,
|
|
||||||
leakage=leakage,
|
|
||||||
discount=discount,
|
|
||||||
ux_penalty=ux_penalty,
|
|
||||||
n_purchases=jnp.sum(purchases.astype(jnp.float32)),
|
|
||||||
n_agents=jnp.sum(actors.astype(jnp.float32)),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def reset_env(key: jax.Array, params: EnvParams) -> tuple[jax.Array, EnvState]:
|
|
||||||
prices = jax.random.uniform(
|
|
||||||
key,
|
|
||||||
shape=(params.n_products,),
|
|
||||||
minval=params.price_low,
|
|
||||||
maxval=params.price_high,
|
|
||||||
)
|
|
||||||
demand = jnp.zeros((params.n_products,), dtype=jnp.float32)
|
|
||||||
state = EnvState(
|
|
||||||
prices=prices,
|
|
||||||
demand=demand,
|
|
||||||
step_count=jnp.asarray(0, dtype=jnp.int32),
|
|
||||||
low_margin_streak=jnp.asarray(0, dtype=jnp.int32),
|
|
||||||
last_agent_prob=jnp.asarray(params.alpha_nominal, dtype=jnp.float32),
|
|
||||||
last_alpha_adv=jnp.asarray(params.alpha_nominal, dtype=jnp.float32),
|
|
||||||
)
|
|
||||||
return _flatten_obs(demand, prices), state
|
|
||||||
|
|
||||||
|
|
||||||
def step_env(
|
|
||||||
key: jax.Array,
|
|
||||||
state: EnvState,
|
|
||||||
action: jax.Array,
|
|
||||||
params: EnvParams,
|
|
||||||
) -> tuple[jax.Array, EnvState, jax.Array, jax.Array, dict[str, jax.Array]]:
|
|
||||||
prices = _decode_action(state.prices, action, params)
|
|
||||||
|
|
||||||
baseline = jnp.maximum(state.prices, 1.0)
|
|
||||||
ux_volatility = jnp.where(
|
|
||||||
state.step_count > 0, jnp.mean(jnp.abs(prices - state.prices) / baseline), 0.0
|
|
||||||
)
|
|
||||||
|
|
||||||
n_candidates = params.alpha_candidates.shape[0]
|
|
||||||
cand_keys = jax.random.split(key, n_candidates)
|
|
||||||
evals = jax.vmap(
|
|
||||||
lambda k, a: _evaluate_candidate(k, a, prices, ux_volatility, params),
|
|
||||||
in_axes=(0, 0),
|
|
||||||
)(cand_keys, params.alpha_candidates)
|
|
||||||
idx = jnp.argmin(evals.reward)
|
|
||||||
|
|
||||||
demand = evals.demand[idx]
|
|
||||||
reward = evals.reward[idx]
|
|
||||||
revenue = evals.revenue[idx]
|
|
||||||
agent_prob = evals.agent_prob[idx]
|
|
||||||
leakage = evals.leakage[idx]
|
|
||||||
discount = evals.discount[idx]
|
|
||||||
ux_penalty = evals.ux_penalty[idx]
|
|
||||||
n_purchases = evals.n_purchases[idx]
|
|
||||||
n_agents = evals.n_agents[idx]
|
|
||||||
alpha_adv = params.alpha_candidates[idx]
|
|
||||||
|
|
||||||
step_count = state.step_count + 1
|
|
||||||
avg_price = jnp.maximum(jnp.mean(prices), 1e-6)
|
|
||||||
avg_margin = (avg_price - params.price_low) / avg_price
|
|
||||||
next_streak = jnp.where(
|
|
||||||
avg_margin < params.margin_floor, state.low_margin_streak + 1, 0
|
|
||||||
)
|
|
||||||
|
|
||||||
margin_collapsed = next_streak >= params.margin_floor_patience
|
|
||||||
done = (step_count >= params.max_episode_steps) | margin_collapsed
|
|
||||||
|
|
||||||
next_state = EnvState(
|
|
||||||
prices=prices,
|
|
||||||
demand=demand,
|
|
||||||
step_count=step_count,
|
|
||||||
low_margin_streak=next_streak,
|
|
||||||
last_agent_prob=agent_prob,
|
|
||||||
last_alpha_adv=alpha_adv,
|
|
||||||
)
|
|
||||||
obs = _flatten_obs(demand, prices)
|
|
||||||
info = {
|
|
||||||
"revenue": revenue,
|
|
||||||
"agent_prob": agent_prob,
|
|
||||||
"alpha_adv": alpha_adv,
|
|
||||||
"coi_leakage": leakage,
|
|
||||||
"coi_discount": discount,
|
|
||||||
"ux_penalty": ux_penalty,
|
|
||||||
"volatility": ux_volatility,
|
|
||||||
"n_purchases": n_purchases,
|
|
||||||
"n_agents": n_agents,
|
|
||||||
"avg_margin": avg_margin,
|
|
||||||
}
|
|
||||||
return obs, next_state, reward, done, info
|
|
||||||
|
|
||||||
|
|
||||||
class PHANTOMJAXEnv:
|
|
||||||
def __init__(self, params: EnvParams):
|
|
||||||
self.params = params
|
|
||||||
|
|
||||||
def reset(self, key: jax.Array, params: EnvParams | None = None):
|
|
||||||
return reset_env(key, self.params if params is None else params)
|
|
||||||
|
|
||||||
def step(
|
|
||||||
self,
|
|
||||||
key: jax.Array,
|
|
||||||
state: EnvState,
|
|
||||||
action: jax.Array,
|
|
||||||
params: EnvParams | None = None,
|
|
||||||
):
|
|
||||||
return step_env(key, state, action, self.params if params is None else params)
|
|
||||||
|
|
||||||
def action_space_n(self, params: EnvParams | None = None) -> int:
|
|
||||||
p = self.params if params is None else params
|
|
||||||
return int(p.action_scales.shape[0])
|
|
||||||
|
|
||||||
def observation_dim(self, params: EnvParams | None = None) -> int:
|
|
||||||
p = self.params if params is None else params
|
|
||||||
return int(p.n_products * 2)
|
|
||||||
@@ -1,501 +0,0 @@
|
|||||||
"""JAX-compatible primitives for PHANTOM session simulation and separability."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from functools import partial
|
|
||||||
from typing import Mapping
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
try:
|
|
||||||
import jax
|
|
||||||
import jax.numpy as jnp
|
|
||||||
|
|
||||||
JAX_AVAILABLE = True
|
|
||||||
except ImportError:
|
|
||||||
jax = None # type: ignore[assignment]
|
|
||||||
jnp = np # type: ignore[assignment]
|
|
||||||
JAX_AVAILABLE = False
|
|
||||||
|
|
||||||
|
|
||||||
STATE_START_KEYS = ("session_start", "start")
|
|
||||||
TERMINAL_EVENT_TOKENS = (
|
|
||||||
"session_end",
|
|
||||||
"end",
|
|
||||||
"purchase_complete",
|
|
||||||
"checkout_start",
|
|
||||||
"checkout",
|
|
||||||
)
|
|
||||||
PURCHASE_EVENT_TOKENS = (
|
|
||||||
"purchase_complete",
|
|
||||||
"purchase",
|
|
||||||
"checkout_start",
|
|
||||||
"checkout",
|
|
||||||
)
|
|
||||||
|
|
||||||
CATEGORY_WEIGHTS = {"cart": 4.0, "dwell": 2.0, "nav": 1.0, "filter": 0.5}
|
|
||||||
ACTION_CATEGORIES = {
|
|
||||||
"cart": {"add_item", "add_to_cart", "remove", "checkout", "purchase"},
|
|
||||||
"dwell": {
|
|
||||||
"hover_title",
|
|
||||||
"hover_paragraph",
|
|
||||||
"hover_link",
|
|
||||||
"hover_over_title",
|
|
||||||
"hover_over_paragraph",
|
|
||||||
"hover_over_link",
|
|
||||||
"hover_over_button",
|
|
||||||
},
|
|
||||||
"nav": {
|
|
||||||
"page_view",
|
|
||||||
"view_item",
|
|
||||||
"view",
|
|
||||||
"learn_more",
|
|
||||||
"learn_more_about_item",
|
|
||||||
"view_item_page",
|
|
||||||
"session_start",
|
|
||||||
},
|
|
||||||
"filter": {
|
|
||||||
"search",
|
|
||||||
"filter_date",
|
|
||||||
"filter_price",
|
|
||||||
"sort",
|
|
||||||
"filter_for_date",
|
|
||||||
"filter_for_price",
|
|
||||||
"filter_for_amenities",
|
|
||||||
"sort_change",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
DEFAULT_ACTION_WEIGHTS = {
|
|
||||||
action: CATEGORY_WEIGHTS[group]
|
|
||||||
for group, actions in ACTION_CATEGORIES.items()
|
|
||||||
for action in actions
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class TransitionData:
|
|
||||||
"""Dense transition kernels and per-state metadata."""
|
|
||||||
|
|
||||||
human_T: np.ndarray
|
|
||||||
agent_T: np.ndarray
|
|
||||||
terminal_mask: np.ndarray
|
|
||||||
purchase_mask: np.ndarray
|
|
||||||
event_weights: np.ndarray
|
|
||||||
event_names: tuple[str, ...]
|
|
||||||
start_idx: int
|
|
||||||
term_idx: int
|
|
||||||
|
|
||||||
def to_jax(self) -> "TransitionData":
|
|
||||||
if not JAX_AVAILABLE:
|
|
||||||
return self
|
|
||||||
return TransitionData(
|
|
||||||
human_T=jnp.asarray(self.human_T),
|
|
||||||
agent_T=jnp.asarray(self.agent_T),
|
|
||||||
terminal_mask=jnp.asarray(self.terminal_mask),
|
|
||||||
purchase_mask=jnp.asarray(self.purchase_mask),
|
|
||||||
event_weights=jnp.asarray(self.event_weights),
|
|
||||||
event_names=self.event_names,
|
|
||||||
start_idx=int(self.start_idx),
|
|
||||||
term_idx=int(self.term_idx),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class SessionBatch:
|
|
||||||
states: np.ndarray
|
|
||||||
products: np.ndarray
|
|
||||||
actors: np.ndarray
|
|
||||||
lengths: np.ndarray
|
|
||||||
|
|
||||||
|
|
||||||
def _event_weight(name: str) -> float:
|
|
||||||
if name in DEFAULT_ACTION_WEIGHTS:
|
|
||||||
return float(DEFAULT_ACTION_WEIGHTS[name])
|
|
||||||
if name.startswith("hover"):
|
|
||||||
return float(CATEGORY_WEIGHTS["dwell"])
|
|
||||||
if name.startswith("filter") or name in {"search", "sort", "sort_change"}:
|
|
||||||
return float(CATEGORY_WEIGHTS["filter"])
|
|
||||||
if name.startswith("add") or name in {
|
|
||||||
"checkout",
|
|
||||||
"checkout_start",
|
|
||||||
"purchase",
|
|
||||||
"remove_item",
|
|
||||||
"purchase_complete",
|
|
||||||
}:
|
|
||||||
return float(CATEGORY_WEIGHTS["cart"])
|
|
||||||
if any(token in name for token in TERMINAL_EVENT_TOKENS):
|
|
||||||
return 0.0
|
|
||||||
return float(CATEGORY_WEIGHTS["nav"])
|
|
||||||
|
|
||||||
|
|
||||||
def _is_terminal(name: str) -> bool:
|
|
||||||
return any(token in name for token in TERMINAL_EVENT_TOKENS)
|
|
||||||
|
|
||||||
|
|
||||||
def _is_purchase(name: str) -> bool:
|
|
||||||
return any(token in name for token in PURCHASE_EVENT_TOKENS)
|
|
||||||
|
|
||||||
|
|
||||||
def _collect_events(*transitions: Mapping[str, Mapping[str, float]]) -> tuple[str, ...]:
|
|
||||||
names: set[str] = set()
|
|
||||||
for trans in transitions:
|
|
||||||
for src, dsts in trans.items():
|
|
||||||
names.add(src)
|
|
||||||
names.update(dsts.keys())
|
|
||||||
names.discard("__terminal__")
|
|
||||||
return tuple(sorted(names))
|
|
||||||
|
|
||||||
|
|
||||||
def _normalize_rows(matrix: np.ndarray, term_idx: int) -> np.ndarray:
|
|
||||||
row_sums = matrix.sum(axis=1, keepdims=True)
|
|
||||||
dead_rows = np.isclose(row_sums.squeeze(-1), 0.0)
|
|
||||||
if np.any(dead_rows):
|
|
||||||
matrix[dead_rows] = 0.0
|
|
||||||
matrix[dead_rows, term_idx] = 1.0
|
|
||||||
row_sums = matrix.sum(axis=1, keepdims=True)
|
|
||||||
return matrix / np.maximum(row_sums, 1e-8)
|
|
||||||
|
|
||||||
|
|
||||||
def _dense_from_dict(
|
|
||||||
transitions: Mapping[str, Mapping[str, float]],
|
|
||||||
event_to_idx: Mapping[str, int],
|
|
||||||
term_idx: int,
|
|
||||||
) -> np.ndarray:
|
|
||||||
n_states = len(event_to_idx)
|
|
||||||
matrix = np.zeros((n_states, n_states), dtype=np.float32)
|
|
||||||
for src, dsts in transitions.items():
|
|
||||||
i = event_to_idx.get(src)
|
|
||||||
if i is None:
|
|
||||||
continue
|
|
||||||
for dst, prob in dsts.items():
|
|
||||||
j = event_to_idx.get(dst)
|
|
||||||
if j is None:
|
|
||||||
continue
|
|
||||||
matrix[i, j] += float(prob)
|
|
||||||
return _normalize_rows(matrix, term_idx)
|
|
||||||
|
|
||||||
|
|
||||||
def compile_transition_data(
|
|
||||||
human_transitions: Mapping[str, Mapping[str, float]],
|
|
||||||
agent_transitions: Mapping[str, Mapping[str, float]],
|
|
||||||
) -> TransitionData:
|
|
||||||
event_names = _collect_events(human_transitions, agent_transitions)
|
|
||||||
if not event_names:
|
|
||||||
return fallback_transition_data()
|
|
||||||
|
|
||||||
event_names = tuple([*event_names, "__terminal__"])
|
|
||||||
term_idx = len(event_names) - 1
|
|
||||||
event_to_idx = {name: i for i, name in enumerate(event_names)}
|
|
||||||
|
|
||||||
human_T = _dense_from_dict(human_transitions, event_to_idx, term_idx)
|
|
||||||
agent_T = _dense_from_dict(agent_transitions, event_to_idx, term_idx)
|
|
||||||
|
|
||||||
terminal_mask = np.array([_is_terminal(name) for name in event_names], dtype=bool)
|
|
||||||
purchase_mask = np.array([_is_purchase(name) for name in event_names], dtype=bool)
|
|
||||||
event_weights = np.array(
|
|
||||||
[_event_weight(name) for name in event_names], dtype=np.float32
|
|
||||||
)
|
|
||||||
|
|
||||||
terminal_mask[term_idx] = True
|
|
||||||
|
|
||||||
for idx, is_term in enumerate(terminal_mask):
|
|
||||||
if not is_term:
|
|
||||||
continue
|
|
||||||
human_T[idx] = 0.0
|
|
||||||
agent_T[idx] = 0.0
|
|
||||||
human_T[idx, idx] = 1.0
|
|
||||||
agent_T[idx, idx] = 1.0
|
|
||||||
|
|
||||||
start_idx = 0
|
|
||||||
for key in STATE_START_KEYS:
|
|
||||||
if key in event_to_idx:
|
|
||||||
start_idx = int(event_to_idx[key])
|
|
||||||
break
|
|
||||||
|
|
||||||
return TransitionData(
|
|
||||||
human_T=human_T,
|
|
||||||
agent_T=agent_T,
|
|
||||||
terminal_mask=terminal_mask,
|
|
||||||
purchase_mask=purchase_mask,
|
|
||||||
event_weights=event_weights,
|
|
||||||
event_names=event_names,
|
|
||||||
start_idx=start_idx,
|
|
||||||
term_idx=term_idx,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def fallback_transition_data() -> TransitionData:
|
|
||||||
human = {
|
|
||||||
"session_start": {
|
|
||||||
"page_view": 0.80,
|
|
||||||
"view_item_page": 0.15,
|
|
||||||
"session_end": 0.05,
|
|
||||||
},
|
|
||||||
"page_view": {"view_item_page": 0.55, "search": 0.25, "session_end": 0.20},
|
|
||||||
"view_item_page": {
|
|
||||||
"learn_more_about_item": 0.40,
|
|
||||||
"add_item_to_cart": 0.28,
|
|
||||||
"session_end": 0.32,
|
|
||||||
},
|
|
||||||
"learn_more_about_item": {
|
|
||||||
"add_item_to_cart": 0.50,
|
|
||||||
"view_item_page": 0.30,
|
|
||||||
"session_end": 0.20,
|
|
||||||
},
|
|
||||||
"add_item_to_cart": {
|
|
||||||
"checkout_start": 0.58,
|
|
||||||
"view_item_page": 0.24,
|
|
||||||
"session_end": 0.18,
|
|
||||||
},
|
|
||||||
"checkout_start": {"purchase_complete": 0.70, "session_end": 0.30},
|
|
||||||
"purchase_complete": {"session_end": 1.0},
|
|
||||||
}
|
|
||||||
agent = {
|
|
||||||
"session_start": {
|
|
||||||
"page_view": 0.90,
|
|
||||||
"view_item_page": 0.08,
|
|
||||||
"session_end": 0.02,
|
|
||||||
},
|
|
||||||
"page_view": {"view_item_page": 0.40, "search": 0.35, "session_end": 0.25},
|
|
||||||
"view_item_page": {
|
|
||||||
"learn_more_about_item": 0.55,
|
|
||||||
"add_item_to_cart": 0.15,
|
|
||||||
"session_end": 0.30,
|
|
||||||
},
|
|
||||||
"learn_more_about_item": {
|
|
||||||
"view_item_page": 0.45,
|
|
||||||
"add_item_to_cart": 0.20,
|
|
||||||
"session_end": 0.35,
|
|
||||||
},
|
|
||||||
"add_item_to_cart": {
|
|
||||||
"checkout_start": 0.42,
|
|
||||||
"view_item_page": 0.28,
|
|
||||||
"session_end": 0.30,
|
|
||||||
},
|
|
||||||
"checkout_start": {"purchase_complete": 0.52, "session_end": 0.48},
|
|
||||||
"purchase_complete": {"session_end": 1.0},
|
|
||||||
}
|
|
||||||
return compile_transition_data(human, agent)
|
|
||||||
|
|
||||||
|
|
||||||
def load_transition_data(prefer_data: bool = True) -> TransitionData:
|
|
||||||
if not prefer_data:
|
|
||||||
return fallback_transition_data()
|
|
||||||
try:
|
|
||||||
from ..lib.behavior import get_transition_models
|
|
||||||
|
|
||||||
human_trans, agent_trans = get_transition_models()
|
|
||||||
return compile_transition_data(human_trans, agent_trans)
|
|
||||||
except Exception:
|
|
||||||
return fallback_transition_data()
|
|
||||||
|
|
||||||
|
|
||||||
if JAX_AVAILABLE:
|
|
||||||
|
|
||||||
@partial(jax.jit, static_argnums=(8, 9, 10))
|
|
||||||
def _sample_sessions_jax(
|
|
||||||
key: jax.Array,
|
|
||||||
human_T: jax.Array,
|
|
||||||
agent_T: jax.Array,
|
|
||||||
terminal_mask: jax.Array,
|
|
||||||
start_idx: int,
|
|
||||||
term_idx: int,
|
|
||||||
alpha: float,
|
|
||||||
n_products: int,
|
|
||||||
n_sessions: int,
|
|
||||||
max_steps: int,
|
|
||||||
n_states: int,
|
|
||||||
) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array]:
|
|
||||||
k_actor, k_product, k_step = jax.random.split(key, 3)
|
|
||||||
start_idx_i32 = jnp.asarray(start_idx, dtype=jnp.int32)
|
|
||||||
term_idx_i32 = jnp.asarray(term_idx, dtype=jnp.int32)
|
|
||||||
actor_draw = jax.random.uniform(k_actor, (n_sessions,))
|
|
||||||
actors = (actor_draw < alpha).astype(jnp.int32)
|
|
||||||
products = jax.random.randint(
|
|
||||||
k_product, (n_sessions,), 0, n_products, dtype=jnp.int32
|
|
||||||
)
|
|
||||||
|
|
||||||
active_init = jnp.ones((n_sessions,), dtype=jnp.bool_)
|
|
||||||
state_init = jnp.full((n_sessions,), start_idx_i32, dtype=jnp.int32)
|
|
||||||
|
|
||||||
def _scan_step(carry, _):
|
|
||||||
states, active, rng = carry
|
|
||||||
rng, k = jax.random.split(rng)
|
|
||||||
probs_h = human_T[states]
|
|
||||||
probs_a = agent_T[states]
|
|
||||||
probs = jnp.where(actors[:, None] == 0, probs_h, probs_a)
|
|
||||||
next_state = jax.random.categorical(k, jnp.log(probs + 1e-10), axis=-1)
|
|
||||||
next_state = jnp.where(active, next_state, term_idx_i32)
|
|
||||||
emitted = jnp.where(active, next_state, -1)
|
|
||||||
is_terminal = terminal_mask[jnp.clip(next_state, 0, n_states - 1)]
|
|
||||||
next_active = active & (~is_terminal)
|
|
||||||
carry_states = jnp.where(next_active, next_state, term_idx_i32)
|
|
||||||
return (carry_states, next_active, rng), emitted
|
|
||||||
|
|
||||||
_, state_t = jax.lax.scan(
|
|
||||||
_scan_step, (state_init, active_init, k_step), None, length=max_steps
|
|
||||||
)
|
|
||||||
states = state_t.T
|
|
||||||
lengths = jnp.sum(states >= 0, axis=1, dtype=jnp.int32)
|
|
||||||
return states, products, actors, lengths
|
|
||||||
|
|
||||||
|
|
||||||
def sample_sessions(
|
|
||||||
key,
|
|
||||||
transition_data: TransitionData,
|
|
||||||
alpha: float,
|
|
||||||
n_products: int,
|
|
||||||
n_sessions: int,
|
|
||||||
max_steps: int,
|
|
||||||
) -> SessionBatch:
|
|
||||||
if JAX_AVAILABLE:
|
|
||||||
td = transition_data.to_jax()
|
|
||||||
states, products, actors, lengths = _sample_sessions_jax(
|
|
||||||
key,
|
|
||||||
td.human_T,
|
|
||||||
td.agent_T,
|
|
||||||
td.terminal_mask,
|
|
||||||
int(td.start_idx),
|
|
||||||
int(td.term_idx),
|
|
||||||
float(alpha),
|
|
||||||
int(n_products),
|
|
||||||
int(n_sessions),
|
|
||||||
int(max_steps),
|
|
||||||
int(td.human_T.shape[0]),
|
|
||||||
)
|
|
||||||
return SessionBatch(
|
|
||||||
states=states, products=products, actors=actors, lengths=lengths
|
|
||||||
)
|
|
||||||
|
|
||||||
rng = np.random.default_rng(int(np.asarray(key).reshape(-1)[0]))
|
|
||||||
n_states = transition_data.human_T.shape[0]
|
|
||||||
products = rng.integers(0, n_products, size=n_sessions, dtype=np.int32)
|
|
||||||
actors = (rng.random(size=n_sessions) < alpha).astype(np.int32)
|
|
||||||
states = np.full((n_sessions, max_steps), -1, dtype=np.int32)
|
|
||||||
lengths = np.zeros((n_sessions,), dtype=np.int32)
|
|
||||||
for i in range(n_sessions):
|
|
||||||
current = int(transition_data.start_idx)
|
|
||||||
mat = transition_data.agent_T if actors[i] == 1 else transition_data.human_T
|
|
||||||
for t in range(max_steps):
|
|
||||||
nxt = int(rng.choice(n_states, p=mat[current]))
|
|
||||||
states[i, t] = nxt
|
|
||||||
if transition_data.terminal_mask[nxt]:
|
|
||||||
lengths[i] = t + 1
|
|
||||||
break
|
|
||||||
current = nxt
|
|
||||||
if lengths[i] == 0:
|
|
||||||
lengths[i] = max_steps
|
|
||||||
return SessionBatch(
|
|
||||||
states=states, products=products, actors=actors, lengths=lengths
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if JAX_AVAILABLE:
|
|
||||||
|
|
||||||
@partial(jax.jit, static_argnums=(2,))
|
|
||||||
def compute_session_transitions(states, lengths, n_states: int):
|
|
||||||
src = states[:, :-1]
|
|
||||||
dst = states[:, 1:]
|
|
||||||
time_idx = jnp.arange(src.shape[1])[None, :]
|
|
||||||
valid = (src >= 0) & (dst >= 0) & (time_idx < (lengths[:, None] - 1))
|
|
||||||
src_clip = jnp.clip(src, 0, n_states - 1)
|
|
||||||
dst_clip = jnp.clip(dst, 0, n_states - 1)
|
|
||||||
src_oh = jax.nn.one_hot(src_clip, n_states)
|
|
||||||
dst_oh = jax.nn.one_hot(dst_clip, n_states)
|
|
||||||
counts = jnp.einsum(
|
|
||||||
"nti,ntj,nt->nij", src_oh, dst_oh, valid.astype(jnp.float32)
|
|
||||||
)
|
|
||||||
row_sums = jnp.sum(counts, axis=-1, keepdims=True)
|
|
||||||
return counts / (row_sums + 1e-10)
|
|
||||||
|
|
||||||
|
|
||||||
else:
|
|
||||||
|
|
||||||
def compute_session_transitions(states, lengths, n_states: int):
|
|
||||||
trans = np.zeros((states.shape[0], n_states, n_states), dtype=np.float32)
|
|
||||||
for i in range(states.shape[0]):
|
|
||||||
for t in range(max(int(lengths[i]) - 1, 0)):
|
|
||||||
s = int(states[i, t])
|
|
||||||
d = int(states[i, t + 1])
|
|
||||||
if s >= 0 and d >= 0:
|
|
||||||
trans[i, s, d] += 1.0
|
|
||||||
row_sums = trans.sum(axis=-1, keepdims=True)
|
|
||||||
return trans / (row_sums + 1e-10)
|
|
||||||
|
|
||||||
|
|
||||||
def batch_kl(P, Q_human, Q_agent, eps: float = 1e-10):
|
|
||||||
p = P + eps
|
|
||||||
p = p / jnp.sum(p, axis=-1, keepdims=True)
|
|
||||||
qh = Q_human[None, ...] + eps
|
|
||||||
qa = Q_agent[None, ...] + eps
|
|
||||||
delta_h = jnp.sum(p * jnp.log(p / qh), axis=(1, 2))
|
|
||||||
delta_a = jnp.sum(p * jnp.log(p / qa), axis=(1, 2))
|
|
||||||
return delta_h, delta_a
|
|
||||||
|
|
||||||
|
|
||||||
if JAX_AVAILABLE:
|
|
||||||
batch_kl = jax.jit(batch_kl)
|
|
||||||
|
|
||||||
|
|
||||||
def agent_probability_from_kl(delta_h, delta_a, temperature: float = 1.0):
|
|
||||||
t = jnp.maximum(float(temperature), 1e-6)
|
|
||||||
exp_h = jnp.exp(-delta_h / t)
|
|
||||||
exp_a = jnp.exp(-delta_a / t)
|
|
||||||
return exp_a / (exp_h + exp_a + 1e-10)
|
|
||||||
|
|
||||||
|
|
||||||
def estimate_alpha_from_kl(delta_h, delta_a, beta: float = 2.0):
|
|
||||||
logits = beta * (delta_h - delta_a)
|
|
||||||
return 1.0 / (1.0 + jnp.exp(-logits))
|
|
||||||
|
|
||||||
|
|
||||||
def weighted_demand(states, products, n_products: int, event_weights):
|
|
||||||
valid = states >= 0
|
|
||||||
state_clip = jnp.clip(states, 0, event_weights.shape[0] - 1)
|
|
||||||
weights = event_weights[state_clip] * valid
|
|
||||||
per_session = jnp.sum(weights, axis=1)
|
|
||||||
demand = jnp.zeros((n_products,), dtype=jnp.float32)
|
|
||||||
demand = demand.at[products].add(per_session)
|
|
||||||
total = jnp.sum(demand)
|
|
||||||
return jnp.where(total > 0.0, (demand / total) * 100.0, demand)
|
|
||||||
|
|
||||||
|
|
||||||
if JAX_AVAILABLE:
|
|
||||||
weighted_demand = jax.jit(weighted_demand, static_argnums=(2,))
|
|
||||||
|
|
||||||
|
|
||||||
def purchase_flags(states, purchase_mask):
|
|
||||||
state_clip = jnp.clip(states, 0, purchase_mask.shape[0] - 1)
|
|
||||||
hits = purchase_mask[state_clip] & (states >= 0)
|
|
||||||
return jnp.any(hits, axis=1)
|
|
||||||
|
|
||||||
|
|
||||||
if JAX_AVAILABLE:
|
|
||||||
purchase_flags = jax.jit(purchase_flags)
|
|
||||||
|
|
||||||
|
|
||||||
def revenue_from_demand(prices, demand):
|
|
||||||
return jnp.dot(prices, demand)
|
|
||||||
|
|
||||||
|
|
||||||
if JAX_AVAILABLE:
|
|
||||||
revenue_from_demand = jax.jit(revenue_from_demand)
|
|
||||||
|
|
||||||
|
|
||||||
def reward_with_coi_penalty(
|
|
||||||
revenue,
|
|
||||||
agent_prob: float,
|
|
||||||
lambda_coi: float,
|
|
||||||
info_value: float,
|
|
||||||
eta_ux: float = 0.0,
|
|
||||||
ux_volatility: float = 0.0,
|
|
||||||
):
|
|
||||||
leakage = agent_prob * info_value
|
|
||||||
discount = jnp.clip(1.0 - lambda_coi * leakage, 0.0, 1.0)
|
|
||||||
ux_penalty = eta_ux * revenue * ux_volatility
|
|
||||||
return revenue * discount - ux_penalty, leakage, discount, ux_penalty
|
|
||||||
|
|
||||||
|
|
||||||
if JAX_AVAILABLE:
|
|
||||||
reward_with_coi_penalty = jax.jit(reward_with_coi_penalty)
|
|
||||||
@@ -1,5 +0,0 @@
|
|||||||
flax==0.10.7
|
|
||||||
optax==0.2.7
|
|
||||||
distrax==0.1.5
|
|
||||||
orbax-checkpoint==0.11.32
|
|
||||||
chex==0.1.90
|
|
||||||
1319
engine/jax/train.py
1319
engine/jax/train.py
File diff suppressed because it is too large
Load Diff
@@ -14,7 +14,6 @@ _EXPORTS: dict[str, tuple[str, str]] = {
|
|||||||
"EconomicMetricsWrapper": (".wrappers", "EconomicMetricsWrapper"),
|
"EconomicMetricsWrapper": (".wrappers", "EconomicMetricsWrapper"),
|
||||||
"MetricsCallback": (".callbacks", "MetricsCallback"),
|
"MetricsCallback": (".callbacks", "MetricsCallback"),
|
||||||
"EvalMetricsCallback": (".callbacks", "EvalMetricsCallback"),
|
"EvalMetricsCallback": (".callbacks", "EvalMetricsCallback"),
|
||||||
"CheckpointArtifactCallback": (".callbacks", "CheckpointArtifactCallback"),
|
|
||||||
"ProviderBenchmark": (".providers", "ProviderBenchmark"),
|
"ProviderBenchmark": (".providers", "ProviderBenchmark"),
|
||||||
"ProviderResult": (".providers", "ProviderResult"),
|
"ProviderResult": (".providers", "ProviderResult"),
|
||||||
"BenchmarkConfig": (".providers", "BenchmarkConfig"),
|
"BenchmarkConfig": (".providers", "BenchmarkConfig"),
|
||||||
|
|||||||
@@ -1,150 +1,96 @@
|
|||||||
"""Training callbacks for W&B/TensorBoard logging - reads from info dict."""
|
"""Training callbacks with algorithm-agnostic metric extraction."""
|
||||||
|
|
||||||
from pathlib import Path
|
from typing import Any
|
||||||
|
|
||||||
from stable_baselines3.common.callbacks import BaseCallback, EvalCallback
|
from stable_baselines3.common.callbacks import BaseCallback, EvalCallback
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from ..wandb_checkpoint import checkpoint_artifact_name, log_checkpoint_file
|
|
||||||
|
|
||||||
try:
|
|
||||||
import wandb
|
|
||||||
|
|
||||||
HAS_WANDB = True
|
|
||||||
except ImportError:
|
|
||||||
HAS_WANDB = False
|
|
||||||
|
|
||||||
|
|
||||||
class MetricsCallback(BaseCallback):
|
class MetricsCallback(BaseCallback):
|
||||||
"""Training metrics logger - reads info['economics'], logs to W&B."""
|
"""Collects interval train metrics from env info dictionaries."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, log_histograms: bool = True, log_freq: int = 100, verbose: int = 0
|
self,
|
||||||
|
log_histograms: bool = False,
|
||||||
|
log_freq: int = 100,
|
||||||
|
verbose: int = 0,
|
||||||
):
|
):
|
||||||
super().__init__(verbose)
|
super().__init__(verbose)
|
||||||
self.log_histograms = log_histograms
|
self.log_histograms = log_histograms
|
||||||
self.log_freq = log_freq
|
self.log_freq = max(1, int(log_freq))
|
||||||
self._episode_revenues: list[float] = []
|
self._window_sums = {
|
||||||
|
"train/revenue_mean": 0.0,
|
||||||
def _on_step(self) -> bool:
|
"train/margin_mean": 0.0,
|
||||||
if not HAS_WANDB or wandb.run is None:
|
"train/coi_level_mean": 0.0,
|
||||||
return True
|
"train/regret_mean": 0.0,
|
||||||
|
"train/coi_mix": 0.0,
|
||||||
for info in self.locals.get("infos", []):
|
"train/coi_base": 0.0,
|
||||||
if "economics" not in info:
|
"train/coi_leakage": 0.0,
|
||||||
continue
|
"train/coi_penalty": 0.0,
|
||||||
|
|
||||||
econ = info["economics"]
|
|
||||||
t = self.num_timesteps
|
|
||||||
|
|
||||||
payload = {
|
|
||||||
"train/revenue_step": econ["revenue"],
|
|
||||||
"train/margin_step": econ["margin"],
|
|
||||||
"train/coi_level": econ["coi_level"],
|
|
||||||
"train/regret_step": econ["regret"],
|
|
||||||
}
|
}
|
||||||
|
self._window_count = 0
|
||||||
|
self.events: list[dict[str, Any]] = []
|
||||||
|
|
||||||
|
def _accumulate(self, info: dict[str, Any]) -> None:
|
||||||
|
econ = info.get("economics")
|
||||||
|
if not isinstance(econ, dict):
|
||||||
|
return
|
||||||
|
self._window_sums["train/revenue_mean"] += float(econ.get("revenue", 0.0))
|
||||||
|
self._window_sums["train/margin_mean"] += float(econ.get("margin", 0.0))
|
||||||
|
self._window_sums["train/coi_level_mean"] += float(econ.get("coi_level", 0.0))
|
||||||
|
self._window_sums["train/regret_mean"] += float(econ.get("regret", 0.0))
|
||||||
if "coi_mix" in econ:
|
if "coi_mix" in econ:
|
||||||
payload["train/coi_mix"] = econ["coi_mix"]
|
self._window_sums["train/coi_mix"] += float(econ.get("coi_mix", 0.0))
|
||||||
if "coi_base" in econ:
|
if "coi_base" in econ:
|
||||||
payload["train/coi_base"] = econ["coi_base"]
|
self._window_sums["train/coi_base"] += float(econ.get("coi_base", 0.0))
|
||||||
if "coi_leakage" in econ:
|
if "coi_leakage" in econ:
|
||||||
payload["train/coi_leakage"] = econ["coi_leakage"]
|
self._window_sums["train/coi_leakage"] += float(
|
||||||
|
econ.get("coi_leakage", 0.0)
|
||||||
|
)
|
||||||
if "coi_penalty" in econ:
|
if "coi_penalty" in econ:
|
||||||
payload["train/coi_penalty"] = econ["coi_penalty"]
|
self._window_sums["train/coi_penalty"] += float(
|
||||||
wandb.log(payload, step=t)
|
econ.get("coi_penalty", 0.0)
|
||||||
|
|
||||||
self._episode_revenues.append(econ["revenue"])
|
|
||||||
|
|
||||||
# histograms at log_freq intervals
|
|
||||||
if self.log_histograms and self.num_timesteps % self.log_freq == 0:
|
|
||||||
for info in self.locals.get("infos", []):
|
|
||||||
if "prices" in info:
|
|
||||||
wandb.log(
|
|
||||||
{"distributions/prices": wandb.Histogram(info["prices"])},
|
|
||||||
step=self.num_timesteps,
|
|
||||||
)
|
|
||||||
if "demand" in info:
|
|
||||||
wandb.log(
|
|
||||||
{"distributions/demand": wandb.Histogram(info["demand"])},
|
|
||||||
step=self.num_timesteps,
|
|
||||||
)
|
)
|
||||||
|
self._window_count += 1
|
||||||
|
|
||||||
return True
|
def _flush(self, step: int) -> None:
|
||||||
|
if self._window_count <= 0:
|
||||||
def _on_rollout_end(self) -> None:
|
|
||||||
if not HAS_WANDB or wandb.run is None or not self._episode_revenues:
|
|
||||||
return
|
return
|
||||||
wandb.log(
|
denom = float(self._window_count)
|
||||||
{
|
payload = {
|
||||||
"train/revenue_rollout_mean": np.mean(self._episode_revenues),
|
key: (value / denom)
|
||||||
"train/revenue_rollout_total": np.sum(self._episode_revenues),
|
for key, value in self._window_sums.items()
|
||||||
},
|
if value != 0.0
|
||||||
step=self.num_timesteps,
|
or key
|
||||||
)
|
in {
|
||||||
self._episode_revenues = []
|
"train/revenue_mean",
|
||||||
|
"train/margin_mean",
|
||||||
|
"train/coi_level_mean",
|
||||||
class CheckpointArtifactCallback(BaseCallback):
|
"train/regret_mean",
|
||||||
"""Periodic SB3 checkpoint uploader backed by W&B artifacts."""
|
|
||||||
|
|
||||||
def __init__(self, cfg: dict, interval: int = 10_000, verbose: int = 0):
|
|
||||||
super().__init__(verbose)
|
|
||||||
self.cfg = dict(cfg)
|
|
||||||
self.interval = max(1, int(interval))
|
|
||||||
self.model_dir = Path(str(self.cfg.get("model_dir", "engine/models")))
|
|
||||||
self.model_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
self._next_checkpoint = self.interval
|
|
||||||
self._last_saved_step = -1
|
|
||||||
|
|
||||||
def _artifact_name(self) -> str:
|
|
||||||
sweep_id = (
|
|
||||||
getattr(wandb.run, "sweep_id", None)
|
|
||||||
if HAS_WANDB and wandb.run is not None
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
return checkpoint_artifact_name(self.cfg, backend="sb3", sweep_id=sweep_id)
|
|
||||||
|
|
||||||
def _checkpoint_file(self) -> Path:
|
|
||||||
algo = str(self.cfg.get("algo", "model"))
|
|
||||||
base = self.model_dir / f"phantom_{algo}_checkpoint"
|
|
||||||
self.model.save(str(base))
|
|
||||||
return base.with_suffix(".zip")
|
|
||||||
|
|
||||||
def _save_checkpoint(self) -> None:
|
|
||||||
if not HAS_WANDB or wandb.run is None:
|
|
||||||
return
|
|
||||||
step = int(self.num_timesteps)
|
|
||||||
if step <= self._last_saved_step:
|
|
||||||
return
|
|
||||||
checkpoint_path = self._checkpoint_file()
|
|
||||||
metadata = {
|
|
||||||
"step": step,
|
|
||||||
"algo": str(self.cfg.get("algo", "unknown")),
|
|
||||||
"sweep_id": getattr(wandb.run, "sweep_id", None),
|
|
||||||
}
|
}
|
||||||
saved = log_checkpoint_file(
|
}
|
||||||
self._artifact_name(),
|
payload["train/global_step"] = int(step)
|
||||||
file_path=checkpoint_path,
|
self.events.append(payload)
|
||||||
artifact_file_name=checkpoint_path.name,
|
for key in self._window_sums:
|
||||||
metadata=metadata,
|
self._window_sums[key] = 0.0
|
||||||
)
|
self._window_count = 0
|
||||||
if saved:
|
|
||||||
self._last_saved_step = step
|
|
||||||
|
|
||||||
def _on_step(self) -> bool:
|
def _on_step(self) -> bool:
|
||||||
if self.num_timesteps < self._next_checkpoint:
|
for info in self.locals.get("infos", []):
|
||||||
return True
|
if isinstance(info, dict):
|
||||||
self._save_checkpoint()
|
self._accumulate(info)
|
||||||
while self._next_checkpoint <= self.num_timesteps:
|
|
||||||
self._next_checkpoint += self.interval
|
if self.num_timesteps % self.log_freq == 0:
|
||||||
|
self._flush(step=self.num_timesteps)
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def _on_training_end(self) -> None:
|
def _on_training_end(self) -> None:
|
||||||
self._save_checkpoint()
|
self._flush(step=self.num_timesteps)
|
||||||
|
|
||||||
|
|
||||||
class EvalMetricsCallback(EvalCallback):
|
class EvalMetricsCallback(EvalCallback):
|
||||||
"""Deterministic evaluation - true performance without exploration noise."""
|
"""Deterministic evaluation collector detached from logging backends."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, eval_env, eval_freq: int = 1000, n_eval_episodes: int = 5, **kwargs
|
self, eval_env, eval_freq: int = 1000, n_eval_episodes: int = 5, **kwargs
|
||||||
@@ -153,23 +99,19 @@ class EvalMetricsCallback(EvalCallback):
|
|||||||
eval_env, eval_freq=eval_freq, n_eval_episodes=n_eval_episodes, **kwargs
|
eval_env, eval_freq=eval_freq, n_eval_episodes=n_eval_episodes, **kwargs
|
||||||
)
|
)
|
||||||
self._eval_revenues: list[float] = []
|
self._eval_revenues: list[float] = []
|
||||||
|
self.events: list[dict[str, float | int]] = []
|
||||||
|
|
||||||
def _on_step(self) -> bool:
|
def _on_step(self) -> bool:
|
||||||
result = super()._on_step()
|
result = super()._on_step()
|
||||||
|
|
||||||
if not HAS_WANDB or wandb.run is None:
|
|
||||||
return result
|
|
||||||
|
|
||||||
# log eval metrics after evaluation runs
|
|
||||||
if self.n_calls % self.eval_freq == 0 and hasattr(self, "last_mean_reward"):
|
if self.n_calls % self.eval_freq == 0 and hasattr(self, "last_mean_reward"):
|
||||||
wandb.log(
|
self.events.append(
|
||||||
{
|
{
|
||||||
"eval/reward_mean": self.last_mean_reward,
|
"eval/reward_mean": float(self.last_mean_reward),
|
||||||
"eval/revenue_mean": np.mean(self._eval_revenues)
|
"eval/revenue_mean": float(np.mean(self._eval_revenues))
|
||||||
if self._eval_revenues
|
if self._eval_revenues
|
||||||
else 0,
|
else 0.0,
|
||||||
},
|
"train/global_step": int(self.num_timesteps),
|
||||||
step=self.num_timesteps,
|
}
|
||||||
)
|
)
|
||||||
self._eval_revenues = []
|
self._eval_revenues = []
|
||||||
|
|
||||||
|
|||||||
@@ -31,26 +31,20 @@ def _print_local_metrics(metrics: dict[str, Any]) -> None:
|
|||||||
print("PHANTOM_METRICS:" + json.dumps(metrics))
|
print("PHANTOM_METRICS:" + json.dumps(metrics))
|
||||||
|
|
||||||
|
|
||||||
def _should_print_local(spec: TrainSpec) -> bool:
|
def _log_train_events(events: list[dict[str, Any]], log_freq: int) -> None:
|
||||||
if not spec.runtime.use_jax:
|
if not events:
|
||||||
return True
|
return
|
||||||
try:
|
period = max(1, int(log_freq))
|
||||||
import jax
|
last_logged_step = -period
|
||||||
|
for event in sorted(
|
||||||
return int(jax.process_index()) == 0
|
[evt for evt in events if isinstance(evt, dict)],
|
||||||
except Exception:
|
key=lambda evt: int(evt.get("train/global_step", 0)),
|
||||||
return True
|
):
|
||||||
|
step = int(event.get("train/global_step", 0))
|
||||||
|
if step <= 0 or (step - last_logged_step) < period:
|
||||||
def _is_non_primary_jax_worker(spec: TrainSpec) -> bool:
|
continue
|
||||||
if not spec.runtime.use_jax:
|
log_metrics(event, step=step)
|
||||||
return False
|
last_logged_step = step
|
||||||
try:
|
|
||||||
import jax
|
|
||||||
|
|
||||||
return int(jax.process_count()) > 1 and int(jax.process_index()) != 0
|
|
||||||
except Exception:
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def run_train_once(
|
def run_train_once(
|
||||||
@@ -65,9 +59,8 @@ def run_train_once(
|
|||||||
extra_tags: Sequence[str],
|
extra_tags: Sequence[str],
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
wandb = get_wandb_module()
|
wandb = get_wandb_module()
|
||||||
if no_wandb or wandb is None or _is_non_primary_jax_worker(spec):
|
if no_wandb or wandb is None:
|
||||||
result = run_train(spec)
|
result = run_train(spec)
|
||||||
if _should_print_local(spec):
|
|
||||||
_print_local_metrics(result.metrics)
|
_print_local_metrics(result.metrics)
|
||||||
return result.metrics
|
return result.metrics
|
||||||
|
|
||||||
@@ -95,6 +88,7 @@ def run_train_once(
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
result = run_train(spec)
|
result = run_train(spec)
|
||||||
|
_log_train_events(result.events, spec.runtime.log_freq)
|
||||||
metrics = result.metrics
|
metrics = result.metrics
|
||||||
step = int(metrics.get("train/global_step", spec.runtime.total_timesteps))
|
step = int(metrics.get("train/global_step", spec.runtime.total_timesteps))
|
||||||
log_metrics(metrics, step=step)
|
log_metrics(metrics, step=step)
|
||||||
@@ -122,6 +116,7 @@ def run_with_active_sweep_run(
|
|||||||
)
|
)
|
||||||
update_run_config({**spec.to_flat_dict(), **metadata})
|
update_run_config({**spec.to_flat_dict(), **metadata})
|
||||||
result = run_train(spec)
|
result = run_train(spec)
|
||||||
|
_log_train_events(result.events, spec.runtime.log_freq)
|
||||||
metrics = result.metrics
|
metrics = result.metrics
|
||||||
step = int(metrics.get("train/global_step", spec.runtime.total_timesteps))
|
step = int(metrics.get("train/global_step", spec.runtime.total_timesteps))
|
||||||
log_metrics(metrics, step=step)
|
log_metrics(metrics, step=step)
|
||||||
|
|||||||
@@ -81,44 +81,6 @@
|
|||||||
"command": "bash scripts/nx_research.sh docker-train-publish",
|
"command": "bash scripts/nx_research.sh docker-train-publish",
|
||||||
"cwd": "."
|
"cwd": "."
|
||||||
}
|
}
|
||||||
},
|
|
||||||
"train-tpu-pod": {
|
|
||||||
"executor": "nx:run-commands",
|
|
||||||
"options": {
|
|
||||||
"command": "bash scripts/nx_research.sh train-tpu-pod",
|
|
||||||
"cwd": "."
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"train-tpu-vm-prepare": {
|
|
||||||
"executor": "nx:run-commands",
|
|
||||||
"options": {
|
|
||||||
"command": "bash scripts/nx_research.sh train-tpu-vm-prepare",
|
|
||||||
"cwd": "."
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"train-tpu-vm-run": {
|
|
||||||
"executor": "nx:run-commands",
|
|
||||||
"options": {
|
|
||||||
"command": "bash scripts/nx_research.sh train-tpu-vm-run",
|
|
||||||
"cwd": "."
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"train-tpu-vm": {
|
|
||||||
"executor": "nx:run-commands",
|
|
||||||
"dependsOn": [
|
|
||||||
"train-tpu-vm-prepare"
|
|
||||||
],
|
|
||||||
"options": {
|
|
||||||
"command": "bash scripts/nx_research.sh train-tpu-vm-run",
|
|
||||||
"cwd": "."
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"train-tpu-vm-sweep": {
|
|
||||||
"executor": "nx:run-commands",
|
|
||||||
"options": {
|
|
||||||
"command": "bash scripts/nx_research.sh train-tpu-vm-sweep",
|
|
||||||
"cwd": "."
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"tags": [
|
"tags": [
|
||||||
|
|||||||
@@ -106,11 +106,6 @@ class OptimizerSpec:
|
|||||||
eps_decay: float = 0.9995
|
eps_decay: float = 0.9995
|
||||||
arch: str = "small"
|
arch: str = "small"
|
||||||
activation: str = "relu"
|
activation: str = "relu"
|
||||||
jax_num_envs: int = 16
|
|
||||||
jax_num_steps: int = 128
|
|
||||||
jax_num_minibatches: int = 4
|
|
||||||
jax_update_epochs: int = 4
|
|
||||||
jax_anneal_lr: bool = True
|
|
||||||
vf_coef: float = 0.5
|
vf_coef: float = 0.5
|
||||||
max_grad_norm: float = 0.5
|
max_grad_norm: float = 0.5
|
||||||
|
|
||||||
@@ -125,7 +120,6 @@ class RuntimeSpec:
|
|||||||
checkpoint_interval: int = 200_000
|
checkpoint_interval: int = 200_000
|
||||||
model_dir: str = "engine/models"
|
model_dir: str = "engine/models"
|
||||||
log_freq: int = 100
|
log_freq: int = 100
|
||||||
use_jax: bool = False
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
@@ -156,7 +150,6 @@ class TrainSpec:
|
|||||||
"model_dir": self.runtime.model_dir,
|
"model_dir": self.runtime.model_dir,
|
||||||
"backend": self.runtime.backend,
|
"backend": self.runtime.backend,
|
||||||
"device": self.runtime.device,
|
"device": self.runtime.device,
|
||||||
"use_jax": self.runtime.use_jax,
|
|
||||||
"checkpoint_interval": self.runtime.checkpoint_interval,
|
"checkpoint_interval": self.runtime.checkpoint_interval,
|
||||||
"n_products": self.env.n_products,
|
"n_products": self.env.n_products,
|
||||||
"N": self.env.n_sessions,
|
"N": self.env.n_sessions,
|
||||||
@@ -197,11 +190,6 @@ class TrainSpec:
|
|||||||
"eps_decay": self.optimizer.eps_decay,
|
"eps_decay": self.optimizer.eps_decay,
|
||||||
"arch": self.optimizer.arch,
|
"arch": self.optimizer.arch,
|
||||||
"activation": self.optimizer.activation,
|
"activation": self.optimizer.activation,
|
||||||
"jax_num_envs": self.optimizer.jax_num_envs,
|
|
||||||
"jax_num_steps": self.optimizer.jax_num_steps,
|
|
||||||
"jax_num_minibatches": self.optimizer.jax_num_minibatches,
|
|
||||||
"jax_update_epochs": self.optimizer.jax_update_epochs,
|
|
||||||
"jax_anneal_lr": self.optimizer.jax_anneal_lr,
|
|
||||||
"vf_coef": self.optimizer.vf_coef,
|
"vf_coef": self.optimizer.vf_coef,
|
||||||
"max_grad_norm": self.optimizer.max_grad_norm,
|
"max_grad_norm": self.optimizer.max_grad_norm,
|
||||||
"robust_eval_enabled": self.eval.robust_eval_enabled,
|
"robust_eval_enabled": self.eval.robust_eval_enabled,
|
||||||
@@ -223,14 +211,11 @@ class TrainSpec:
|
|||||||
base.get("device", runtime_env.get("PHANTOM_DEVICE", "auto"))
|
base.get("device", runtime_env.get("PHANTOM_DEVICE", "auto"))
|
||||||
)
|
)
|
||||||
|
|
||||||
requested_jax = _truthy(base.get("use_jax")) or _truthy(
|
backend = str(base.get("backend", "sb3")).lower()
|
||||||
runtime_env.get("PHANTOM_USE_JAX")
|
|
||||||
)
|
|
||||||
backend = str(base.get("backend", "jax" if requested_jax else "sb3")).lower()
|
|
||||||
if backend == "auto":
|
if backend == "auto":
|
||||||
backend = "jax" if requested_jax else "sb3"
|
backend = "sb3"
|
||||||
if backend == "jax":
|
if backend != "sb3":
|
||||||
requested_jax = True
|
backend = "sb3"
|
||||||
|
|
||||||
no_robust = _truthy(base.get("no_robust"))
|
no_robust = _truthy(base.get("no_robust"))
|
||||||
if no_robust:
|
if no_robust:
|
||||||
@@ -284,11 +269,6 @@ class TrainSpec:
|
|||||||
eps_decay=float(base["eps_decay"]),
|
eps_decay=float(base["eps_decay"]),
|
||||||
arch=str(base["arch"]),
|
arch=str(base["arch"]),
|
||||||
activation=str(base["activation"]),
|
activation=str(base["activation"]),
|
||||||
jax_num_envs=int(base["jax_num_envs"]),
|
|
||||||
jax_num_steps=int(base["jax_num_steps"]),
|
|
||||||
jax_num_minibatches=int(base["jax_num_minibatches"]),
|
|
||||||
jax_update_epochs=int(base["jax_update_epochs"]),
|
|
||||||
jax_anneal_lr=_truthy(base.get("jax_anneal_lr")),
|
|
||||||
vf_coef=float(base["vf_coef"]),
|
vf_coef=float(base["vf_coef"]),
|
||||||
max_grad_norm=float(base["max_grad_norm"]),
|
max_grad_norm=float(base["max_grad_norm"]),
|
||||||
),
|
),
|
||||||
@@ -301,7 +281,6 @@ class TrainSpec:
|
|||||||
checkpoint_interval=int(base["checkpoint_interval"]),
|
checkpoint_interval=int(base["checkpoint_interval"]),
|
||||||
model_dir=str(base["model_dir"]),
|
model_dir=str(base["model_dir"]),
|
||||||
log_freq=int(base["log_freq"]),
|
log_freq=int(base["log_freq"]),
|
||||||
use_jax=requested_jax,
|
|
||||||
),
|
),
|
||||||
eval=EvalSpec(
|
eval=EvalSpec(
|
||||||
eval_freq=int(base["eval_freq"]),
|
eval_freq=int(base["eval_freq"]),
|
||||||
|
|||||||
@@ -1,93 +0,0 @@
|
|||||||
method: bayes
|
|
||||||
metric:
|
|
||||||
name: objective/score
|
|
||||||
goal: maximize
|
|
||||||
command:
|
|
||||||
- ${env}
|
|
||||||
- python
|
|
||||||
- -m
|
|
||||||
- engine.train
|
|
||||||
parameters:
|
|
||||||
# fixed: always use JAX backend so TPU chips are actually exercised
|
|
||||||
use_jax:
|
|
||||||
value: true
|
|
||||||
# all four algos have JAX implementations
|
|
||||||
algo:
|
|
||||||
values: [ppo, a2c, dqn, qtable]
|
|
||||||
total_timesteps:
|
|
||||||
values: [50000, 80000, 120000]
|
|
||||||
checkpoint_interval:
|
|
||||||
value: 200000
|
|
||||||
seed:
|
|
||||||
values: [13, 42, 77]
|
|
||||||
n_products:
|
|
||||||
values: [8, 10, 12]
|
|
||||||
# COI framework parameters -- primary research variables
|
|
||||||
alpha:
|
|
||||||
distribution: uniform
|
|
||||||
min: 0.1
|
|
||||||
max: 0.6
|
|
||||||
lambda_coi:
|
|
||||||
distribution: uniform
|
|
||||||
min: 0.05
|
|
||||||
max: 0.6
|
|
||||||
robust_radius:
|
|
||||||
distribution: uniform
|
|
||||||
min: 0.0
|
|
||||||
max: 0.3
|
|
||||||
robust_points:
|
|
||||||
values: [3, 5, 7]
|
|
||||||
info_value:
|
|
||||||
distribution: uniform
|
|
||||||
min: 0.5
|
|
||||||
max: 2.0
|
|
||||||
revenue_weight:
|
|
||||||
values: [0.005, 0.01, 0.02]
|
|
||||||
# shared hyperparameters
|
|
||||||
learning_rate:
|
|
||||||
distribution: log_uniform_values
|
|
||||||
min: 1.0e-5
|
|
||||||
max: 1.0e-3
|
|
||||||
gamma:
|
|
||||||
values: [0.97, 0.99, 0.995]
|
|
||||||
# JAX parallelism -- key lever for TPU throughput
|
|
||||||
jax_num_envs:
|
|
||||||
values: [8, 16, 32]
|
|
||||||
jax_num_steps:
|
|
||||||
values: [64, 128, 256]
|
|
||||||
jax_num_minibatches:
|
|
||||||
values: [2, 4, 8]
|
|
||||||
jax_update_epochs:
|
|
||||||
values: [2, 4, 8]
|
|
||||||
# PPO/A2C specific
|
|
||||||
gae_lambda:
|
|
||||||
values: [0.9, 0.95, 0.98]
|
|
||||||
clip_range:
|
|
||||||
values: [0.1, 0.2, 0.3]
|
|
||||||
ent_coef:
|
|
||||||
values: [0.0, 0.005, 0.01]
|
|
||||||
# DQN specific
|
|
||||||
buffer_size:
|
|
||||||
values: [20000, 50000, 100000]
|
|
||||||
batch_size:
|
|
||||||
values: [128, 256, 512]
|
|
||||||
learning_starts:
|
|
||||||
values: [500, 1000, 3000]
|
|
||||||
exploration_fraction:
|
|
||||||
values: [0.1, 0.2, 0.3]
|
|
||||||
exploration_final_eps:
|
|
||||||
values: [0.01, 0.03, 0.05]
|
|
||||||
# QTable specific
|
|
||||||
q_lr:
|
|
||||||
values: [0.03, 0.05, 0.1, 0.2]
|
|
||||||
eps_end:
|
|
||||||
values: [0.02, 0.05, 0.1]
|
|
||||||
eps_decay:
|
|
||||||
values: [0.999, 0.9995, 0.9999]
|
|
||||||
# action space
|
|
||||||
action_levels:
|
|
||||||
values: [7, 9, 11]
|
|
||||||
action_scale_low:
|
|
||||||
values: [0.75, 0.8, 0.85]
|
|
||||||
action_scale_high:
|
|
||||||
values: [1.15, 1.2, 1.25]
|
|
||||||
@@ -1,64 +0,0 @@
|
|||||||
method: bayes
|
|
||||||
metric:
|
|
||||||
name: objective/score
|
|
||||||
goal: maximize
|
|
||||||
command:
|
|
||||||
- ${env}
|
|
||||||
- python
|
|
||||||
- -m
|
|
||||||
- engine.train
|
|
||||||
parameters:
|
|
||||||
use_jax:
|
|
||||||
value: true
|
|
||||||
# pmap requires all workers to compile the same computation graph shape,
|
|
||||||
# so structural params are fixed -- only research/scalar params are swept
|
|
||||||
algo:
|
|
||||||
values: [ppo, a2c]
|
|
||||||
jax_num_envs:
|
|
||||||
value: 32
|
|
||||||
jax_num_steps:
|
|
||||||
value: 128
|
|
||||||
jax_num_minibatches:
|
|
||||||
value: 4
|
|
||||||
jax_update_epochs:
|
|
||||||
value: 4
|
|
||||||
total_timesteps:
|
|
||||||
value: 100000
|
|
||||||
checkpoint_interval:
|
|
||||||
value: 200000
|
|
||||||
n_products:
|
|
||||||
value: 10
|
|
||||||
action_levels:
|
|
||||||
value: 9
|
|
||||||
# research parameters -- primary sweep targets
|
|
||||||
alpha:
|
|
||||||
distribution: uniform
|
|
||||||
min: 0.1
|
|
||||||
max: 0.6
|
|
||||||
lambda_coi:
|
|
||||||
distribution: uniform
|
|
||||||
min: 0.05
|
|
||||||
max: 0.6
|
|
||||||
robust_radius:
|
|
||||||
distribution: uniform
|
|
||||||
min: 0.0
|
|
||||||
max: 0.3
|
|
||||||
info_value:
|
|
||||||
distribution: uniform
|
|
||||||
min: 0.5
|
|
||||||
max: 2.0
|
|
||||||
revenue_weight:
|
|
||||||
values: [0.005, 0.01, 0.02]
|
|
||||||
# training hyperparameters
|
|
||||||
learning_rate:
|
|
||||||
distribution: log_uniform_values
|
|
||||||
min: 1.0e-5
|
|
||||||
max: 1.0e-3
|
|
||||||
gamma:
|
|
||||||
values: [0.97, 0.99, 0.995]
|
|
||||||
gae_lambda:
|
|
||||||
values: [0.9, 0.95, 0.98]
|
|
||||||
clip_range:
|
|
||||||
values: [0.1, 0.2, 0.3]
|
|
||||||
ent_coef:
|
|
||||||
values: [0.0, 0.005, 0.01]
|
|
||||||
@@ -7,14 +7,6 @@ from .orchestrators import run_benchmark_cli, run_sweep_agent, run_train_once
|
|||||||
from .spec import TrainSpec
|
from .spec import TrainSpec
|
||||||
|
|
||||||
|
|
||||||
def _truthy(value: str | bool | None) -> bool:
|
|
||||||
if isinstance(value, bool):
|
|
||||||
return value
|
|
||||||
if value is None:
|
|
||||||
return False
|
|
||||||
return str(value).strip().lower() in {"1", "true", "yes", "on"}
|
|
||||||
|
|
||||||
|
|
||||||
def _parse_tags(raw: str | None) -> list[str]:
|
def _parse_tags(raw: str | None) -> list[str]:
|
||||||
if raw is None:
|
if raw is None:
|
||||||
return []
|
return []
|
||||||
@@ -55,7 +47,7 @@ def _build_parser() -> argparse.ArgumentParser:
|
|||||||
parser.add_argument("--group", type=str)
|
parser.add_argument("--group", type=str)
|
||||||
parser.add_argument("--tags", type=str)
|
parser.add_argument("--tags", type=str)
|
||||||
|
|
||||||
parser.add_argument("--backend", choices=["auto", "sb3", "jax"], default="auto")
|
parser.add_argument("--backend", choices=["auto", "sb3"], default="auto")
|
||||||
parser.add_argument("--algo", choices=["ppo", "a2c", "dqn", "qtable", "sac"])
|
parser.add_argument("--algo", choices=["ppo", "a2c", "dqn", "qtable", "sac"])
|
||||||
parser.add_argument("--seed", type=int)
|
parser.add_argument("--seed", type=int)
|
||||||
parser.add_argument("--total-timesteps", type=int)
|
parser.add_argument("--total-timesteps", type=int)
|
||||||
@@ -111,13 +103,6 @@ def _build_parser() -> argparse.ArgumentParser:
|
|||||||
parser.add_argument("--eval-freq", type=int)
|
parser.add_argument("--eval-freq", type=int)
|
||||||
parser.add_argument("--eval-episodes", type=int)
|
parser.add_argument("--eval-episodes", type=int)
|
||||||
|
|
||||||
parser.add_argument("--jax", action="store_true")
|
|
||||||
parser.add_argument("--jax-num-envs", type=int)
|
|
||||||
parser.add_argument("--jax-num-steps", type=int)
|
|
||||||
parser.add_argument("--jax-num-minibatches", type=int)
|
|
||||||
parser.add_argument("--jax-update-epochs", type=int)
|
|
||||||
parser.add_argument("--jax-anneal-lr", type=str)
|
|
||||||
|
|
||||||
parser.add_argument("--sweep-agent", action="store_true")
|
parser.add_argument("--sweep-agent", action="store_true")
|
||||||
parser.add_argument("--sweep-id", type=str)
|
parser.add_argument("--sweep-id", type=str)
|
||||||
parser.add_argument("--count", type=int, default=0)
|
parser.add_argument("--count", type=int, default=0)
|
||||||
@@ -127,9 +112,6 @@ def _build_parser() -> argparse.ArgumentParser:
|
|||||||
|
|
||||||
|
|
||||||
def _overrides_from_args(args: argparse.Namespace) -> dict[str, Any]:
|
def _overrides_from_args(args: argparse.Namespace) -> dict[str, Any]:
|
||||||
jax_anneal_lr = (
|
|
||||||
_truthy(args.jax_anneal_lr) if args.jax_anneal_lr is not None else None
|
|
||||||
)
|
|
||||||
backend = None if args.backend == "auto" else args.backend
|
backend = None if args.backend == "auto" else args.backend
|
||||||
|
|
||||||
overrides = {
|
overrides = {
|
||||||
@@ -185,12 +167,6 @@ def _overrides_from_args(args: argparse.Namespace) -> dict[str, Any]:
|
|||||||
"max_grad_norm": args.max_grad_norm,
|
"max_grad_norm": args.max_grad_norm,
|
||||||
"eval_freq": args.eval_freq,
|
"eval_freq": args.eval_freq,
|
||||||
"eval_episodes": args.eval_episodes,
|
"eval_episodes": args.eval_episodes,
|
||||||
"use_jax": args.jax or None,
|
|
||||||
"jax_num_envs": args.jax_num_envs,
|
|
||||||
"jax_num_steps": args.jax_num_steps,
|
|
||||||
"jax_num_minibatches": args.jax_num_minibatches,
|
|
||||||
"jax_update_epochs": args.jax_update_epochs,
|
|
||||||
"jax_anneal_lr": jax_anneal_lr,
|
|
||||||
}
|
}
|
||||||
return {key: value for key, value in overrides.items() if value is not None}
|
return {key: value for key, value in overrides.items() if value is not None}
|
||||||
|
|
||||||
|
|||||||
@@ -12,17 +12,14 @@ class TrainResult:
|
|||||||
spec: TrainSpec
|
spec: TrainSpec
|
||||||
metrics: dict[str, Any]
|
metrics: dict[str, Any]
|
||||||
artifacts: dict[str, str]
|
artifacts: dict[str, str]
|
||||||
|
events: list[dict[str, Any]]
|
||||||
|
|
||||||
|
|
||||||
def run_train(spec: TrainSpec) -> TrainResult:
|
def run_train(spec: TrainSpec) -> TrainResult:
|
||||||
cfg = spec.to_flat_dict()
|
cfg = spec.to_flat_dict()
|
||||||
algo = spec.algorithm.name
|
algo = spec.algorithm.name
|
||||||
|
|
||||||
if spec.runtime.use_jax or spec.runtime.backend == "jax":
|
if algo == "qtable":
|
||||||
from .backends.jax import train_jax_backend
|
|
||||||
|
|
||||||
_, raw_metrics = train_jax_backend(cfg)
|
|
||||||
elif algo == "qtable":
|
|
||||||
from .backends.qtable import train_qtable
|
from .backends.qtable import train_qtable
|
||||||
|
|
||||||
_, raw_metrics = train_qtable(cfg)
|
_, raw_metrics = train_qtable(cfg)
|
||||||
@@ -31,10 +28,13 @@ def run_train(spec: TrainSpec) -> TrainResult:
|
|||||||
|
|
||||||
_, raw_metrics = train_sb3(cfg)
|
_, raw_metrics = train_sb3(cfg)
|
||||||
|
|
||||||
|
events_raw = raw_metrics.pop("_train_events", [])
|
||||||
|
events = [evt for evt in events_raw if isinstance(evt, dict)]
|
||||||
|
|
||||||
metrics = canonicalize_metrics(raw_metrics, spec)
|
metrics = canonicalize_metrics(raw_metrics, spec)
|
||||||
artifacts: dict[str, str] = {}
|
artifacts: dict[str, str] = {}
|
||||||
model_path = raw_metrics.get("model/path")
|
model_path = raw_metrics.get("model/path")
|
||||||
if isinstance(model_path, str):
|
if isinstance(model_path, str):
|
||||||
artifacts["model/path"] = model_path
|
artifacts["model/path"] = model_path
|
||||||
|
|
||||||
return TrainResult(spec=spec, metrics=metrics, artifacts=artifacts)
|
return TrainResult(spec=spec, metrics=metrics, artifacts=artifacts, events=events)
|
||||||
|
|||||||
@@ -108,49 +108,6 @@ PY
|
|||||||
image_ref="${TRAIN_IMAGE_REF:-us-central1-docker.pkg.dev/phantom-trc/phantom/phantom-trainer}"
|
image_ref="${TRAIN_IMAGE_REF:-us-central1-docker.pkg.dev/phantom-trc/phantom/phantom-trainer}"
|
||||||
docker build -f docker/Trainer.dockerfile --target gpu -t "$image_ref:gpu-latest" .
|
docker build -f docker/Trainer.dockerfile --target gpu -t "$image_ref:gpu-latest" .
|
||||||
docker push "$image_ref:gpu-latest"
|
docker push "$image_ref:gpu-latest"
|
||||||
docker build -f docker/Trainer.dockerfile --target tpu -t "$image_ref:tpu-latest" .
|
|
||||||
docker push "$image_ref:tpu-latest"
|
|
||||||
;;
|
|
||||||
train-tpu-pod)
|
|
||||||
load_sweep_env
|
|
||||||
require_var TPU_NAME "TPU_NAME required, e.g. TPU_NAME=TPUlong"
|
|
||||||
require_var SWEEP_ID "SWEEP_ID required, e.g. SWEEP_ID=entity/project/id"
|
|
||||||
require_var WANDB_API_KEY "WANDB_API_KEY required - set it in $env_file"
|
|
||||||
gcloud compute tpus tpu-vm scp scripts/tpu_pod_run.sh "$TPU_NAME":/tmp/tpu_pod_run.sh --zone="${TPU_ZONE:-us-central2-b}" --project="${TPU_PROJECT:-phantom-trc}" --worker=all
|
|
||||||
gcloud compute tpus tpu-vm ssh "$TPU_NAME" --zone="${TPU_ZONE:-us-central2-b}" --project="${TPU_PROJECT:-phantom-trc}" --worker=all --command="WANDB_API_KEY='$WANDB_API_KEY' SWEEP_ID='$SWEEP_ID' AGENT_COUNT='${AGENT_COUNT:-0}' sh /tmp/tpu_pod_run.sh"
|
|
||||||
;;
|
|
||||||
train-tpu-vm-prepare)
|
|
||||||
require_var TPU_NAME "TPU_NAME required, e.g. TPU_NAME=TPUlong"
|
|
||||||
TPU_NAME="$TPU_NAME" \
|
|
||||||
TPU_ZONE="${TPU_ZONE:-us-central2-b}" \
|
|
||||||
TPU_PROJECT="${TPU_PROJECT:-phantom-trc}" \
|
|
||||||
LOCAL_REPO_DIR="$PWD" \
|
|
||||||
REMOTE_REPO_DIR="${TPU_REPO_DIR:-/tmp/PHANTOM}" \
|
|
||||||
sh scripts/tpu_sync_repo.sh
|
|
||||||
gcloud compute tpus tpu-vm scp scripts/tpu_vm_train.sh "$TPU_NAME":/tmp/tpu_vm_train.sh --zone="${TPU_ZONE:-us-central2-b}" --project="${TPU_PROJECT:-phantom-trc}" --worker=all
|
|
||||||
;;
|
|
||||||
train-tpu-vm-run)
|
|
||||||
load_sweep_env
|
|
||||||
require_var TPU_NAME "TPU_NAME required, e.g. TPU_NAME=TPUlong"
|
|
||||||
require_var LOCAL_TRAIN_ARGS "LOCAL_TRAIN_ARGS required, e.g. --algo ppo --jax --total-timesteps 200000"
|
|
||||||
gcloud compute tpus tpu-vm ssh "$TPU_NAME" --zone="${TPU_ZONE:-us-central2-b}" --project="${TPU_PROJECT:-phantom-trc}" --worker=all --command="REPO_DIR='${TPU_REPO_DIR:-/tmp/PHANTOM}' TRAIN_ARGS='${LOCAL_TRAIN_ARGS}' WANDB_API_KEY='${WANDB_API_KEY:-}' sh /tmp/tpu_vm_train.sh"
|
|
||||||
;;
|
|
||||||
train-tpu-vm-sweep)
|
|
||||||
load_sweep_env
|
|
||||||
require_var TPU_NAME "TPU_NAME required, e.g. TPU_NAME=TPUlong"
|
|
||||||
require_var SWEEP_ID "SWEEP_ID required, e.g. SWEEP_ID=lusiana/capstone/abc123"
|
|
||||||
require_var WANDB_API_KEY "WANDB_API_KEY required - set it in $env_file"
|
|
||||||
args=(
|
|
||||||
--sweep-id "$SWEEP_ID"
|
|
||||||
--tpu-name "$TPU_NAME"
|
|
||||||
--tpu-zone "${TPU_ZONE:-us-central2-b}"
|
|
||||||
--tpu-project "${TPU_PROJECT:-phantom-trc}"
|
|
||||||
--tpu-repo-dir "${TPU_REPO_DIR:-/tmp/PHANTOM}"
|
|
||||||
)
|
|
||||||
if [ -n "${AGENT_COUNT:-}" ] && [ "${AGENT_COUNT}" != "0" ]; then
|
|
||||||
args+=(--count "$AGENT_COUNT")
|
|
||||||
fi
|
|
||||||
WANDB_API_KEY="$WANDB_API_KEY" python3 scripts/tpu_vm_sweep_agent.py "${args[@]}"
|
|
||||||
;;
|
;;
|
||||||
*)
|
*)
|
||||||
printf '%s\n' "Unknown research command: $cmd" >&2
|
printf '%s\n' "Unknown research command: $cmd" >&2
|
||||||
|
|||||||
@@ -1,32 +0,0 @@
|
|||||||
#!/usr/bin/env sh
|
|
||||||
# Executed on each TPU pod worker via `gcloud tpu-vm scp` + `gcloud tpu-vm ssh --worker=all`.
|
|
||||||
# Authenticates with Artifact Registry using the VM's service account metadata token,
|
|
||||||
# pulls the TPU trainer image, then runs the W&B sweep agent inside Docker.
|
|
||||||
# TPU chip devices (/dev/accel*) are exposed via --privileged + /dev volume mount.
|
|
||||||
# Required env vars: WANDB_API_KEY, SWEEP_ID
|
|
||||||
# Optional: AGENT_COUNT (default 1, 0 = run until sweep ends)
|
|
||||||
set -eu
|
|
||||||
|
|
||||||
IMAGE="us-central1-docker.pkg.dev/phantom-trc/phantom/phantom-trainer:tpu-latest"
|
|
||||||
AGENT_COUNT="${AGENT_COUNT:-1}"
|
|
||||||
|
|
||||||
# use VM service account — no manual key needed on the pod
|
|
||||||
TOKEN=$(curl -sf -H "Metadata-Flavor: Google" \
|
|
||||||
"http://metadata.google.internal/computeMetadata/v1/instance/service-accounts/default/token" \
|
|
||||||
| python3 -c 'import sys, json; print(json.load(sys.stdin)["access_token"])')
|
|
||||||
|
|
||||||
echo "$TOKEN" | sudo docker login -u oauth2accesstoken \
|
|
||||||
--password-stdin https://us-central1-docker.pkg.dev
|
|
||||||
|
|
||||||
sudo docker pull "$IMAGE"
|
|
||||||
|
|
||||||
# --privileged + /dev mount gives the container access to /dev/accel* (TPU chips)
|
|
||||||
# --network host lets JAX reach the other pod workers for distributed init
|
|
||||||
sudo docker run --rm \
|
|
||||||
--privileged \
|
|
||||||
--network host \
|
|
||||||
--volume /dev:/dev \
|
|
||||||
-e WANDB_API_KEY="$WANDB_API_KEY" \
|
|
||||||
-e SWEEP_ID="$SWEEP_ID" \
|
|
||||||
-e AGENT_COUNT="$AGENT_COUNT" \
|
|
||||||
"$IMAGE"
|
|
||||||
@@ -1,83 +0,0 @@
|
|||||||
#!/usr/bin/env sh
|
|
||||||
set -eu
|
|
||||||
|
|
||||||
TPU_NAME="${TPU_NAME:?TPU_NAME is required}"
|
|
||||||
TPU_ZONE="${TPU_ZONE:-us-central2-b}"
|
|
||||||
TPU_PROJECT="${TPU_PROJECT:-phantom-trc}"
|
|
||||||
LOCAL_REPO_DIR="${LOCAL_REPO_DIR:-$(pwd)}"
|
|
||||||
REMOTE_REPO_DIR="${REMOTE_REPO_DIR:-/tmp/PHANTOM}"
|
|
||||||
ARCHIVE_PATH="${ARCHIVE_PATH:-/tmp/phantom-sync.tgz}"
|
|
||||||
|
|
||||||
FILE_LIST="$(mktemp /tmp/phantom-sync-files.XXXXXX)"
|
|
||||||
CLEANUP_LIST=true
|
|
||||||
|
|
||||||
cleanup() {
|
|
||||||
if [ "$CLEANUP_LIST" = "true" ]; then
|
|
||||||
rm -f "$FILE_LIST"
|
|
||||||
fi
|
|
||||||
}
|
|
||||||
trap cleanup EXIT
|
|
||||||
|
|
||||||
if [ ! -d "$LOCAL_REPO_DIR" ]; then
|
|
||||||
echo "local repo directory not found: $LOCAL_REPO_DIR"
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
|
|
||||||
if git -C "$LOCAL_REPO_DIR" rev-parse --is-inside-work-tree >/dev/null 2>&1; then
|
|
||||||
git -C "$LOCAL_REPO_DIR" ls-files -co --exclude-standard > "$FILE_LIST"
|
|
||||||
python3 - "$FILE_LIST" <<'PY'
|
|
||||||
import sys
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
file_list = Path(sys.argv[1])
|
|
||||||
skip_prefixes = (
|
|
||||||
"wandb/",
|
|
||||||
".venv/",
|
|
||||||
"venv/",
|
|
||||||
"node_modules/",
|
|
||||||
".next/",
|
|
||||||
".turbo/",
|
|
||||||
"__pycache__/",
|
|
||||||
".mypy_cache/",
|
|
||||||
".pytest_cache/",
|
|
||||||
".ruff_cache/",
|
|
||||||
"paper/build/",
|
|
||||||
"tests/e2e/test-results/",
|
|
||||||
)
|
|
||||||
|
|
||||||
rows = file_list.read_text().splitlines()
|
|
||||||
kept = [
|
|
||||||
row
|
|
||||||
for row in rows
|
|
||||||
if row and not any(row == p.rstrip("/") or row.startswith(p) for p in skip_prefixes)
|
|
||||||
]
|
|
||||||
file_list.write_text("\n".join(kept) + ("\n" if kept else ""))
|
|
||||||
PY
|
|
||||||
tar -czf "$ARCHIVE_PATH" -C "$LOCAL_REPO_DIR" -T "$FILE_LIST"
|
|
||||||
else
|
|
||||||
tar \
|
|
||||||
--exclude-vcs \
|
|
||||||
--exclude=".venv" --exclude="*/.venv" \
|
|
||||||
--exclude="venv" --exclude="*/venv" \
|
|
||||||
--exclude="node_modules" --exclude="*/node_modules" \
|
|
||||||
--exclude=".next" --exclude="*/.next" \
|
|
||||||
--exclude=".turbo" --exclude="*/.turbo" \
|
|
||||||
--exclude="__pycache__" --exclude="*/__pycache__" \
|
|
||||||
--exclude=".mypy_cache" --exclude="*/.mypy_cache" \
|
|
||||||
--exclude=".pytest_cache" --exclude="*/.pytest_cache" \
|
|
||||||
--exclude=".ruff_cache" --exclude="*/.ruff_cache" \
|
|
||||||
--exclude="wandb" --exclude="*/wandb" \
|
|
||||||
--exclude="paper/build" \
|
|
||||||
--exclude="tests/e2e/test-results" \
|
|
||||||
-czf "$ARCHIVE_PATH" \
|
|
||||||
-C "$LOCAL_REPO_DIR" .
|
|
||||||
fi
|
|
||||||
|
|
||||||
gcloud compute tpus tpu-vm scp "$ARCHIVE_PATH" "$TPU_NAME:/tmp/phantom-sync.tgz" \
|
|
||||||
--zone="$TPU_ZONE" --project="$TPU_PROJECT" --worker=all
|
|
||||||
|
|
||||||
gcloud compute tpus tpu-vm ssh "$TPU_NAME" \
|
|
||||||
--zone="$TPU_ZONE" --project="$TPU_PROJECT" --worker=all \
|
|
||||||
--command="rm -rf '$REMOTE_REPO_DIR' && mkdir -p '$REMOTE_REPO_DIR' && tar -xzf /tmp/phantom-sync.tgz -C '$REMOTE_REPO_DIR' && rm -f /tmp/phantom-sync.tgz"
|
|
||||||
|
|
||||||
rm -f "$ARCHIVE_PATH"
|
|
||||||
@@ -1,211 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import gc
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
import re
|
|
||||||
import shlex
|
|
||||||
import shutil
|
|
||||||
import subprocess
|
|
||||||
import time
|
|
||||||
import resource
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import wandb
|
|
||||||
|
|
||||||
|
|
||||||
CLI_MAP: dict[str, str] = {
|
|
||||||
"algo": "--algo",
|
|
||||||
"total_timesteps": "--total-timesteps",
|
|
||||||
"alpha": "--alpha",
|
|
||||||
"N": "--N",
|
|
||||||
"n_products": "--n-products",
|
|
||||||
"lambda_coi": "--lambda-coi",
|
|
||||||
"info_value": "--info-value",
|
|
||||||
"robust_radius": "--robust-radius",
|
|
||||||
"robust_points": "--robust-points",
|
|
||||||
"no_robust": "--no-robust",
|
|
||||||
"learning_rate": "--learning-rate",
|
|
||||||
"gamma": "--gamma",
|
|
||||||
"gae_lambda": "--gae-lambda",
|
|
||||||
"clip_range": "--clip-range",
|
|
||||||
"ent_coef": "--ent-coef",
|
|
||||||
"revenue_weight": "--revenue-weight",
|
|
||||||
"max_steps": "--max-steps",
|
|
||||||
"margin_floor": "--margin-floor",
|
|
||||||
"margin_floor_patience": "--margin-floor-patience",
|
|
||||||
"arch": "--arch",
|
|
||||||
"activation": "--activation",
|
|
||||||
"jax_num_envs": "--jax-num-envs",
|
|
||||||
"jax_num_steps": "--jax-num-steps",
|
|
||||||
"jax_num_minibatches": "--jax-num-minibatches",
|
|
||||||
"jax_update_epochs": "--jax-update-epochs",
|
|
||||||
"jax_anneal_lr": "--jax-anneal-lr",
|
|
||||||
"checkpoint_interval": "--checkpoint-interval",
|
|
||||||
"action_levels": "--action-levels",
|
|
||||||
"action_scale_low": "--action-scale-low",
|
|
||||||
"action_scale_high": "--action-scale-high",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def _to_cli_args(cfg: dict) -> str:
|
|
||||||
parts: list[str] = ["--jax", "--no-wandb"]
|
|
||||||
for key, flag in CLI_MAP.items():
|
|
||||||
if key not in cfg:
|
|
||||||
continue
|
|
||||||
value = cfg[key]
|
|
||||||
if value is None:
|
|
||||||
continue
|
|
||||||
if isinstance(value, bool):
|
|
||||||
if key == "jax_anneal_lr":
|
|
||||||
parts.extend([flag, "true" if value else "false"])
|
|
||||||
elif value:
|
|
||||||
parts.append(flag)
|
|
||||||
continue
|
|
||||||
parts.extend([flag, str(value)])
|
|
||||||
return " ".join(shlex.quote(p) for p in parts)
|
|
||||||
|
|
||||||
|
|
||||||
_SENTINEL = "PHANTOM_METRICS:"
|
|
||||||
|
|
||||||
|
|
||||||
def _raise_nofile_limit(min_soft: int = 8192) -> None:
|
|
||||||
try:
|
|
||||||
soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
|
|
||||||
target = min(hard, max(soft, min_soft))
|
|
||||||
if target > soft:
|
|
||||||
resource.setrlimit(resource.RLIMIT_NOFILE, (target, hard))
|
|
||||||
except Exception:
|
|
||||||
return
|
|
||||||
|
|
||||||
|
|
||||||
def _extract_metrics(output: str) -> dict:
|
|
||||||
# fast path: look for the dedicated sentinel line emitted by run_local
|
|
||||||
for line in output.splitlines():
|
|
||||||
if line.startswith(_SENTINEL):
|
|
||||||
try:
|
|
||||||
return json.loads(line[len(_SENTINEL) :])
|
|
||||||
except Exception:
|
|
||||||
break
|
|
||||||
# fallback: scan for any JSON block containing eval/sweep keys;
|
|
||||||
# use greedy match to capture the largest possible block first
|
|
||||||
for block in re.findall(r"\{[^{}]*\}", output):
|
|
||||||
try:
|
|
||||||
obj = json.loads(block)
|
|
||||||
except Exception:
|
|
||||||
continue
|
|
||||||
if isinstance(obj, dict) and (
|
|
||||||
"objective/score" in obj
|
|
||||||
or "eval/reward_mean" in obj
|
|
||||||
or "sweep/score" in obj
|
|
||||||
):
|
|
||||||
return obj
|
|
||||||
return {}
|
|
||||||
|
|
||||||
|
|
||||||
def main() -> None:
|
|
||||||
_raise_nofile_limit()
|
|
||||||
p = argparse.ArgumentParser(
|
|
||||||
description="Run W&B sweep where each trial uses full TPU pod"
|
|
||||||
)
|
|
||||||
p.add_argument("--sweep-id", required=True)
|
|
||||||
p.add_argument("--tpu-name", required=True)
|
|
||||||
p.add_argument("--tpu-zone", default="us-central2-b")
|
|
||||||
p.add_argument("--tpu-project", default="phantom-trc")
|
|
||||||
p.add_argument("--tpu-repo-dir", default="/tmp/PHANTOM")
|
|
||||||
p.add_argument("--count", type=int, default=0)
|
|
||||||
p.add_argument("--workdir", default=str(Path(__file__).resolve().parents[1]))
|
|
||||||
args = p.parse_args()
|
|
||||||
|
|
||||||
workdir = Path(args.workdir).resolve()
|
|
||||||
env = os.environ.copy()
|
|
||||||
wandb_root = workdir / ".wandb-agent"
|
|
||||||
wandb_root.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
prepare_cmd = [
|
|
||||||
"make",
|
|
||||||
"train.tpu.vm.prepare",
|
|
||||||
f"TPU_NAME={args.tpu_name}",
|
|
||||||
f"TPU_ZONE={args.tpu_zone}",
|
|
||||||
f"TPU_PROJECT={args.tpu_project}",
|
|
||||||
f"TPU_REPO_DIR={args.tpu_repo_dir}",
|
|
||||||
]
|
|
||||||
prepare = subprocess.run(
|
|
||||||
prepare_cmd,
|
|
||||||
cwd=workdir,
|
|
||||||
env=env,
|
|
||||||
text=True,
|
|
||||||
capture_output=False,
|
|
||||||
check=False,
|
|
||||||
)
|
|
||||||
if prepare.returncode != 0:
|
|
||||||
raise RuntimeError("Failed to prepare TPU workers for sweep")
|
|
||||||
|
|
||||||
def run_trial() -> None:
|
|
||||||
run = None
|
|
||||||
trial_wandb_dir = wandb_root / f"trial-{time.time_ns()}"
|
|
||||||
trial_wandb_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
try:
|
|
||||||
run = wandb.init(dir=str(trial_wandb_dir))
|
|
||||||
cfg = dict(wandb.config)
|
|
||||||
cli_args = _to_cli_args(cfg)
|
|
||||||
env_trial = dict(env)
|
|
||||||
env_trial["LOCAL_TRAIN_ARGS"] = cli_args
|
|
||||||
env_trial["WANDB_DIR"] = str(trial_wandb_dir)
|
|
||||||
env_trial["WANDB_CACHE_DIR"] = str(trial_wandb_dir / "cache")
|
|
||||||
env_trial["WANDB_DATA_DIR"] = str(trial_wandb_dir / "data")
|
|
||||||
|
|
||||||
cmd = [
|
|
||||||
"make",
|
|
||||||
"train.tpu.vm.run",
|
|
||||||
f"TPU_NAME={args.tpu_name}",
|
|
||||||
f"TPU_ZONE={args.tpu_zone}",
|
|
||||||
f"TPU_PROJECT={args.tpu_project}",
|
|
||||||
f"TPU_REPO_DIR={args.tpu_repo_dir}",
|
|
||||||
]
|
|
||||||
|
|
||||||
proc = subprocess.run(
|
|
||||||
cmd,
|
|
||||||
cwd=workdir,
|
|
||||||
env=env_trial,
|
|
||||||
text=True,
|
|
||||||
capture_output=True,
|
|
||||||
check=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
if proc.stdout:
|
|
||||||
print(proc.stdout)
|
|
||||||
if proc.stderr:
|
|
||||||
print(proc.stderr)
|
|
||||||
|
|
||||||
if proc.returncode != 0:
|
|
||||||
if run is not None:
|
|
||||||
run.summary["runner/exit_code"] = proc.returncode
|
|
||||||
raise RuntimeError(f"TPU trial failed with exit code {proc.returncode}")
|
|
||||||
|
|
||||||
metrics = _extract_metrics(proc.stdout)
|
|
||||||
if metrics:
|
|
||||||
wandb.log(metrics)
|
|
||||||
for k, v in metrics.items():
|
|
||||||
run.summary[k] = v
|
|
||||||
run.summary["runner/exit_code"] = 0
|
|
||||||
except Exception:
|
|
||||||
time.sleep(2)
|
|
||||||
raise
|
|
||||||
finally:
|
|
||||||
if run is not None and wandb.run is not None:
|
|
||||||
wandb.finish()
|
|
||||||
shutil.rmtree(trial_wandb_dir, ignore_errors=True)
|
|
||||||
gc.collect()
|
|
||||||
|
|
||||||
wandb.agent(
|
|
||||||
args.sweep_id,
|
|
||||||
function=run_trial,
|
|
||||||
count=args.count if args.count > 0 else None,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
@@ -1,43 +0,0 @@
|
|||||||
#!/usr/bin/env sh
|
|
||||||
set -eu
|
|
||||||
|
|
||||||
REPO_DIR="${REPO_DIR:-$HOME/PHANTOM}"
|
|
||||||
PYTHON_BIN="${PYTHON_BIN:-python3}"
|
|
||||||
TRAIN_ARGS="${TRAIN_ARGS:---algo ppo --jax --total-timesteps 200000 --jax-num-envs 32 --jax-num-steps 128 --jax-num-minibatches 4 --jax-update-epochs 4}"
|
|
||||||
EXTRA_PIP="${EXTRA_PIP:-flax optax distrax}"
|
|
||||||
INSTALL_FULL_REQUIREMENTS="${INSTALL_FULL_REQUIREMENTS:-0}"
|
|
||||||
|
|
||||||
if [ ! -d "$REPO_DIR" ]; then
|
|
||||||
echo "repo directory not found: $REPO_DIR"
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
|
|
||||||
cd "$REPO_DIR"
|
|
||||||
|
|
||||||
if [ -d "wandb" ]; then
|
|
||||||
rm -rf wandb
|
|
||||||
fi
|
|
||||||
|
|
||||||
# keep install idempotent and avoid re-installing jax/libtpu each run
|
|
||||||
if [ "$INSTALL_FULL_REQUIREMENTS" = "1" ] && [ -f "requirements.txt" ]; then
|
|
||||||
$PYTHON_BIN -m pip install -r requirements.txt
|
|
||||||
fi
|
|
||||||
if ! $PYTHON_BIN -c 'import flax, optax, distrax' >/dev/null 2>&1; then
|
|
||||||
if [ -f "engine/jax/requirements.txt" ]; then
|
|
||||||
$PYTHON_BIN -m pip install -r engine/jax/requirements.txt
|
|
||||||
fi
|
|
||||||
$PYTHON_BIN -m pip install -U $EXTRA_PIP
|
|
||||||
fi
|
|
||||||
|
|
||||||
if [ -n "${WANDB_API_KEY:-}" ]; then
|
|
||||||
if ! $PYTHON_BIN -c 'import wandb; import inspect; assert hasattr(wandb, "init") and callable(wandb.init)' >/dev/null 2>&1; then
|
|
||||||
$PYTHON_BIN -m pip install -U wandb
|
|
||||||
fi
|
|
||||||
fi
|
|
||||||
|
|
||||||
if [ -n "${WANDB_API_KEY:-}" ]; then
|
|
||||||
export WANDB_API_KEY
|
|
||||||
exec $PYTHON_BIN -m engine.train $TRAIN_ARGS
|
|
||||||
fi
|
|
||||||
|
|
||||||
exec $PYTHON_BIN -m engine.train $TRAIN_ARGS --no-wandb
|
|
||||||
Reference in New Issue
Block a user