mirror of
https://github.com/velocitatem/PHANTOM.git
synced 2026-05-31 08:33:36 +00:00
adding naive jax and libraries and make adjustments
This commit is contained in:
111
Makefile
111
Makefile
@@ -8,12 +8,30 @@ VENV := .venv
|
|||||||
PYTHON := $(VENV)/bin/python
|
PYTHON := $(VENV)/bin/python
|
||||||
PIP := $(VENV)/bin/pip
|
PIP := $(VENV)/bin/pip
|
||||||
PYTEST := $(VENV)/bin/pytest
|
PYTEST := $(VENV)/bin/pytest
|
||||||
|
TPU_NAME ?= phantom-tpu
|
||||||
|
TPU_ZONE ?= us-central2-b
|
||||||
|
TPU_TYPE ?= v4-32
|
||||||
|
TPU_RUNTIME ?= tpu-vm-v4-base
|
||||||
|
TPU_PROJECT ?= phantom-trc
|
||||||
|
TPU_NETWORK ?= default
|
||||||
|
TPU_SUBNETWORK ?= default-us-central2
|
||||||
|
TPU_USE_SPOT ?= 0
|
||||||
|
TPU_EXTRA_CREATE_FLAGS ?=
|
||||||
|
TPU_WORKDIR ?= ~/PHANTOM
|
||||||
|
TPU_SYNC_PATHS ?= engine lib requirements.txt Makefile .env
|
||||||
|
TPU_TRAIN_ARGS ?= --algo ppo --jax --total-timesteps 20000
|
||||||
|
TPU_JAX_WHEEL_URL ?= https://storage.googleapis.com/jax-releases/libtpu_releases.html
|
||||||
|
TPU_VENV ?= .venv-tpu
|
||||||
|
TPU_TRAIN_ENV ?= PHANTOM_USE_JAX=1 WANDB_MODE=offline
|
||||||
|
TPU_SPOT_FLAG := $(if $(filter 1 true TRUE yes YES,$(TPU_USE_SPOT)),--spot,)
|
||||||
|
TPU_CREATE_CMD = gcloud --project="$(TPU_PROJECT)" compute tpus tpu-vm create "$(TPU_NAME)" --zone="$(TPU_ZONE)" --accelerator-type="$(TPU_TYPE)" --version="$(TPU_RUNTIME)" --network="$(TPU_NETWORK)" --subnetwork="$(TPU_SUBNETWORK)" $(TPU_SPOT_FLAG) $(TPU_EXTRA_CREATE_FLAGS)
|
||||||
|
|
||||||
.DEFAULT_GOAL := help
|
.DEFAULT_GOAL := help
|
||||||
|
|
||||||
.PHONY: help
|
.PHONY: help
|
||||||
help:
|
help:
|
||||||
@echo "pdf.build pdf.watch pdf.clean | test.backend test.e2e test.all | web.dev | install | stats.lines"
|
@echo "pdf.build pdf.watch pdf.clean | test.backend test.e2e test.all | web.dev | install | stats.lines | tpu.*"
|
||||||
|
@echo "TPU presets: tpu.create.v4.ondemand | tpu.create.v4.spot"
|
||||||
|
|
||||||
$(BUILDDIR):
|
$(BUILDDIR):
|
||||||
mkdir -p paper/$(BUILDDIR)
|
mkdir -p paper/$(BUILDDIR)
|
||||||
@@ -70,6 +88,97 @@ $(VENV):
|
|||||||
install: $(VENV)
|
install: $(VENV)
|
||||||
$(PIP) install -r requirements.txt
|
$(PIP) install -r requirements.txt
|
||||||
|
|
||||||
|
.PHONY: tpu.setup
|
||||||
|
tpu.setup:
|
||||||
|
@command -v gcloud >/dev/null 2>&1 || (echo "gcloud CLI not found. Install from https://cloud.google.com/sdk/docs/install" && exit 1)
|
||||||
|
@gcloud auth login --update-adc
|
||||||
|
@gcloud auth application-default login
|
||||||
|
@gcloud config set project "$(TPU_PROJECT)"
|
||||||
|
|
||||||
|
.PHONY: tpu.check.zone
|
||||||
|
tpu.check.zone:
|
||||||
|
@case "$(TPU_ZONE)" in \
|
||||||
|
europe-west4-a|us-central2-b|us-central1-a|us-east1-d|europe-west4-b) ;; \
|
||||||
|
*) echo "Unsupported TPU_ZONE='$(TPU_ZONE)'. Allowed zones: europe-west4-a us-central2-b us-central1-a us-east1-d europe-west4-b"; exit 1 ;; \
|
||||||
|
esac
|
||||||
|
|
||||||
|
.PHONY: tpu.create.v4.ondemand
|
||||||
|
tpu.create.v4.ondemand:
|
||||||
|
$(MAKE) tpu.create TPU_ZONE=us-central2-b TPU_TYPE=v4-32 TPU_USE_SPOT=0 TPU_SUBNETWORK=default-us-central2
|
||||||
|
|
||||||
|
.PHONY: tpu.create.v4.spot
|
||||||
|
tpu.create.v4.spot:
|
||||||
|
$(MAKE) tpu.create TPU_ZONE=us-central2-b TPU_TYPE=v4-32 TPU_USE_SPOT=1 TPU_SUBNETWORK=default-us-central2
|
||||||
|
|
||||||
|
.PHONY: tpu.create
|
||||||
|
tpu.create: tpu.check.zone
|
||||||
|
@if gcloud --project="$(TPU_PROJECT)" compute tpus tpu-vm describe "$(TPU_NAME)" --zone="$(TPU_ZONE)" >/dev/null 2>&1; then \
|
||||||
|
STATE=$$(gcloud --project="$(TPU_PROJECT)" compute tpus tpu-vm describe "$(TPU_NAME)" --zone="$(TPU_ZONE)" --format='value(state)'); \
|
||||||
|
echo "TPU VM $(TPU_NAME) already exists in $(TPU_ZONE) with state=$$STATE, skipping create"; \
|
||||||
|
else \
|
||||||
|
$(TPU_CREATE_CMD); \
|
||||||
|
fi
|
||||||
|
|
||||||
|
.PHONY: tpu.ensure
|
||||||
|
tpu.ensure: tpu.check.zone
|
||||||
|
@set -e; \
|
||||||
|
STATE=$$(gcloud --project="$(TPU_PROJECT)" compute tpus tpu-vm describe "$(TPU_NAME)" --zone="$(TPU_ZONE)" --format='value(state)' 2>/dev/null || true); \
|
||||||
|
if [ -z "$$STATE" ]; then \
|
||||||
|
echo "TPU VM $(TPU_NAME) not found in $(TPU_ZONE), creating"; \
|
||||||
|
$(TPU_CREATE_CMD); \
|
||||||
|
elif [ "$$STATE" = "READY" ]; then \
|
||||||
|
echo "TPU VM $(TPU_NAME) is READY"; \
|
||||||
|
elif [ "$$STATE" = "PREEMPTED" ] || [ "$$STATE" = "TERMINATED" ] || [ "$$STATE" = "FAILED" ]; then \
|
||||||
|
echo "TPU VM $(TPU_NAME) is in terminal state $$STATE, recreating"; \
|
||||||
|
gcloud --project="$(TPU_PROJECT)" compute tpus tpu-vm delete "$(TPU_NAME)" --zone="$(TPU_ZONE)" --quiet || true; \
|
||||||
|
$(TPU_CREATE_CMD); \
|
||||||
|
else \
|
||||||
|
echo "TPU VM $(TPU_NAME) is in state $$STATE; wait or recreate manually"; \
|
||||||
|
exit 1; \
|
||||||
|
fi
|
||||||
|
|
||||||
|
.PHONY: tpu.status
|
||||||
|
tpu.status:
|
||||||
|
gcloud --project="$(TPU_PROJECT)" compute tpus tpu-vm describe "$(TPU_NAME)" --zone="$(TPU_ZONE)"
|
||||||
|
|
||||||
|
.PHONY: tpu.ssh
|
||||||
|
tpu.ssh:
|
||||||
|
gcloud --project="$(TPU_PROJECT)" compute tpus tpu-vm ssh "$(TPU_NAME)" --zone="$(TPU_ZONE)"
|
||||||
|
|
||||||
|
.PHONY: tpu.prepare
|
||||||
|
tpu.prepare: tpu.ensure
|
||||||
|
gcloud --project="$(TPU_PROJECT)" compute tpus tpu-vm ssh "$(TPU_NAME)" --zone="$(TPU_ZONE)" --command "mkdir -p $(TPU_WORKDIR)"
|
||||||
|
|
||||||
|
.PHONY: tpu.deploy
|
||||||
|
tpu.deploy: tpu.prepare
|
||||||
|
@for p in $(TPU_SYNC_PATHS); do \
|
||||||
|
if [ ! -e "$$p" ]; then continue; fi; \
|
||||||
|
if [ -d "$$p" ]; then \
|
||||||
|
gcloud --project="$(TPU_PROJECT)" compute tpus tpu-vm scp --recurse "$$p" "$(TPU_NAME):$(TPU_WORKDIR)/$$p" --zone="$(TPU_ZONE)"; \
|
||||||
|
else \
|
||||||
|
gcloud --project="$(TPU_PROJECT)" compute tpus tpu-vm scp "$$p" "$(TPU_NAME):$(TPU_WORKDIR)/$$p" --zone="$(TPU_ZONE)"; \
|
||||||
|
fi; \
|
||||||
|
done
|
||||||
|
|
||||||
|
.PHONY: tpu.install
|
||||||
|
tpu.install: tpu.ensure
|
||||||
|
gcloud --project="$(TPU_PROJECT)" compute tpus tpu-vm ssh "$(TPU_NAME)" --zone="$(TPU_ZONE)" --command 'cd $(TPU_WORKDIR) && PYBIN=$$(command -v python3.11 || command -v python3.10 || command -v python3) && $$PYBIN -m venv $(TPU_VENV) && $(TPU_VENV)/bin/pip install --upgrade pip setuptools wheel && $(TPU_VENV)/bin/pip install -r requirements.txt && $(TPU_VENV)/bin/pip install -r engine/jax/requirements.txt && $(TPU_VENV)/bin/pip install "jax[tpu]" -f $(TPU_JAX_WHEEL_URL)'
|
||||||
|
|
||||||
|
.PHONY: tpu.check.remote
|
||||||
|
tpu.check.remote: tpu.ensure
|
||||||
|
gcloud --project="$(TPU_PROJECT)" compute tpus tpu-vm ssh "$(TPU_NAME)" --zone="$(TPU_ZONE)" --command 'set -e; mkdir -p $(TPU_WORKDIR); cd $(TPU_WORKDIR); test -f engine/train.py || (echo "Missing code on TPU VM. Run: make tpu.deploy" && exit 2); test -x $(TPU_VENV)/bin/python || (echo "Missing TPU venv. Run: make tpu.install" && exit 3)'
|
||||||
|
|
||||||
|
.PHONY: tpu.train
|
||||||
|
tpu.train: tpu.check.remote
|
||||||
|
gcloud --project="$(TPU_PROJECT)" compute tpus tpu-vm ssh "$(TPU_NAME)" --zone="$(TPU_ZONE)" --command 'cd $(TPU_WORKDIR) && if [ -f .env ]; then set -a && . ./.env && set +a; fi && $(TPU_TRAIN_ENV) $(TPU_VENV)/bin/python -m engine.train $(TPU_TRAIN_ARGS)'
|
||||||
|
|
||||||
|
.PHONY: tpu.bootstrap
|
||||||
|
tpu.bootstrap: tpu.ensure tpu.deploy tpu.install
|
||||||
|
|
||||||
|
.PHONY: tpu.delete
|
||||||
|
tpu.delete:
|
||||||
|
gcloud --project="$(TPU_PROJECT)" compute tpus tpu-vm delete "$(TPU_NAME)" --zone="$(TPU_ZONE)" --quiet
|
||||||
|
|
||||||
.PHONY: stats.lines
|
.PHONY: stats.lines
|
||||||
stats.lines:
|
stats.lines:
|
||||||
@find . \( -path '*/node_modules' -o -path '*/.venv' -o -path '*/venv' \) -prune -o \
|
@find . \( -path '*/node_modules' -o -path '*/.venv' -o -path '*/venv' \) -prune -o \
|
||||||
|
|||||||
13
engine/jax/__init__.py
Normal file
13
engine/jax/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
"""JAX-compatible training and environment modules for PHANTOM."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
try:
|
||||||
|
import jax # noqa: F401
|
||||||
|
import jax.numpy as jnp # noqa: F401
|
||||||
|
|
||||||
|
JAX_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
JAX_AVAILABLE = False
|
||||||
|
|
||||||
|
__all__ = ["JAX_AVAILABLE"]
|
||||||
49
engine/jax/checkpoint.py
Normal file
49
engine/jax/checkpoint.py
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
"""Orbax checkpoint helpers for JAX training runs."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
try:
|
||||||
|
import orbax.checkpoint as ocp
|
||||||
|
|
||||||
|
HAS_ORBAX = True
|
||||||
|
except ImportError:
|
||||||
|
HAS_ORBAX = False
|
||||||
|
|
||||||
|
|
||||||
|
def _require_orbax() -> None:
|
||||||
|
if not HAS_ORBAX:
|
||||||
|
raise ImportError(
|
||||||
|
"orbax-checkpoint is required for checkpoint support. "
|
||||||
|
"Install engine/jax/requirements.txt first."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def create_manager(directory: str | Path, max_to_keep: int = 5):
|
||||||
|
_require_orbax()
|
||||||
|
root = Path(directory)
|
||||||
|
root.mkdir(parents=True, exist_ok=True)
|
||||||
|
options = ocp.CheckpointManagerOptions(
|
||||||
|
max_to_keep=max(1, int(max_to_keep)), create=True
|
||||||
|
)
|
||||||
|
return ocp.CheckpointManager(root.as_posix(), ocp.PyTreeCheckpointer(), options)
|
||||||
|
|
||||||
|
|
||||||
|
def save(manager, *, step: int, payload: Any) -> bool:
|
||||||
|
_require_orbax()
|
||||||
|
return bool(manager.save(int(step), payload))
|
||||||
|
|
||||||
|
|
||||||
|
def latest_step(manager) -> int | None:
|
||||||
|
_require_orbax()
|
||||||
|
return manager.latest_step()
|
||||||
|
|
||||||
|
|
||||||
|
def restore(manager, *, target: Any, step: int | None = None) -> Any:
|
||||||
|
_require_orbax()
|
||||||
|
step_to_restore = manager.latest_step() if step is None else int(step)
|
||||||
|
if step_to_restore is None:
|
||||||
|
return target
|
||||||
|
return manager.restore(step_to_restore, items=target)
|
||||||
287
engine/jax/env.py
Normal file
287
engine/jax/env.py
Normal file
@@ -0,0 +1,287 @@
|
|||||||
|
"""JAX-native PHANTOM environment with robust contamination step."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import NamedTuple
|
||||||
|
|
||||||
|
try:
|
||||||
|
import jax
|
||||||
|
import jax.numpy as jnp
|
||||||
|
except ImportError as exc: # pragma: no cover
|
||||||
|
raise ImportError("engine.jax.env requires JAX") from exc
|
||||||
|
|
||||||
|
from .primitives import (
|
||||||
|
_sample_sessions_jax,
|
||||||
|
agent_probability_from_kl,
|
||||||
|
batch_kl,
|
||||||
|
compute_session_transitions,
|
||||||
|
load_transition_data,
|
||||||
|
purchase_flags,
|
||||||
|
reward_with_coi_penalty,
|
||||||
|
revenue_from_demand,
|
||||||
|
weighted_demand,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class EnvParams(NamedTuple):
|
||||||
|
n_products: int
|
||||||
|
n_sessions: int
|
||||||
|
max_episode_steps: int
|
||||||
|
max_session_steps: int
|
||||||
|
price_low: float
|
||||||
|
price_high: float
|
||||||
|
lambda_coi: float
|
||||||
|
info_value: float
|
||||||
|
robust_radius: float
|
||||||
|
margin_floor: float
|
||||||
|
margin_floor_patience: int
|
||||||
|
action_scales: jax.Array
|
||||||
|
alpha_nominal: float
|
||||||
|
alpha_candidates: jax.Array
|
||||||
|
human_T: jax.Array
|
||||||
|
agent_T: jax.Array
|
||||||
|
terminal_mask: jax.Array
|
||||||
|
purchase_mask: jax.Array
|
||||||
|
event_weights: jax.Array
|
||||||
|
start_idx: int
|
||||||
|
term_idx: int
|
||||||
|
|
||||||
|
|
||||||
|
class EnvState(NamedTuple):
|
||||||
|
prices: jax.Array
|
||||||
|
demand: jax.Array
|
||||||
|
step_count: jax.Array
|
||||||
|
low_margin_streak: jax.Array
|
||||||
|
last_agent_prob: jax.Array
|
||||||
|
last_alpha_adv: jax.Array
|
||||||
|
|
||||||
|
|
||||||
|
class CandidateEval(NamedTuple):
|
||||||
|
reward: jax.Array
|
||||||
|
revenue: jax.Array
|
||||||
|
demand: jax.Array
|
||||||
|
agent_prob: jax.Array
|
||||||
|
leakage: jax.Array
|
||||||
|
discount: jax.Array
|
||||||
|
n_purchases: jax.Array
|
||||||
|
n_agents: jax.Array
|
||||||
|
|
||||||
|
|
||||||
|
def make_env_params(
|
||||||
|
*,
|
||||||
|
n_products: int,
|
||||||
|
alpha: float,
|
||||||
|
n_sessions: int,
|
||||||
|
lambda_coi: float,
|
||||||
|
robust_radius: float,
|
||||||
|
robust_points: int,
|
||||||
|
info_value: float,
|
||||||
|
action_levels: int,
|
||||||
|
action_scale_low: float,
|
||||||
|
action_scale_high: float,
|
||||||
|
price_low: float,
|
||||||
|
price_high: float,
|
||||||
|
max_episode_steps: int,
|
||||||
|
max_session_steps: int = 40,
|
||||||
|
margin_floor: float = 0.05,
|
||||||
|
margin_floor_patience: int = 5,
|
||||||
|
prefer_behavior_data: bool = True,
|
||||||
|
) -> EnvParams:
|
||||||
|
transition = load_transition_data(prefer_data=prefer_behavior_data).to_jax()
|
||||||
|
if robust_radius <= 0.0 or robust_points <= 1:
|
||||||
|
alpha_candidates = jnp.asarray([float(alpha)], dtype=jnp.float32)
|
||||||
|
else:
|
||||||
|
lo = max(0.0, float(alpha) - float(robust_radius))
|
||||||
|
hi = min(1.0, float(alpha) + float(robust_radius))
|
||||||
|
alpha_candidates = jnp.linspace(lo, hi, int(robust_points), dtype=jnp.float32)
|
||||||
|
|
||||||
|
action_scales = jnp.linspace(
|
||||||
|
float(action_scale_low),
|
||||||
|
float(action_scale_high),
|
||||||
|
int(action_levels),
|
||||||
|
dtype=jnp.float32,
|
||||||
|
)
|
||||||
|
return EnvParams(
|
||||||
|
n_products=int(n_products),
|
||||||
|
n_sessions=int(n_sessions),
|
||||||
|
max_episode_steps=int(max_episode_steps),
|
||||||
|
max_session_steps=int(max_session_steps),
|
||||||
|
price_low=float(price_low),
|
||||||
|
price_high=float(price_high),
|
||||||
|
lambda_coi=float(lambda_coi),
|
||||||
|
info_value=float(info_value),
|
||||||
|
robust_radius=float(robust_radius),
|
||||||
|
margin_floor=float(margin_floor),
|
||||||
|
margin_floor_patience=int(margin_floor_patience),
|
||||||
|
action_scales=action_scales,
|
||||||
|
alpha_nominal=float(alpha),
|
||||||
|
alpha_candidates=alpha_candidates,
|
||||||
|
human_T=jnp.asarray(transition.human_T),
|
||||||
|
agent_T=jnp.asarray(transition.agent_T),
|
||||||
|
terminal_mask=jnp.asarray(transition.terminal_mask),
|
||||||
|
purchase_mask=jnp.asarray(transition.purchase_mask),
|
||||||
|
event_weights=jnp.asarray(transition.event_weights),
|
||||||
|
start_idx=int(transition.start_idx),
|
||||||
|
term_idx=int(transition.term_idx),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _flatten_obs(demand: jax.Array, prices: jax.Array) -> jax.Array:
|
||||||
|
return jnp.concatenate([demand.astype(jnp.float32), prices.astype(jnp.float32)])
|
||||||
|
|
||||||
|
|
||||||
|
def _decode_action(
|
||||||
|
prices: jax.Array, action: jax.Array, params: EnvParams
|
||||||
|
) -> jax.Array:
|
||||||
|
idx = jnp.clip(action.astype(jnp.int32), 0, params.action_scales.shape[0] - 1)
|
||||||
|
scale = params.action_scales[idx]
|
||||||
|
next_prices = prices * scale
|
||||||
|
return jnp.clip(next_prices, params.price_low, params.price_high)
|
||||||
|
|
||||||
|
|
||||||
|
def _evaluate_candidate(
|
||||||
|
key: jax.Array,
|
||||||
|
alpha_candidate: jax.Array,
|
||||||
|
prices: jax.Array,
|
||||||
|
params: EnvParams,
|
||||||
|
) -> CandidateEval:
|
||||||
|
states, products, actors, lengths = _sample_sessions_jax(
|
||||||
|
key,
|
||||||
|
params.human_T,
|
||||||
|
params.agent_T,
|
||||||
|
params.terminal_mask,
|
||||||
|
params.start_idx,
|
||||||
|
params.term_idx,
|
||||||
|
alpha_candidate,
|
||||||
|
params.n_products,
|
||||||
|
params.n_sessions,
|
||||||
|
params.max_session_steps,
|
||||||
|
int(params.human_T.shape[0]),
|
||||||
|
)
|
||||||
|
session_trans = compute_session_transitions(
|
||||||
|
states, lengths, int(params.human_T.shape[0])
|
||||||
|
)
|
||||||
|
delta_h, delta_a = batch_kl(session_trans, params.human_T, params.agent_T)
|
||||||
|
agent_probs = agent_probability_from_kl(delta_h, delta_a)
|
||||||
|
agent_prob = jnp.mean(agent_probs)
|
||||||
|
|
||||||
|
demand = weighted_demand(states, products, params.n_products, params.event_weights)
|
||||||
|
revenue = revenue_from_demand(prices, demand)
|
||||||
|
reward, leakage, discount = reward_with_coi_penalty(
|
||||||
|
revenue,
|
||||||
|
agent_prob,
|
||||||
|
params.lambda_coi,
|
||||||
|
params.info_value,
|
||||||
|
)
|
||||||
|
purchases = purchase_flags(states, params.purchase_mask)
|
||||||
|
return CandidateEval(
|
||||||
|
reward=reward,
|
||||||
|
revenue=revenue,
|
||||||
|
demand=demand,
|
||||||
|
agent_prob=agent_prob,
|
||||||
|
leakage=leakage,
|
||||||
|
discount=discount,
|
||||||
|
n_purchases=jnp.sum(purchases.astype(jnp.float32)),
|
||||||
|
n_agents=jnp.sum(actors.astype(jnp.float32)),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def reset_env(key: jax.Array, params: EnvParams) -> tuple[jax.Array, EnvState]:
|
||||||
|
prices = jax.random.uniform(
|
||||||
|
key,
|
||||||
|
shape=(params.n_products,),
|
||||||
|
minval=params.price_low,
|
||||||
|
maxval=params.price_high,
|
||||||
|
)
|
||||||
|
demand = jnp.zeros((params.n_products,), dtype=jnp.float32)
|
||||||
|
state = EnvState(
|
||||||
|
prices=prices,
|
||||||
|
demand=demand,
|
||||||
|
step_count=jnp.asarray(0, dtype=jnp.int32),
|
||||||
|
low_margin_streak=jnp.asarray(0, dtype=jnp.int32),
|
||||||
|
last_agent_prob=jnp.asarray(params.alpha_nominal, dtype=jnp.float32),
|
||||||
|
last_alpha_adv=jnp.asarray(params.alpha_nominal, dtype=jnp.float32),
|
||||||
|
)
|
||||||
|
return _flatten_obs(demand, prices), state
|
||||||
|
|
||||||
|
|
||||||
|
def step_env(
|
||||||
|
key: jax.Array,
|
||||||
|
state: EnvState,
|
||||||
|
action: jax.Array,
|
||||||
|
params: EnvParams,
|
||||||
|
) -> tuple[jax.Array, EnvState, jax.Array, jax.Array, dict[str, jax.Array]]:
|
||||||
|
prices = _decode_action(state.prices, action, params)
|
||||||
|
n_candidates = params.alpha_candidates.shape[0]
|
||||||
|
cand_keys = jax.random.split(key, n_candidates)
|
||||||
|
evals = jax.vmap(
|
||||||
|
lambda k, a: _evaluate_candidate(k, a, prices, params),
|
||||||
|
in_axes=(0, 0),
|
||||||
|
)(cand_keys, params.alpha_candidates)
|
||||||
|
idx = jnp.argmin(evals.reward)
|
||||||
|
|
||||||
|
demand = evals.demand[idx]
|
||||||
|
reward = evals.reward[idx]
|
||||||
|
revenue = evals.revenue[idx]
|
||||||
|
agent_prob = evals.agent_prob[idx]
|
||||||
|
leakage = evals.leakage[idx]
|
||||||
|
discount = evals.discount[idx]
|
||||||
|
n_purchases = evals.n_purchases[idx]
|
||||||
|
n_agents = evals.n_agents[idx]
|
||||||
|
alpha_adv = params.alpha_candidates[idx]
|
||||||
|
|
||||||
|
step_count = state.step_count + 1
|
||||||
|
avg_price = jnp.maximum(jnp.mean(prices), 1e-6)
|
||||||
|
avg_margin = (avg_price - params.price_low) / avg_price
|
||||||
|
next_streak = jnp.where(
|
||||||
|
avg_margin < params.margin_floor, state.low_margin_streak + 1, 0
|
||||||
|
)
|
||||||
|
|
||||||
|
margin_collapsed = next_streak >= params.margin_floor_patience
|
||||||
|
done = (step_count >= params.max_episode_steps) | margin_collapsed
|
||||||
|
|
||||||
|
next_state = EnvState(
|
||||||
|
prices=prices,
|
||||||
|
demand=demand,
|
||||||
|
step_count=step_count,
|
||||||
|
low_margin_streak=next_streak,
|
||||||
|
last_agent_prob=agent_prob,
|
||||||
|
last_alpha_adv=alpha_adv,
|
||||||
|
)
|
||||||
|
obs = _flatten_obs(demand, prices)
|
||||||
|
info = {
|
||||||
|
"revenue": revenue,
|
||||||
|
"agent_prob": agent_prob,
|
||||||
|
"alpha_adv": alpha_adv,
|
||||||
|
"coi_leakage": leakage,
|
||||||
|
"coi_discount": discount,
|
||||||
|
"n_purchases": n_purchases,
|
||||||
|
"n_agents": n_agents,
|
||||||
|
"avg_margin": avg_margin,
|
||||||
|
}
|
||||||
|
return obs, next_state, reward, done, info
|
||||||
|
|
||||||
|
|
||||||
|
class PHANTOMJAXEnv:
|
||||||
|
def __init__(self, params: EnvParams):
|
||||||
|
self.params = params
|
||||||
|
|
||||||
|
def reset(self, key: jax.Array, params: EnvParams | None = None):
|
||||||
|
return reset_env(key, self.params if params is None else params)
|
||||||
|
|
||||||
|
def step(
|
||||||
|
self,
|
||||||
|
key: jax.Array,
|
||||||
|
state: EnvState,
|
||||||
|
action: jax.Array,
|
||||||
|
params: EnvParams | None = None,
|
||||||
|
):
|
||||||
|
return step_env(key, state, action, self.params if params is None else params)
|
||||||
|
|
||||||
|
def action_space_n(self, params: EnvParams | None = None) -> int:
|
||||||
|
p = self.params if params is None else params
|
||||||
|
return int(p.action_scales.shape[0])
|
||||||
|
|
||||||
|
def observation_dim(self, params: EnvParams | None = None) -> int:
|
||||||
|
p = self.params if params is None else params
|
||||||
|
return int(p.n_products * 2)
|
||||||
493
engine/jax/primitives.py
Normal file
493
engine/jax/primitives.py
Normal file
@@ -0,0 +1,493 @@
|
|||||||
|
"""JAX-compatible primitives for PHANTOM session simulation and separability."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from functools import partial
|
||||||
|
from typing import Mapping, Sequence
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
try:
|
||||||
|
import jax
|
||||||
|
import jax.numpy as jnp
|
||||||
|
|
||||||
|
JAX_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
jax = None # type: ignore[assignment]
|
||||||
|
jnp = np # type: ignore[assignment]
|
||||||
|
JAX_AVAILABLE = False
|
||||||
|
|
||||||
|
|
||||||
|
STATE_START_KEYS = ("session_start", "start")
|
||||||
|
TERMINAL_EVENT_TOKENS = (
|
||||||
|
"session_end",
|
||||||
|
"end",
|
||||||
|
"purchase_complete",
|
||||||
|
"checkout_start",
|
||||||
|
"checkout",
|
||||||
|
)
|
||||||
|
PURCHASE_EVENT_TOKENS = (
|
||||||
|
"purchase_complete",
|
||||||
|
"purchase",
|
||||||
|
"checkout_start",
|
||||||
|
"checkout",
|
||||||
|
)
|
||||||
|
|
||||||
|
CATEGORY_WEIGHTS = {"cart": 4.0, "dwell": 2.0, "nav": 1.0, "filter": 0.5}
|
||||||
|
ACTION_CATEGORIES = {
|
||||||
|
"cart": {"add_item", "add_to_cart", "remove", "checkout", "purchase"},
|
||||||
|
"dwell": {
|
||||||
|
"hover_title",
|
||||||
|
"hover_paragraph",
|
||||||
|
"hover_link",
|
||||||
|
"hover_over_title",
|
||||||
|
"hover_over_paragraph",
|
||||||
|
"hover_over_link",
|
||||||
|
"hover_over_button",
|
||||||
|
},
|
||||||
|
"nav": {
|
||||||
|
"page_view",
|
||||||
|
"view_item",
|
||||||
|
"view",
|
||||||
|
"learn_more",
|
||||||
|
"learn_more_about_item",
|
||||||
|
"view_item_page",
|
||||||
|
"session_start",
|
||||||
|
},
|
||||||
|
"filter": {
|
||||||
|
"search",
|
||||||
|
"filter_date",
|
||||||
|
"filter_price",
|
||||||
|
"sort",
|
||||||
|
"filter_for_date",
|
||||||
|
"filter_for_price",
|
||||||
|
"filter_for_amenities",
|
||||||
|
"sort_change",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
DEFAULT_ACTION_WEIGHTS = {
|
||||||
|
action: CATEGORY_WEIGHTS[group]
|
||||||
|
for group, actions in ACTION_CATEGORIES.items()
|
||||||
|
for action in actions
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class TransitionData:
|
||||||
|
"""Dense transition kernels and per-state metadata."""
|
||||||
|
|
||||||
|
human_T: np.ndarray
|
||||||
|
agent_T: np.ndarray
|
||||||
|
terminal_mask: np.ndarray
|
||||||
|
purchase_mask: np.ndarray
|
||||||
|
event_weights: np.ndarray
|
||||||
|
event_names: tuple[str, ...]
|
||||||
|
start_idx: int
|
||||||
|
term_idx: int
|
||||||
|
|
||||||
|
def to_jax(self) -> "TransitionData":
|
||||||
|
if not JAX_AVAILABLE:
|
||||||
|
return self
|
||||||
|
return TransitionData(
|
||||||
|
human_T=jnp.asarray(self.human_T),
|
||||||
|
agent_T=jnp.asarray(self.agent_T),
|
||||||
|
terminal_mask=jnp.asarray(self.terminal_mask),
|
||||||
|
purchase_mask=jnp.asarray(self.purchase_mask),
|
||||||
|
event_weights=jnp.asarray(self.event_weights),
|
||||||
|
event_names=self.event_names,
|
||||||
|
start_idx=int(self.start_idx),
|
||||||
|
term_idx=int(self.term_idx),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class SessionBatch:
|
||||||
|
states: np.ndarray
|
||||||
|
products: np.ndarray
|
||||||
|
actors: np.ndarray
|
||||||
|
lengths: np.ndarray
|
||||||
|
|
||||||
|
|
||||||
|
def _event_weight(name: str) -> float:
|
||||||
|
if name in DEFAULT_ACTION_WEIGHTS:
|
||||||
|
return float(DEFAULT_ACTION_WEIGHTS[name])
|
||||||
|
if name.startswith("hover"):
|
||||||
|
return float(CATEGORY_WEIGHTS["dwell"])
|
||||||
|
if name.startswith("filter") or name in {"search", "sort", "sort_change"}:
|
||||||
|
return float(CATEGORY_WEIGHTS["filter"])
|
||||||
|
if name.startswith("add") or name in {
|
||||||
|
"checkout",
|
||||||
|
"checkout_start",
|
||||||
|
"purchase",
|
||||||
|
"remove_item",
|
||||||
|
"purchase_complete",
|
||||||
|
}:
|
||||||
|
return float(CATEGORY_WEIGHTS["cart"])
|
||||||
|
if any(token in name for token in TERMINAL_EVENT_TOKENS):
|
||||||
|
return 0.0
|
||||||
|
return float(CATEGORY_WEIGHTS["nav"])
|
||||||
|
|
||||||
|
|
||||||
|
def _is_terminal(name: str) -> bool:
|
||||||
|
return any(token in name for token in TERMINAL_EVENT_TOKENS)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_purchase(name: str) -> bool:
|
||||||
|
return any(token in name for token in PURCHASE_EVENT_TOKENS)
|
||||||
|
|
||||||
|
|
||||||
|
def _collect_events(*transitions: Mapping[str, Mapping[str, float]]) -> tuple[str, ...]:
|
||||||
|
names: set[str] = set()
|
||||||
|
for trans in transitions:
|
||||||
|
for src, dsts in trans.items():
|
||||||
|
names.add(src)
|
||||||
|
names.update(dsts.keys())
|
||||||
|
names.discard("__terminal__")
|
||||||
|
return tuple(sorted(names))
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_rows(matrix: np.ndarray, term_idx: int) -> np.ndarray:
|
||||||
|
row_sums = matrix.sum(axis=1, keepdims=True)
|
||||||
|
dead_rows = np.isclose(row_sums.squeeze(-1), 0.0)
|
||||||
|
if np.any(dead_rows):
|
||||||
|
matrix[dead_rows] = 0.0
|
||||||
|
matrix[dead_rows, term_idx] = 1.0
|
||||||
|
row_sums = matrix.sum(axis=1, keepdims=True)
|
||||||
|
return matrix / np.maximum(row_sums, 1e-8)
|
||||||
|
|
||||||
|
|
||||||
|
def _dense_from_dict(
|
||||||
|
transitions: Mapping[str, Mapping[str, float]],
|
||||||
|
event_to_idx: Mapping[str, int],
|
||||||
|
term_idx: int,
|
||||||
|
) -> np.ndarray:
|
||||||
|
n_states = len(event_to_idx)
|
||||||
|
matrix = np.zeros((n_states, n_states), dtype=np.float32)
|
||||||
|
for src, dsts in transitions.items():
|
||||||
|
i = event_to_idx.get(src)
|
||||||
|
if i is None:
|
||||||
|
continue
|
||||||
|
for dst, prob in dsts.items():
|
||||||
|
j = event_to_idx.get(dst)
|
||||||
|
if j is None:
|
||||||
|
continue
|
||||||
|
matrix[i, j] += float(prob)
|
||||||
|
return _normalize_rows(matrix, term_idx)
|
||||||
|
|
||||||
|
|
||||||
|
def compile_transition_data(
|
||||||
|
human_transitions: Mapping[str, Mapping[str, float]],
|
||||||
|
agent_transitions: Mapping[str, Mapping[str, float]],
|
||||||
|
) -> TransitionData:
|
||||||
|
event_names = _collect_events(human_transitions, agent_transitions)
|
||||||
|
if not event_names:
|
||||||
|
return fallback_transition_data()
|
||||||
|
|
||||||
|
event_names = tuple([*event_names, "__terminal__"])
|
||||||
|
term_idx = len(event_names) - 1
|
||||||
|
event_to_idx = {name: i for i, name in enumerate(event_names)}
|
||||||
|
|
||||||
|
human_T = _dense_from_dict(human_transitions, event_to_idx, term_idx)
|
||||||
|
agent_T = _dense_from_dict(agent_transitions, event_to_idx, term_idx)
|
||||||
|
|
||||||
|
terminal_mask = np.array([_is_terminal(name) for name in event_names], dtype=bool)
|
||||||
|
purchase_mask = np.array([_is_purchase(name) for name in event_names], dtype=bool)
|
||||||
|
event_weights = np.array(
|
||||||
|
[_event_weight(name) for name in event_names], dtype=np.float32
|
||||||
|
)
|
||||||
|
|
||||||
|
terminal_mask[term_idx] = True
|
||||||
|
|
||||||
|
for idx, is_term in enumerate(terminal_mask):
|
||||||
|
if not is_term:
|
||||||
|
continue
|
||||||
|
human_T[idx] = 0.0
|
||||||
|
agent_T[idx] = 0.0
|
||||||
|
human_T[idx, idx] = 1.0
|
||||||
|
agent_T[idx, idx] = 1.0
|
||||||
|
|
||||||
|
start_idx = 0
|
||||||
|
for key in STATE_START_KEYS:
|
||||||
|
if key in event_to_idx:
|
||||||
|
start_idx = int(event_to_idx[key])
|
||||||
|
break
|
||||||
|
|
||||||
|
return TransitionData(
|
||||||
|
human_T=human_T,
|
||||||
|
agent_T=agent_T,
|
||||||
|
terminal_mask=terminal_mask,
|
||||||
|
purchase_mask=purchase_mask,
|
||||||
|
event_weights=event_weights,
|
||||||
|
event_names=event_names,
|
||||||
|
start_idx=start_idx,
|
||||||
|
term_idx=term_idx,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def fallback_transition_data() -> TransitionData:
|
||||||
|
human = {
|
||||||
|
"session_start": {
|
||||||
|
"page_view": 0.80,
|
||||||
|
"view_item_page": 0.15,
|
||||||
|
"session_end": 0.05,
|
||||||
|
},
|
||||||
|
"page_view": {"view_item_page": 0.55, "search": 0.25, "session_end": 0.20},
|
||||||
|
"view_item_page": {
|
||||||
|
"learn_more_about_item": 0.40,
|
||||||
|
"add_item_to_cart": 0.28,
|
||||||
|
"session_end": 0.32,
|
||||||
|
},
|
||||||
|
"learn_more_about_item": {
|
||||||
|
"add_item_to_cart": 0.50,
|
||||||
|
"view_item_page": 0.30,
|
||||||
|
"session_end": 0.20,
|
||||||
|
},
|
||||||
|
"add_item_to_cart": {
|
||||||
|
"checkout_start": 0.58,
|
||||||
|
"view_item_page": 0.24,
|
||||||
|
"session_end": 0.18,
|
||||||
|
},
|
||||||
|
"checkout_start": {"purchase_complete": 0.70, "session_end": 0.30},
|
||||||
|
"purchase_complete": {"session_end": 1.0},
|
||||||
|
}
|
||||||
|
agent = {
|
||||||
|
"session_start": {
|
||||||
|
"page_view": 0.90,
|
||||||
|
"view_item_page": 0.08,
|
||||||
|
"session_end": 0.02,
|
||||||
|
},
|
||||||
|
"page_view": {"view_item_page": 0.40, "search": 0.35, "session_end": 0.25},
|
||||||
|
"view_item_page": {
|
||||||
|
"learn_more_about_item": 0.55,
|
||||||
|
"add_item_to_cart": 0.15,
|
||||||
|
"session_end": 0.30,
|
||||||
|
},
|
||||||
|
"learn_more_about_item": {
|
||||||
|
"view_item_page": 0.45,
|
||||||
|
"add_item_to_cart": 0.20,
|
||||||
|
"session_end": 0.35,
|
||||||
|
},
|
||||||
|
"add_item_to_cart": {
|
||||||
|
"checkout_start": 0.42,
|
||||||
|
"view_item_page": 0.28,
|
||||||
|
"session_end": 0.30,
|
||||||
|
},
|
||||||
|
"checkout_start": {"purchase_complete": 0.52, "session_end": 0.48},
|
||||||
|
"purchase_complete": {"session_end": 1.0},
|
||||||
|
}
|
||||||
|
return compile_transition_data(human, agent)
|
||||||
|
|
||||||
|
|
||||||
|
def load_transition_data(prefer_data: bool = True) -> TransitionData:
|
||||||
|
if not prefer_data:
|
||||||
|
return fallback_transition_data()
|
||||||
|
try:
|
||||||
|
from ..lib.behavior import get_transition_models
|
||||||
|
|
||||||
|
human_trans, agent_trans = get_transition_models()
|
||||||
|
return compile_transition_data(human_trans, agent_trans)
|
||||||
|
except Exception:
|
||||||
|
return fallback_transition_data()
|
||||||
|
|
||||||
|
|
||||||
|
if JAX_AVAILABLE:
|
||||||
|
|
||||||
|
@partial(jax.jit, static_argnums=(8, 9, 10))
|
||||||
|
def _sample_sessions_jax(
|
||||||
|
key: jax.Array,
|
||||||
|
human_T: jax.Array,
|
||||||
|
agent_T: jax.Array,
|
||||||
|
terminal_mask: jax.Array,
|
||||||
|
start_idx: int,
|
||||||
|
term_idx: int,
|
||||||
|
alpha: float,
|
||||||
|
n_products: int,
|
||||||
|
n_sessions: int,
|
||||||
|
max_steps: int,
|
||||||
|
n_states: int,
|
||||||
|
) -> tuple[jax.Array, jax.Array, jax.Array, jax.Array]:
|
||||||
|
k_actor, k_product, k_step = jax.random.split(key, 3)
|
||||||
|
actor_draw = jax.random.uniform(k_actor, (n_sessions,))
|
||||||
|
actors = (actor_draw < alpha).astype(jnp.int32)
|
||||||
|
products = jax.random.randint(
|
||||||
|
k_product, (n_sessions,), 0, n_products, dtype=jnp.int32
|
||||||
|
)
|
||||||
|
|
||||||
|
active_init = jnp.ones((n_sessions,), dtype=jnp.bool_)
|
||||||
|
state_init = jnp.full((n_sessions,), int(start_idx), dtype=jnp.int32)
|
||||||
|
|
||||||
|
def _scan_step(carry, _):
|
||||||
|
states, active, rng = carry
|
||||||
|
rng, k = jax.random.split(rng)
|
||||||
|
probs_h = human_T[states]
|
||||||
|
probs_a = agent_T[states]
|
||||||
|
probs = jnp.where(actors[:, None] == 0, probs_h, probs_a)
|
||||||
|
next_state = jax.random.categorical(k, jnp.log(probs + 1e-10), axis=-1)
|
||||||
|
next_state = jnp.where(active, next_state, int(term_idx))
|
||||||
|
emitted = jnp.where(active, next_state, -1)
|
||||||
|
is_terminal = terminal_mask[jnp.clip(next_state, 0, n_states - 1)]
|
||||||
|
next_active = active & (~is_terminal)
|
||||||
|
carry_states = jnp.where(next_active, next_state, int(term_idx))
|
||||||
|
return (carry_states, next_active, rng), emitted
|
||||||
|
|
||||||
|
_, state_t = jax.lax.scan(
|
||||||
|
_scan_step, (state_init, active_init, k_step), None, length=max_steps
|
||||||
|
)
|
||||||
|
states = state_t.T
|
||||||
|
lengths = jnp.sum(states >= 0, axis=1, dtype=jnp.int32)
|
||||||
|
return states, products, actors, lengths
|
||||||
|
|
||||||
|
|
||||||
|
def sample_sessions(
|
||||||
|
key,
|
||||||
|
transition_data: TransitionData,
|
||||||
|
alpha: float,
|
||||||
|
n_products: int,
|
||||||
|
n_sessions: int,
|
||||||
|
max_steps: int,
|
||||||
|
) -> SessionBatch:
|
||||||
|
if JAX_AVAILABLE:
|
||||||
|
td = transition_data.to_jax()
|
||||||
|
states, products, actors, lengths = _sample_sessions_jax(
|
||||||
|
key,
|
||||||
|
td.human_T,
|
||||||
|
td.agent_T,
|
||||||
|
td.terminal_mask,
|
||||||
|
int(td.start_idx),
|
||||||
|
int(td.term_idx),
|
||||||
|
float(alpha),
|
||||||
|
int(n_products),
|
||||||
|
int(n_sessions),
|
||||||
|
int(max_steps),
|
||||||
|
int(td.human_T.shape[0]),
|
||||||
|
)
|
||||||
|
return SessionBatch(
|
||||||
|
states=states, products=products, actors=actors, lengths=lengths
|
||||||
|
)
|
||||||
|
|
||||||
|
rng = np.random.default_rng(int(np.asarray(key).reshape(-1)[0]))
|
||||||
|
n_states = transition_data.human_T.shape[0]
|
||||||
|
products = rng.integers(0, n_products, size=n_sessions, dtype=np.int32)
|
||||||
|
actors = (rng.random(size=n_sessions) < alpha).astype(np.int32)
|
||||||
|
states = np.full((n_sessions, max_steps), -1, dtype=np.int32)
|
||||||
|
lengths = np.zeros((n_sessions,), dtype=np.int32)
|
||||||
|
for i in range(n_sessions):
|
||||||
|
current = int(transition_data.start_idx)
|
||||||
|
mat = transition_data.agent_T if actors[i] == 1 else transition_data.human_T
|
||||||
|
for t in range(max_steps):
|
||||||
|
nxt = int(rng.choice(n_states, p=mat[current]))
|
||||||
|
states[i, t] = nxt
|
||||||
|
if transition_data.terminal_mask[nxt]:
|
||||||
|
lengths[i] = t + 1
|
||||||
|
break
|
||||||
|
current = nxt
|
||||||
|
if lengths[i] == 0:
|
||||||
|
lengths[i] = max_steps
|
||||||
|
return SessionBatch(
|
||||||
|
states=states, products=products, actors=actors, lengths=lengths
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if JAX_AVAILABLE:
|
||||||
|
|
||||||
|
@partial(jax.jit, static_argnums=(2,))
|
||||||
|
def compute_session_transitions(states, lengths, n_states: int):
|
||||||
|
src = states[:, :-1]
|
||||||
|
dst = states[:, 1:]
|
||||||
|
time_idx = jnp.arange(src.shape[1])[None, :]
|
||||||
|
valid = (src >= 0) & (dst >= 0) & (time_idx < (lengths[:, None] - 1))
|
||||||
|
src_clip = jnp.clip(src, 0, n_states - 1)
|
||||||
|
dst_clip = jnp.clip(dst, 0, n_states - 1)
|
||||||
|
src_oh = jax.nn.one_hot(src_clip, n_states)
|
||||||
|
dst_oh = jax.nn.one_hot(dst_clip, n_states)
|
||||||
|
counts = jnp.einsum(
|
||||||
|
"nti,ntj,nt->nij", src_oh, dst_oh, valid.astype(jnp.float32)
|
||||||
|
)
|
||||||
|
row_sums = jnp.sum(counts, axis=-1, keepdims=True)
|
||||||
|
return counts / (row_sums + 1e-10)
|
||||||
|
|
||||||
|
|
||||||
|
else:
|
||||||
|
|
||||||
|
def compute_session_transitions(states, lengths, n_states: int):
|
||||||
|
trans = np.zeros((states.shape[0], n_states, n_states), dtype=np.float32)
|
||||||
|
for i in range(states.shape[0]):
|
||||||
|
for t in range(max(int(lengths[i]) - 1, 0)):
|
||||||
|
s = int(states[i, t])
|
||||||
|
d = int(states[i, t + 1])
|
||||||
|
if s >= 0 and d >= 0:
|
||||||
|
trans[i, s, d] += 1.0
|
||||||
|
row_sums = trans.sum(axis=-1, keepdims=True)
|
||||||
|
return trans / (row_sums + 1e-10)
|
||||||
|
|
||||||
|
|
||||||
|
def batch_kl(P, Q_human, Q_agent, eps: float = 1e-10):
|
||||||
|
p = P + eps
|
||||||
|
p = p / jnp.sum(p, axis=-1, keepdims=True)
|
||||||
|
qh = Q_human[None, ...] + eps
|
||||||
|
qa = Q_agent[None, ...] + eps
|
||||||
|
delta_h = jnp.sum(p * jnp.log(p / qh), axis=(1, 2))
|
||||||
|
delta_a = jnp.sum(p * jnp.log(p / qa), axis=(1, 2))
|
||||||
|
return delta_h, delta_a
|
||||||
|
|
||||||
|
|
||||||
|
if JAX_AVAILABLE:
|
||||||
|
batch_kl = jax.jit(batch_kl)
|
||||||
|
|
||||||
|
|
||||||
|
def agent_probability_from_kl(delta_h, delta_a, temperature: float = 1.0):
|
||||||
|
t = jnp.maximum(float(temperature), 1e-6)
|
||||||
|
exp_h = jnp.exp(-delta_h / t)
|
||||||
|
exp_a = jnp.exp(-delta_a / t)
|
||||||
|
return exp_a / (exp_h + exp_a + 1e-10)
|
||||||
|
|
||||||
|
|
||||||
|
def estimate_alpha_from_kl(delta_h, delta_a, beta: float = 2.0):
|
||||||
|
logits = beta * (delta_h - delta_a)
|
||||||
|
return 1.0 / (1.0 + jnp.exp(-logits))
|
||||||
|
|
||||||
|
|
||||||
|
def weighted_demand(states, products, n_products: int, event_weights):
|
||||||
|
valid = states >= 0
|
||||||
|
state_clip = jnp.clip(states, 0, event_weights.shape[0] - 1)
|
||||||
|
weights = event_weights[state_clip] * valid
|
||||||
|
per_session = jnp.sum(weights, axis=1)
|
||||||
|
demand = jnp.zeros((n_products,), dtype=jnp.float32)
|
||||||
|
demand = demand.at[products].add(per_session)
|
||||||
|
total = jnp.sum(demand)
|
||||||
|
return jnp.where(total > 0.0, (demand / total) * 100.0, demand)
|
||||||
|
|
||||||
|
|
||||||
|
if JAX_AVAILABLE:
|
||||||
|
weighted_demand = jax.jit(weighted_demand, static_argnums=(2,))
|
||||||
|
|
||||||
|
|
||||||
|
def purchase_flags(states, purchase_mask):
|
||||||
|
state_clip = jnp.clip(states, 0, purchase_mask.shape[0] - 1)
|
||||||
|
hits = purchase_mask[state_clip] & (states >= 0)
|
||||||
|
return jnp.any(hits, axis=1)
|
||||||
|
|
||||||
|
|
||||||
|
if JAX_AVAILABLE:
|
||||||
|
purchase_flags = jax.jit(purchase_flags)
|
||||||
|
|
||||||
|
|
||||||
|
def revenue_from_demand(prices, demand):
|
||||||
|
return jnp.dot(prices, demand)
|
||||||
|
|
||||||
|
|
||||||
|
if JAX_AVAILABLE:
|
||||||
|
revenue_from_demand = jax.jit(revenue_from_demand)
|
||||||
|
|
||||||
|
|
||||||
|
def reward_with_coi_penalty(
|
||||||
|
revenue, agent_prob: float, lambda_coi: float, info_value: float
|
||||||
|
):
|
||||||
|
leakage = agent_prob * info_value
|
||||||
|
discount = jnp.clip(1.0 - lambda_coi * leakage, 0.0, 1.0)
|
||||||
|
return revenue * discount, leakage, discount
|
||||||
|
|
||||||
|
|
||||||
|
if JAX_AVAILABLE:
|
||||||
|
reward_with_coi_penalty = jax.jit(reward_with_coi_penalty)
|
||||||
5
engine/jax/requirements.txt
Normal file
5
engine/jax/requirements.txt
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
flax>=0.8.0
|
||||||
|
optax>=0.2.0
|
||||||
|
distrax>=0.1.5
|
||||||
|
orbax-checkpoint>=0.5.0
|
||||||
|
chex>=0.1.8
|
||||||
471
engine/jax/train.py
Normal file
471
engine/jax/train.py
Normal file
@@ -0,0 +1,471 @@
|
|||||||
|
"""Pure JAX PPO trainer for the PHANTOM environment."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, NamedTuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
try:
|
||||||
|
import jax
|
||||||
|
import jax.numpy as jnp
|
||||||
|
import distrax
|
||||||
|
import flax.linen as nn
|
||||||
|
import optax
|
||||||
|
from flax import serialization
|
||||||
|
from flax.linen.initializers import constant, orthogonal
|
||||||
|
from flax.training.train_state import TrainState
|
||||||
|
|
||||||
|
HAS_JAX_STACK = True
|
||||||
|
except ImportError:
|
||||||
|
jax = None # type: ignore[assignment]
|
||||||
|
jnp = None # type: ignore[assignment]
|
||||||
|
distrax = None # type: ignore[assignment]
|
||||||
|
optax = None # type: ignore[assignment]
|
||||||
|
serialization = None # type: ignore[assignment]
|
||||||
|
|
||||||
|
class _ModuleStub:
|
||||||
|
pass
|
||||||
|
|
||||||
|
class _NNStub:
|
||||||
|
Module = _ModuleStub
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def compact(fn):
|
||||||
|
return fn
|
||||||
|
|
||||||
|
nn = _NNStub() # type: ignore[assignment]
|
||||||
|
|
||||||
|
def constant(*_args, **_kwargs): # type: ignore[override]
|
||||||
|
return None
|
||||||
|
|
||||||
|
def orthogonal(*_args, **_kwargs): # type: ignore[override]
|
||||||
|
return None
|
||||||
|
|
||||||
|
class TrainState: # type: ignore[override]
|
||||||
|
pass
|
||||||
|
|
||||||
|
HAS_JAX_STACK = False
|
||||||
|
|
||||||
|
from .env import PHANTOMJAXEnv, make_env_params
|
||||||
|
|
||||||
|
|
||||||
|
class ActorCritic(nn.Module):
|
||||||
|
action_dim: int
|
||||||
|
activation: str = "tanh"
|
||||||
|
|
||||||
|
@nn.compact
|
||||||
|
def __call__(self, x):
|
||||||
|
activation_fn = nn.relu if self.activation == "relu" else nn.tanh
|
||||||
|
|
||||||
|
actor = nn.Dense(
|
||||||
|
64,
|
||||||
|
kernel_init=orthogonal(np.sqrt(2.0)),
|
||||||
|
bias_init=constant(0.0),
|
||||||
|
)(x)
|
||||||
|
actor = activation_fn(actor)
|
||||||
|
actor = nn.Dense(
|
||||||
|
64,
|
||||||
|
kernel_init=orthogonal(np.sqrt(2.0)),
|
||||||
|
bias_init=constant(0.0),
|
||||||
|
)(actor)
|
||||||
|
actor = activation_fn(actor)
|
||||||
|
logits = nn.Dense(
|
||||||
|
self.action_dim,
|
||||||
|
kernel_init=orthogonal(0.01),
|
||||||
|
bias_init=constant(0.0),
|
||||||
|
)(actor)
|
||||||
|
|
||||||
|
critic = nn.Dense(
|
||||||
|
64,
|
||||||
|
kernel_init=orthogonal(np.sqrt(2.0)),
|
||||||
|
bias_init=constant(0.0),
|
||||||
|
)(x)
|
||||||
|
critic = activation_fn(critic)
|
||||||
|
critic = nn.Dense(
|
||||||
|
64,
|
||||||
|
kernel_init=orthogonal(np.sqrt(2.0)),
|
||||||
|
bias_init=constant(0.0),
|
||||||
|
)(critic)
|
||||||
|
critic = activation_fn(critic)
|
||||||
|
value = nn.Dense(1, kernel_init=orthogonal(1.0), bias_init=constant(0.0))(
|
||||||
|
critic
|
||||||
|
)
|
||||||
|
return distrax.Categorical(logits=logits), jnp.squeeze(value, axis=-1)
|
||||||
|
|
||||||
|
|
||||||
|
class Transition(NamedTuple):
|
||||||
|
done: jax.Array
|
||||||
|
action: jax.Array
|
||||||
|
value: jax.Array
|
||||||
|
reward: jax.Array
|
||||||
|
log_prob: jax.Array
|
||||||
|
obs: jax.Array
|
||||||
|
info: dict[str, jax.Array]
|
||||||
|
|
||||||
|
|
||||||
|
def _jax_cfg(cfg: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
out = {
|
||||||
|
"algo": str(cfg.get("algo", "ppo")).lower(),
|
||||||
|
"seed": int(cfg.get("seed", 42)),
|
||||||
|
"learning_rate": float(cfg.get("learning_rate", 3e-4)),
|
||||||
|
"gamma": float(cfg.get("gamma", 0.99)),
|
||||||
|
"gae_lambda": float(cfg.get("gae_lambda", 0.95)),
|
||||||
|
"clip_range": float(cfg.get("clip_range", 0.2)),
|
||||||
|
"ent_coef": float(cfg.get("ent_coef", 0.01)),
|
||||||
|
"vf_coef": float(cfg.get("vf_coef", 0.5)),
|
||||||
|
"max_grad_norm": float(cfg.get("max_grad_norm", 0.5)),
|
||||||
|
"activation": str(cfg.get("activation", "relu")),
|
||||||
|
"total_timesteps": int(cfg.get("total_timesteps", 50_000)),
|
||||||
|
"eval_episodes": int(cfg.get("eval_episodes", 5)),
|
||||||
|
"model_dir": str(cfg.get("model_dir", "engine/models")),
|
||||||
|
"n_products": int(cfg.get("n_products", 10)),
|
||||||
|
"N": int(cfg.get("N", 100)),
|
||||||
|
"alpha": float(cfg.get("alpha", 0.3)),
|
||||||
|
"lambda_coi": float(cfg.get("lambda_coi", 0.2)),
|
||||||
|
"robust_radius": float(cfg.get("robust_radius", 0.15)),
|
||||||
|
"robust_points": int(cfg.get("robust_points", 5)),
|
||||||
|
"info_value": float(cfg.get("info_value", 1.0)),
|
||||||
|
"price_low": float(cfg.get("price_low", 10.0)),
|
||||||
|
"price_high": float(cfg.get("price_high", 150.0)),
|
||||||
|
"action_levels": int(cfg.get("action_levels", 9)),
|
||||||
|
"action_scale_low": float(cfg.get("action_scale_low", 0.8)),
|
||||||
|
"action_scale_high": float(cfg.get("action_scale_high", 1.2)),
|
||||||
|
"max_episode_steps": int(cfg.get("max_steps", 100)),
|
||||||
|
"max_session_steps": int(cfg.get("max_session_steps", 40)),
|
||||||
|
"margin_floor": float(cfg.get("margin_floor", 0.05)),
|
||||||
|
"margin_floor_patience": int(cfg.get("margin_floor_patience", 5)),
|
||||||
|
"prefer_behavior_data": bool(cfg.get("prefer_behavior_data", True)),
|
||||||
|
"num_envs": int(cfg.get("jax_num_envs", 16)),
|
||||||
|
"num_steps": int(cfg.get("jax_num_steps", 128)),
|
||||||
|
"num_minibatches": int(cfg.get("jax_num_minibatches", 4)),
|
||||||
|
"update_epochs": int(cfg.get("jax_update_epochs", 4)),
|
||||||
|
"anneal_lr": bool(cfg.get("jax_anneal_lr", True)),
|
||||||
|
}
|
||||||
|
rollout = out["num_envs"] * out["num_steps"]
|
||||||
|
out["num_updates"] = max(1, out["total_timesteps"] // max(rollout, 1))
|
||||||
|
out["minibatch_size"] = max(1, rollout // max(out["num_minibatches"], 1))
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def _select_env_state(done: jax.Array, keep: jax.Array, reset: jax.Array) -> jax.Array:
|
||||||
|
mask = done
|
||||||
|
while mask.ndim < keep.ndim:
|
||||||
|
mask = mask[..., None]
|
||||||
|
return jnp.where(mask, reset, keep)
|
||||||
|
|
||||||
|
|
||||||
|
def make_train(config: dict[str, Any]):
|
||||||
|
cfg = _jax_cfg(config)
|
||||||
|
env_params = make_env_params(
|
||||||
|
n_products=cfg["n_products"],
|
||||||
|
alpha=cfg["alpha"],
|
||||||
|
n_sessions=cfg["N"],
|
||||||
|
lambda_coi=cfg["lambda_coi"],
|
||||||
|
robust_radius=cfg["robust_radius"],
|
||||||
|
robust_points=cfg["robust_points"],
|
||||||
|
info_value=cfg["info_value"],
|
||||||
|
action_levels=cfg["action_levels"],
|
||||||
|
action_scale_low=cfg["action_scale_low"],
|
||||||
|
action_scale_high=cfg["action_scale_high"],
|
||||||
|
price_low=cfg["price_low"],
|
||||||
|
price_high=cfg["price_high"],
|
||||||
|
max_episode_steps=cfg["max_episode_steps"],
|
||||||
|
max_session_steps=cfg["max_session_steps"],
|
||||||
|
margin_floor=cfg["margin_floor"],
|
||||||
|
margin_floor_patience=cfg["margin_floor_patience"],
|
||||||
|
prefer_behavior_data=cfg["prefer_behavior_data"],
|
||||||
|
)
|
||||||
|
env = PHANTOMJAXEnv(env_params)
|
||||||
|
network = ActorCritic(env.action_space_n(), activation=cfg["activation"])
|
||||||
|
|
||||||
|
def linear_schedule(count: jax.Array) -> jax.Array:
|
||||||
|
updates_done = count // (cfg["num_minibatches"] * cfg["update_epochs"])
|
||||||
|
frac = 1.0 - updates_done / max(cfg["num_updates"], 1)
|
||||||
|
return cfg["learning_rate"] * frac
|
||||||
|
|
||||||
|
def train(rng: jax.Array):
|
||||||
|
rng, init_key = jax.random.split(rng)
|
||||||
|
init_obs = jnp.zeros((env.observation_dim(),), dtype=jnp.float32)
|
||||||
|
params = network.init(init_key, init_obs)
|
||||||
|
|
||||||
|
if cfg["anneal_lr"]:
|
||||||
|
tx = optax.chain(
|
||||||
|
optax.clip_by_global_norm(cfg["max_grad_norm"]),
|
||||||
|
optax.adam(learning_rate=linear_schedule, eps=1e-5),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
tx = optax.chain(
|
||||||
|
optax.clip_by_global_norm(cfg["max_grad_norm"]),
|
||||||
|
optax.adam(cfg["learning_rate"], eps=1e-5),
|
||||||
|
)
|
||||||
|
train_state = TrainState.create(apply_fn=network.apply, params=params, tx=tx)
|
||||||
|
|
||||||
|
rng, reset_key = jax.random.split(rng)
|
||||||
|
reset_keys = jax.random.split(reset_key, cfg["num_envs"])
|
||||||
|
obs, env_state = jax.vmap(env.reset)(reset_keys)
|
||||||
|
|
||||||
|
def _update_step(runner_state, _):
|
||||||
|
def _env_step(runner_state, _):
|
||||||
|
train_state, env_state, last_obs, rng = runner_state
|
||||||
|
rng, action_key = jax.random.split(rng)
|
||||||
|
policy, value = network.apply(train_state.params, last_obs)
|
||||||
|
action = policy.sample(seed=action_key)
|
||||||
|
log_prob = policy.log_prob(action)
|
||||||
|
|
||||||
|
rng, step_key = jax.random.split(rng)
|
||||||
|
step_keys = jax.random.split(step_key, cfg["num_envs"])
|
||||||
|
nxt_obs, nxt_state, reward, done, info = jax.vmap(
|
||||||
|
env.step,
|
||||||
|
in_axes=(0, 0, 0),
|
||||||
|
)(step_keys, env_state, action)
|
||||||
|
|
||||||
|
rng, reset_key = jax.random.split(rng)
|
||||||
|
reset_keys = jax.random.split(reset_key, cfg["num_envs"])
|
||||||
|
rst_obs, rst_state = jax.vmap(env.reset)(reset_keys)
|
||||||
|
obs_next = jnp.where(done[:, None], rst_obs, nxt_obs)
|
||||||
|
env_next = jax.tree_util.tree_map(
|
||||||
|
lambda keep, reset: _select_env_state(done, keep, reset),
|
||||||
|
nxt_state,
|
||||||
|
rst_state,
|
||||||
|
)
|
||||||
|
transition = Transition(
|
||||||
|
done=done,
|
||||||
|
action=action,
|
||||||
|
value=value,
|
||||||
|
reward=reward,
|
||||||
|
log_prob=log_prob,
|
||||||
|
obs=last_obs,
|
||||||
|
info=info,
|
||||||
|
)
|
||||||
|
return (train_state, env_next, obs_next, rng), transition
|
||||||
|
|
||||||
|
runner_state, traj_batch = jax.lax.scan(
|
||||||
|
_env_step,
|
||||||
|
runner_state,
|
||||||
|
None,
|
||||||
|
length=cfg["num_steps"],
|
||||||
|
)
|
||||||
|
|
||||||
|
train_state, env_state, last_obs, rng = runner_state
|
||||||
|
_, last_value = network.apply(train_state.params, last_obs)
|
||||||
|
|
||||||
|
def _compute_gae(traj_batch, last_value):
|
||||||
|
def _gae_step(carry, transition):
|
||||||
|
gae, next_value = carry
|
||||||
|
delta = (
|
||||||
|
transition.reward
|
||||||
|
+ cfg["gamma"] * next_value * (1.0 - transition.done)
|
||||||
|
- transition.value
|
||||||
|
)
|
||||||
|
gae = (
|
||||||
|
delta
|
||||||
|
+ cfg["gamma"]
|
||||||
|
* cfg["gae_lambda"]
|
||||||
|
* (1.0 - transition.done)
|
||||||
|
* gae
|
||||||
|
)
|
||||||
|
return (gae, transition.value), gae
|
||||||
|
|
||||||
|
_, advantages = jax.lax.scan(
|
||||||
|
_gae_step,
|
||||||
|
(jnp.zeros_like(last_value), last_value),
|
||||||
|
traj_batch,
|
||||||
|
reverse=True,
|
||||||
|
unroll=16,
|
||||||
|
)
|
||||||
|
targets = advantages + traj_batch.value
|
||||||
|
return advantages, targets
|
||||||
|
|
||||||
|
advantages, targets = _compute_gae(traj_batch, last_value)
|
||||||
|
|
||||||
|
def _update_epoch(update_state, _):
|
||||||
|
def _update_minibatch(train_state, batch_info):
|
||||||
|
traj_b, adv_b, tgt_b = batch_info
|
||||||
|
|
||||||
|
def _loss_fn(params, traj_b, adv_b, tgt_b):
|
||||||
|
policy, value = network.apply(params, traj_b.obs)
|
||||||
|
log_prob = policy.log_prob(traj_b.action)
|
||||||
|
|
||||||
|
value_clipped = traj_b.value + (value - traj_b.value).clip(
|
||||||
|
-cfg["clip_range"], cfg["clip_range"]
|
||||||
|
)
|
||||||
|
value_loss = (
|
||||||
|
0.5
|
||||||
|
* jnp.maximum(
|
||||||
|
jnp.square(value - tgt_b),
|
||||||
|
jnp.square(value_clipped - tgt_b),
|
||||||
|
).mean()
|
||||||
|
)
|
||||||
|
|
||||||
|
adv_norm = (adv_b - adv_b.mean()) / (adv_b.std() + 1e-8)
|
||||||
|
ratio = jnp.exp(log_prob - traj_b.log_prob)
|
||||||
|
loss_actor = -jnp.minimum(
|
||||||
|
ratio * adv_norm,
|
||||||
|
jnp.clip(
|
||||||
|
ratio,
|
||||||
|
1.0 - cfg["clip_range"],
|
||||||
|
1.0 + cfg["clip_range"],
|
||||||
|
)
|
||||||
|
* adv_norm,
|
||||||
|
).mean()
|
||||||
|
entropy = policy.entropy().mean()
|
||||||
|
total_loss = (
|
||||||
|
loss_actor
|
||||||
|
+ cfg["vf_coef"] * value_loss
|
||||||
|
- cfg["ent_coef"] * entropy
|
||||||
|
)
|
||||||
|
return total_loss, (value_loss, loss_actor, entropy)
|
||||||
|
|
||||||
|
grad_fn = jax.value_and_grad(_loss_fn, has_aux=True)
|
||||||
|
(_, _), grads = grad_fn(train_state.params, traj_b, adv_b, tgt_b)
|
||||||
|
train_state = train_state.apply_gradients(grads=grads)
|
||||||
|
return train_state, jnp.asarray(0.0, dtype=jnp.float32)
|
||||||
|
|
||||||
|
train_state, traj_batch, advantages, targets, rng = update_state
|
||||||
|
rng, perm_key = jax.random.split(rng)
|
||||||
|
batch_size = cfg["num_envs"] * cfg["num_steps"]
|
||||||
|
permutation = jax.random.permutation(perm_key, batch_size)
|
||||||
|
batch = (traj_batch, advantages, targets)
|
||||||
|
batch = jax.tree_util.tree_map(
|
||||||
|
lambda x: x.reshape((batch_size,) + x.shape[2:]),
|
||||||
|
batch,
|
||||||
|
)
|
||||||
|
shuffled = jax.tree_util.tree_map(
|
||||||
|
lambda x: jnp.take(x, permutation, axis=0),
|
||||||
|
batch,
|
||||||
|
)
|
||||||
|
minibatches = jax.tree_util.tree_map(
|
||||||
|
lambda x: x.reshape(
|
||||||
|
(cfg["num_minibatches"], cfg["minibatch_size"]) + x.shape[1:]
|
||||||
|
),
|
||||||
|
shuffled,
|
||||||
|
)
|
||||||
|
train_state, _ = jax.lax.scan(
|
||||||
|
_update_minibatch, train_state, minibatches
|
||||||
|
)
|
||||||
|
return (train_state, traj_batch, advantages, targets, rng), None
|
||||||
|
|
||||||
|
update_state = (train_state, traj_batch, advantages, targets, rng)
|
||||||
|
update_state, _ = jax.lax.scan(
|
||||||
|
_update_epoch,
|
||||||
|
update_state,
|
||||||
|
None,
|
||||||
|
length=cfg["update_epochs"],
|
||||||
|
)
|
||||||
|
train_state = update_state[0]
|
||||||
|
rng = update_state[-1]
|
||||||
|
|
||||||
|
metric = {
|
||||||
|
"reward": jnp.mean(traj_batch.reward),
|
||||||
|
"revenue": jnp.mean(traj_batch.info["revenue"]),
|
||||||
|
"agent_prob": jnp.mean(traj_batch.info["agent_prob"]),
|
||||||
|
"alpha_adv": jnp.mean(traj_batch.info["alpha_adv"]),
|
||||||
|
"coi_leakage": jnp.mean(traj_batch.info["coi_leakage"]),
|
||||||
|
}
|
||||||
|
runner_state = (train_state, env_state, last_obs, rng)
|
||||||
|
return runner_state, metric
|
||||||
|
|
||||||
|
runner_state = (train_state, env_state, obs, rng)
|
||||||
|
runner_state, metric = jax.lax.scan(
|
||||||
|
_update_step,
|
||||||
|
runner_state,
|
||||||
|
None,
|
||||||
|
length=cfg["num_updates"],
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"runner_state": runner_state,
|
||||||
|
"metrics": metric,
|
||||||
|
}
|
||||||
|
|
||||||
|
return train, network, env, cfg
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate_policy(
|
||||||
|
*,
|
||||||
|
network: ActorCritic,
|
||||||
|
params: Any,
|
||||||
|
env: PHANTOMJAXEnv,
|
||||||
|
episodes: int,
|
||||||
|
seed: int,
|
||||||
|
) -> dict[str, float]:
|
||||||
|
rewards: list[float] = []
|
||||||
|
revenues: list[float] = []
|
||||||
|
key = jax.random.PRNGKey(seed)
|
||||||
|
|
||||||
|
for _ in range(int(episodes)):
|
||||||
|
key, reset_key = jax.random.split(key)
|
||||||
|
obs, state = env.reset(reset_key)
|
||||||
|
ep_reward = 0.0
|
||||||
|
ep_revenue = 0.0
|
||||||
|
done = False
|
||||||
|
steps = 0
|
||||||
|
|
||||||
|
while not done and steps < int(env.params.max_episode_steps):
|
||||||
|
policy, _ = network.apply(params, obs)
|
||||||
|
action = jnp.argmax(policy.logits)
|
||||||
|
key, step_key = jax.random.split(key)
|
||||||
|
obs, state, reward, done_flag, info = env.step(step_key, state, action)
|
||||||
|
ep_reward += float(np.asarray(reward))
|
||||||
|
ep_revenue += float(np.asarray(info["revenue"]))
|
||||||
|
done = bool(np.asarray(done_flag))
|
||||||
|
steps += 1
|
||||||
|
|
||||||
|
rewards.append(ep_reward)
|
||||||
|
revenues.append(ep_revenue)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"eval/reward": float(np.mean(rewards)),
|
||||||
|
"eval/revenue": float(np.mean(revenues)),
|
||||||
|
"eval/reward_std": float(np.std(rewards)),
|
||||||
|
"eval/revenue_std": float(np.std(revenues)),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def train_jax(cfg: dict[str, Any]) -> tuple[dict[str, Any], dict[str, float]]:
|
||||||
|
if not HAS_JAX_STACK:
|
||||||
|
raise ImportError(
|
||||||
|
"JAX PPO path requires jax, flax, optax, and distrax. "
|
||||||
|
"Install engine/jax/requirements.txt on this machine first."
|
||||||
|
)
|
||||||
|
|
||||||
|
run_cfg = _jax_cfg(cfg)
|
||||||
|
if run_cfg["algo"] != "ppo":
|
||||||
|
raise ValueError(
|
||||||
|
f"JAX backend currently supports algo='ppo' only, got '{run_cfg['algo']}'"
|
||||||
|
)
|
||||||
|
|
||||||
|
train_fn, network, env, run_cfg = make_train(run_cfg)
|
||||||
|
train_jit = jax.jit(train_fn)
|
||||||
|
rng = jax.random.PRNGKey(run_cfg["seed"])
|
||||||
|
out = train_jit(rng)
|
||||||
|
|
||||||
|
train_state = out["runner_state"][0]
|
||||||
|
metric = out["metrics"]
|
||||||
|
metrics = {
|
||||||
|
"train/reward": float(np.mean(np.asarray(metric["reward"]))),
|
||||||
|
"train/revenue": float(np.mean(np.asarray(metric["revenue"]))),
|
||||||
|
"train/agent_prob": float(np.mean(np.asarray(metric["agent_prob"]))),
|
||||||
|
"train/alpha_adv": float(np.mean(np.asarray(metric["alpha_adv"]))),
|
||||||
|
"train/coi_leakage": float(np.mean(np.asarray(metric["coi_leakage"]))),
|
||||||
|
"train/global_step": int(
|
||||||
|
run_cfg["num_updates"] * run_cfg["num_steps"] * run_cfg["num_envs"]
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
eval_metrics = evaluate_policy(
|
||||||
|
network=network,
|
||||||
|
params=train_state.params,
|
||||||
|
env=env,
|
||||||
|
episodes=run_cfg["eval_episodes"],
|
||||||
|
seed=run_cfg["seed"] + 7,
|
||||||
|
)
|
||||||
|
metrics.update(eval_metrics)
|
||||||
|
|
||||||
|
model_dir = Path(run_cfg["model_dir"])
|
||||||
|
model_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
model_path = model_dir / "phantom_ppo_jax.msgpack"
|
||||||
|
model_path.write_bytes(serialization.to_bytes(train_state.params))
|
||||||
|
metrics["model/path"] = str(model_path)
|
||||||
|
return {"params": train_state.params}, metrics
|
||||||
119
engine/lib/callbacks.py
Normal file
119
engine/lib/callbacks.py
Normal file
@@ -0,0 +1,119 @@
|
|||||||
|
"""Training callbacks for W&B/TensorBoard logging - reads from info dict."""
|
||||||
|
|
||||||
|
from stable_baselines3.common.callbacks import BaseCallback, EvalCallback
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
try:
|
||||||
|
import wandb
|
||||||
|
|
||||||
|
HAS_WANDB = True
|
||||||
|
except ImportError:
|
||||||
|
HAS_WANDB = False
|
||||||
|
|
||||||
|
|
||||||
|
class MetricsCallback(BaseCallback):
|
||||||
|
"""Training metrics logger - reads info['economics'], logs to W&B."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, log_histograms: bool = True, log_freq: int = 100, verbose: int = 0
|
||||||
|
):
|
||||||
|
super().__init__(verbose)
|
||||||
|
self.log_histograms = log_histograms
|
||||||
|
self.log_freq = log_freq
|
||||||
|
self._episode_revenues: list[float] = []
|
||||||
|
|
||||||
|
def _on_step(self) -> bool:
|
||||||
|
if not HAS_WANDB or wandb.run is None:
|
||||||
|
return True
|
||||||
|
|
||||||
|
for info in self.locals.get("infos", []):
|
||||||
|
if "economics" not in info:
|
||||||
|
continue
|
||||||
|
|
||||||
|
econ = info["economics"]
|
||||||
|
t = self.num_timesteps
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"economics/revenue": econ["revenue"],
|
||||||
|
"economics/margin": econ["margin"],
|
||||||
|
"coi/level": econ["coi_level"],
|
||||||
|
"economics/regret": econ["regret"],
|
||||||
|
}
|
||||||
|
if "coi_mix" in econ:
|
||||||
|
payload["coi/mix"] = econ["coi_mix"]
|
||||||
|
if "coi_base" in econ:
|
||||||
|
payload["coi/base"] = econ["coi_base"]
|
||||||
|
if "coi_leakage" in econ:
|
||||||
|
payload["coi/leakage"] = econ["coi_leakage"]
|
||||||
|
if "coi_penalty" in econ:
|
||||||
|
payload["coi/penalty"] = econ["coi_penalty"]
|
||||||
|
wandb.log(payload, step=t)
|
||||||
|
|
||||||
|
self._episode_revenues.append(econ["revenue"])
|
||||||
|
|
||||||
|
# histograms at log_freq intervals
|
||||||
|
if self.log_histograms and self.num_timesteps % self.log_freq == 0:
|
||||||
|
for info in self.locals.get("infos", []):
|
||||||
|
if "prices" in info:
|
||||||
|
wandb.log(
|
||||||
|
{"distributions/prices": wandb.Histogram(info["prices"])},
|
||||||
|
step=self.num_timesteps,
|
||||||
|
)
|
||||||
|
if "demand" in info:
|
||||||
|
wandb.log(
|
||||||
|
{"distributions/demand": wandb.Histogram(info["demand"])},
|
||||||
|
step=self.num_timesteps,
|
||||||
|
)
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _on_rollout_end(self) -> None:
|
||||||
|
if not HAS_WANDB or wandb.run is None or not self._episode_revenues:
|
||||||
|
return
|
||||||
|
wandb.log(
|
||||||
|
{
|
||||||
|
"episode/mean_revenue": np.mean(self._episode_revenues),
|
||||||
|
"episode/total_revenue": np.sum(self._episode_revenues),
|
||||||
|
},
|
||||||
|
step=self.num_timesteps,
|
||||||
|
)
|
||||||
|
self._episode_revenues = []
|
||||||
|
|
||||||
|
|
||||||
|
class EvalMetricsCallback(EvalCallback):
|
||||||
|
"""Deterministic evaluation - true performance without exploration noise."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, eval_env, eval_freq: int = 1000, n_eval_episodes: int = 5, **kwargs
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
eval_env, eval_freq=eval_freq, n_eval_episodes=n_eval_episodes, **kwargs
|
||||||
|
)
|
||||||
|
self._eval_revenues: list[float] = []
|
||||||
|
|
||||||
|
def _on_step(self) -> bool:
|
||||||
|
result = super()._on_step()
|
||||||
|
|
||||||
|
if not HAS_WANDB or wandb.run is None:
|
||||||
|
return result
|
||||||
|
|
||||||
|
# log eval metrics after evaluation runs
|
||||||
|
if self.n_calls % self.eval_freq == 0 and hasattr(self, "last_mean_reward"):
|
||||||
|
wandb.log(
|
||||||
|
{
|
||||||
|
"eval/mean_reward": self.last_mean_reward,
|
||||||
|
"eval/mean_revenue": np.mean(self._eval_revenues)
|
||||||
|
if self._eval_revenues
|
||||||
|
else 0,
|
||||||
|
},
|
||||||
|
step=self.num_timesteps,
|
||||||
|
)
|
||||||
|
self._eval_revenues = []
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def _log_success_callback(self, locals_: dict, globals_: dict) -> None:
|
||||||
|
# called after each eval episode
|
||||||
|
info = locals_.get("info", {})
|
||||||
|
if "economics" in info:
|
||||||
|
self._eval_revenues.append(info["economics"]["revenue"])
|
||||||
76
engine/lib/coi.py
Normal file
76
engine/lib/coi.py
Normal file
@@ -0,0 +1,76 @@
|
|||||||
|
import numpy as np
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
|
||||||
|
def compute_agent_probability(
|
||||||
|
trajectory: list, human_transitions: Dict, agent_transitions: Dict
|
||||||
|
) -> float:
|
||||||
|
"""estimate agent probability via KL divergence between trajectory transitions and reference models
|
||||||
|
|
||||||
|
compares empirical trajectory transition distribution to human/agent prototypes
|
||||||
|
|
||||||
|
args:
|
||||||
|
trajectory: list of state/event strings from session
|
||||||
|
human_transitions: reference transition dict from human MDP (event->event->prob)
|
||||||
|
agent_transitions: reference transition dict from agent MDP (event->event->prob)
|
||||||
|
|
||||||
|
returns:
|
||||||
|
agent probability in [0, 1] via softmax over KL divergences
|
||||||
|
"""
|
||||||
|
if len(trajectory) < 2:
|
||||||
|
return 0.0 # insufficient data, assume human
|
||||||
|
|
||||||
|
# build empirical transition distribution from trajectory
|
||||||
|
trans_counts = {}
|
||||||
|
for s, s_next in zip(trajectory[:-1], trajectory[1:]):
|
||||||
|
if s not in trans_counts:
|
||||||
|
trans_counts[s] = {}
|
||||||
|
trans_counts[s][s_next] = trans_counts[s].get(s_next, 0) + 1
|
||||||
|
|
||||||
|
# normalize to probabilities
|
||||||
|
empirical = {}
|
||||||
|
for s, nxt in trans_counts.items():
|
||||||
|
total = sum(nxt.values())
|
||||||
|
empirical[s] = {s_n: cnt / total for s_n, cnt in nxt.items()}
|
||||||
|
|
||||||
|
# compute KL divergence to each prototype
|
||||||
|
def kl_div(p_dist: Dict, q_dist: Dict) -> float:
|
||||||
|
eps = 1e-10
|
||||||
|
# aggregate over all source states in empirical dist
|
||||||
|
kl = 0.0
|
||||||
|
for s in p_dist:
|
||||||
|
if s not in q_dist:
|
||||||
|
continue # skip states not in reference
|
||||||
|
p_trans, q_trans = p_dist[s], q_dist[s]
|
||||||
|
for k in p_trans:
|
||||||
|
p_val = p_trans[k] + eps
|
||||||
|
q_val = q_trans.get(k, 0.0) + eps
|
||||||
|
kl += p_val * np.log(p_val / q_val)
|
||||||
|
return kl
|
||||||
|
|
||||||
|
kl_human = kl_div(empirical, human_transitions)
|
||||||
|
kl_agent = kl_div(empirical, agent_transitions)
|
||||||
|
|
||||||
|
# convert to probability via softmax (lower KL = higher prob)
|
||||||
|
# agent_prob = exp(-kl_agent) / (exp(-kl_human) + exp(-kl_agent))
|
||||||
|
exp_h = np.exp(-kl_human)
|
||||||
|
exp_a = np.exp(-kl_agent)
|
||||||
|
return float(exp_a / (exp_h + exp_a + 1e-10))
|
||||||
|
|
||||||
|
|
||||||
|
def extract_purchases(trajectories: list) -> Dict[int, int]:
|
||||||
|
purchases: Dict[int, int] = {}
|
||||||
|
for traj in trajectories:
|
||||||
|
if traj and "checkout" in traj[-1] and "_product" in traj[-1]:
|
||||||
|
prod_id = int(traj[-1].rsplit("_product", 1)[1])
|
||||||
|
purchases[prod_id] = purchases.get(prod_id, 0) + 1
|
||||||
|
return purchases
|
||||||
|
|
||||||
|
|
||||||
|
def compute_uplift_coi(
|
||||||
|
prices: np.ndarray, purchases: Dict[int, int], baseline_prices: np.ndarray
|
||||||
|
) -> float:
|
||||||
|
# TODO: consider view-weighted fractional purchase for denser signal
|
||||||
|
return float(
|
||||||
|
sum(max(0.0, prices[k] - baseline_prices[k]) * n for k, n in purchases.items())
|
||||||
|
)
|
||||||
70
engine/lib/discrete.py
Normal file
70
engine/lib/discrete.py
Normal file
@@ -0,0 +1,70 @@
|
|||||||
|
from collections import defaultdict
|
||||||
|
import gymnasium as gym
|
||||||
|
from gymnasium import spaces
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
class DiscretePriceActionWrapper(gym.ActionWrapper):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
env: gym.Env,
|
||||||
|
n_levels: int = 9,
|
||||||
|
min_scale: float = 0.8,
|
||||||
|
max_scale: float = 1.2,
|
||||||
|
):
|
||||||
|
super().__init__(env)
|
||||||
|
self.scales = np.linspace(min_scale, max_scale, n_levels, dtype=np.float32)
|
||||||
|
self.action_space = spaces.Discrete(n_levels)
|
||||||
|
|
||||||
|
def action(self, action: int):
|
||||||
|
scale = float(self.scales[int(action)])
|
||||||
|
cur = np.asarray(self.env.unwrapped._prices, dtype=np.float32)
|
||||||
|
lo, hi = self.env.unwrapped.price_bounds
|
||||||
|
return np.clip(cur * scale, lo, hi).astype(np.float32)
|
||||||
|
|
||||||
|
|
||||||
|
class EventQTable:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
n_actions: int,
|
||||||
|
n_products: int,
|
||||||
|
price_bounds: tuple,
|
||||||
|
lr: float = 0.1,
|
||||||
|
gamma: float = 0.99,
|
||||||
|
n_bins: int = 6,
|
||||||
|
):
|
||||||
|
self.n_actions = int(n_actions)
|
||||||
|
self.n_products = int(n_products)
|
||||||
|
self.lr = float(lr)
|
||||||
|
self.gamma = float(gamma)
|
||||||
|
self.q = defaultdict(lambda: np.zeros(self.n_actions, dtype=np.float32))
|
||||||
|
lo, hi = price_bounds
|
||||||
|
self.demand_bins = np.linspace(0.0, 100.0, n_bins + 1)[1:-1]
|
||||||
|
self.price_bins = np.linspace(lo, hi, n_bins + 1)[1:-1]
|
||||||
|
|
||||||
|
def encode(self, obs: np.ndarray) -> tuple:
|
||||||
|
obs = np.asarray(obs, dtype=np.float32)
|
||||||
|
d = obs[: self.n_products]
|
||||||
|
p = obs[self.n_products : 2 * self.n_products]
|
||||||
|
d_mean = float(np.mean(d)) if d.size else 0.0
|
||||||
|
d_std = float(np.std(d)) if d.size else 0.0
|
||||||
|
p_mean = float(np.mean(p)) if p.size else 0.0
|
||||||
|
return (
|
||||||
|
int(np.digitize(d_mean, self.demand_bins)),
|
||||||
|
int(np.digitize(d_std, self.demand_bins)),
|
||||||
|
int(np.digitize(p_mean, self.price_bins)),
|
||||||
|
)
|
||||||
|
|
||||||
|
def act(self, obs: np.ndarray, eps: float = 0.0) -> tuple[int, tuple]:
|
||||||
|
s = self.encode(obs)
|
||||||
|
if np.random.random() < eps:
|
||||||
|
return int(np.random.randint(self.n_actions)), s
|
||||||
|
return int(np.argmax(self.q[s])), s
|
||||||
|
|
||||||
|
def update(self, s: tuple, a: int, r: float, s2: tuple, done: bool):
|
||||||
|
target = r + (0.0 if done else self.gamma * float(np.max(self.q[s2])))
|
||||||
|
self.q[s][a] += self.lr * (target - self.q[s][a])
|
||||||
|
|
||||||
|
def predict(self, obs: np.ndarray, deterministic: bool = True):
|
||||||
|
a, _ = self.act(obs, 0.0 if deterministic else 0.05)
|
||||||
|
return a, None
|
||||||
182
engine/lib/providers.py
Normal file
182
engine/lib/providers.py
Normal file
@@ -0,0 +1,182 @@
|
|||||||
|
"""Provider benchmarking - compare pricing strategies across contamination levels."""
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Callable, Any
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
try:
|
||||||
|
import wandb
|
||||||
|
|
||||||
|
HAS_WANDB = True
|
||||||
|
except ImportError:
|
||||||
|
HAS_WANDB = False
|
||||||
|
|
||||||
|
|
||||||
|
class RandomBaseline:
|
||||||
|
"""uniform random action selection as a lower-bound baseline"""
|
||||||
|
|
||||||
|
def __init__(self, n_actions: int):
|
||||||
|
self.n = n_actions
|
||||||
|
|
||||||
|
def __call__(self, obs):
|
||||||
|
return int(np.random.randint(self.n))
|
||||||
|
|
||||||
|
def predict(self, obs, **kw):
|
||||||
|
return self(obs), None
|
||||||
|
|
||||||
|
|
||||||
|
class SurgeBaseline:
|
||||||
|
"""heuristic surge pricing: boost price when demand is above threshold, discount when below.
|
||||||
|
matches the naive pricing rule from thesis Section 3.3.2"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, n_actions: int, high_threshold: float = 60.0, low_threshold: float = 30.0
|
||||||
|
):
|
||||||
|
self.n = n_actions
|
||||||
|
self.mid = n_actions // 2 # identity action (scale ~1.0)
|
||||||
|
self.high_t = high_threshold
|
||||||
|
self.low_t = low_threshold
|
||||||
|
|
||||||
|
def __call__(self, obs):
|
||||||
|
obs = np.asarray(obs, dtype=np.float32)
|
||||||
|
n_prod = len(obs) // 2
|
||||||
|
demand_mean = float(np.mean(obs[:n_prod])) if n_prod > 0 else 0.0
|
||||||
|
if demand_mean >= self.high_t:
|
||||||
|
return min(self.mid + 2, self.n - 1) # surge: two levels above identity
|
||||||
|
if demand_mean <= self.low_t:
|
||||||
|
return max(self.mid - 2, 0) # discount: two levels below identity
|
||||||
|
return self.mid # hold
|
||||||
|
|
||||||
|
def predict(self, obs, **kw):
|
||||||
|
return self(obs), None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ProviderResult:
|
||||||
|
"""Single benchmark result for one provider at one alpha level."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
alpha: float
|
||||||
|
total_revenue: float
|
||||||
|
mean_revenue: float
|
||||||
|
coi_level: float
|
||||||
|
coi_preserved_pct: float # vs alpha=0 baseline
|
||||||
|
margin_integrity: float
|
||||||
|
regret: float
|
||||||
|
episodes: int
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BenchmarkConfig:
|
||||||
|
"""Configuration for provider benchmark runs."""
|
||||||
|
|
||||||
|
n_episodes: int = 100
|
||||||
|
alpha_range: list[float] = field(default_factory=lambda: [0.0, 0.1, 0.3, 0.5])
|
||||||
|
baseline_name: str = "fixed"
|
||||||
|
|
||||||
|
|
||||||
|
class ProviderBenchmark:
|
||||||
|
"""Compare pricing providers to prove margin preservation across contamination levels.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
def env_factory(alpha):
|
||||||
|
return EconomicMetricsWrapper(PHANTOM(alpha=alpha))
|
||||||
|
|
||||||
|
providers = {
|
||||||
|
"fixed": lambda obs: np.ones(10) * 50,
|
||||||
|
"learned": model.predict,
|
||||||
|
}
|
||||||
|
|
||||||
|
benchmark = ProviderBenchmark(env_factory, providers)
|
||||||
|
results = benchmark.run()
|
||||||
|
print(benchmark.summary_table())
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
env_factory: Callable[[float], Any],
|
||||||
|
providers: dict[str, Callable],
|
||||||
|
config: BenchmarkConfig | None = None,
|
||||||
|
):
|
||||||
|
self.env_factory = env_factory # fn(alpha) -> wrapped env
|
||||||
|
self.providers = providers # {name: fn(obs) -> action}
|
||||||
|
self.config = config or BenchmarkConfig()
|
||||||
|
self.results: list[ProviderResult] = []
|
||||||
|
|
||||||
|
def run(self) -> list[ProviderResult]:
|
||||||
|
"""Run benchmark across all providers and alpha levels."""
|
||||||
|
baseline_coi: dict[str, float] = {} # {provider: coi at alpha=0}
|
||||||
|
|
||||||
|
for alpha in self.config.alpha_range:
|
||||||
|
env = self.env_factory(alpha)
|
||||||
|
|
||||||
|
for name, policy_fn in self.providers.items():
|
||||||
|
revenues, coi_levels, margins = [], [], []
|
||||||
|
|
||||||
|
for _ in range(self.config.n_episodes):
|
||||||
|
obs, _ = env.reset()
|
||||||
|
episode_revenue = 0.0
|
||||||
|
done = False
|
||||||
|
|
||||||
|
while not done:
|
||||||
|
action = policy_fn(obs)
|
||||||
|
# handle sb3 model.predict returning tuple
|
||||||
|
if isinstance(action, tuple):
|
||||||
|
action = action[0]
|
||||||
|
obs, reward, term, trunc, info = env.step(action)
|
||||||
|
done = term or trunc
|
||||||
|
|
||||||
|
econ = info.get("economics", {})
|
||||||
|
episode_revenue += econ.get("revenue", 0)
|
||||||
|
coi_levels.append(econ.get("coi_level", 0))
|
||||||
|
margins.append(econ.get("margin", 0))
|
||||||
|
|
||||||
|
revenues.append(episode_revenue)
|
||||||
|
|
||||||
|
mean_coi = np.mean(coi_levels) if coi_levels else 0.0
|
||||||
|
if alpha == 0.0:
|
||||||
|
baseline_coi[name] = mean_coi
|
||||||
|
|
||||||
|
base = baseline_coi.get(name, mean_coi)
|
||||||
|
coi_preserved = mean_coi / base if base > 0 else 1.0
|
||||||
|
|
||||||
|
result = ProviderResult(
|
||||||
|
name=name,
|
||||||
|
alpha=alpha,
|
||||||
|
total_revenue=float(np.sum(revenues)),
|
||||||
|
mean_revenue=float(np.mean(revenues)),
|
||||||
|
coi_level=mean_coi,
|
||||||
|
coi_preserved_pct=coi_preserved * 100,
|
||||||
|
margin_integrity=float(np.mean(margins)) if margins else 0.0,
|
||||||
|
regret=0.0, # compute vs optimal if known
|
||||||
|
episodes=self.config.n_episodes,
|
||||||
|
)
|
||||||
|
self.results.append(result)
|
||||||
|
|
||||||
|
# log to wandb if available
|
||||||
|
if HAS_WANDB and wandb.run is not None:
|
||||||
|
wandb.log(
|
||||||
|
{
|
||||||
|
f"benchmark/{name}/revenue": result.mean_revenue,
|
||||||
|
f"benchmark/{name}/coi_preserved": result.coi_preserved_pct,
|
||||||
|
f"benchmark/{name}/margin": result.margin_integrity,
|
||||||
|
"benchmark/alpha": alpha,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.results
|
||||||
|
|
||||||
|
def to_dataframe(self) -> pd.DataFrame:
|
||||||
|
"""Convert results to pandas DataFrame."""
|
||||||
|
return pd.DataFrame([r.__dict__ for r in self.results])
|
||||||
|
|
||||||
|
def summary_table(self) -> pd.DataFrame:
|
||||||
|
"""Pivot table: providers x alpha with revenue/COI metrics."""
|
||||||
|
df = self.to_dataframe()
|
||||||
|
return df.pivot_table(
|
||||||
|
index="name",
|
||||||
|
columns="alpha",
|
||||||
|
values=["mean_revenue", "coi_preserved_pct", "margin_integrity"],
|
||||||
|
aggfunc="mean",
|
||||||
|
)
|
||||||
77
engine/lib/wrappers.py
Normal file
77
engine/lib/wrappers.py
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
"""Economic metrics wrapper - calculates thesis-aligned KPIs and injects into info dict."""
|
||||||
|
|
||||||
|
import gymnasium as gym
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
class EconomicMetricsWrapper(gym.Wrapper):
|
||||||
|
"""Calculates thesis-aligned economic metrics per step, injects into info.
|
||||||
|
|
||||||
|
Metrics follow thesis definitions:
|
||||||
|
- COI level: E[P] - p_min (Definition 1)
|
||||||
|
- Margin: (avg_price - p_min) / avg_price
|
||||||
|
- Regret: 1 - (revenue / baseline_revenue)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, env: gym.Env, p_min: float = 10.0, baseline_revenue: float | None = None
|
||||||
|
):
|
||||||
|
super().__init__(env)
|
||||||
|
self.p_min = p_min
|
||||||
|
self.baseline_revenue = baseline_revenue
|
||||||
|
self._price_history: list[np.ndarray] = []
|
||||||
|
self._revenue_history: list[float] = []
|
||||||
|
|
||||||
|
def reset(self, **kwargs):
|
||||||
|
obs, info = self.env.reset(**kwargs)
|
||||||
|
self._price_history = []
|
||||||
|
self._revenue_history = []
|
||||||
|
return obs, info
|
||||||
|
|
||||||
|
def step(self, action):
|
||||||
|
obs, reward, terminated, truncated, info = self.env.step(action)
|
||||||
|
|
||||||
|
# extract from unwrapped env
|
||||||
|
prices = self.env.unwrapped._prices
|
||||||
|
demand_dict = self.env.unwrapped._demand
|
||||||
|
demand = np.array([demand_dict.get(i, 0.0) for i in range(len(prices))])
|
||||||
|
alpha = self.env.unwrapped.alpha
|
||||||
|
|
||||||
|
# core calculations
|
||||||
|
revenue = float(np.sum(prices * demand))
|
||||||
|
avg_price = float(np.mean(prices))
|
||||||
|
margin = (avg_price - self.p_min) / max(avg_price, 1e-6)
|
||||||
|
coi_level = avg_price - self.p_min # E[P] - p_min per thesis Def 1
|
||||||
|
|
||||||
|
self._price_history.append(prices.copy())
|
||||||
|
self._revenue_history.append(revenue)
|
||||||
|
|
||||||
|
# regret vs baseline (golden path)
|
||||||
|
regret = 0.0
|
||||||
|
if self.baseline_revenue and self.baseline_revenue > 0:
|
||||||
|
regret = 1.0 - (revenue / self.baseline_revenue)
|
||||||
|
|
||||||
|
# inject structured metrics into info
|
||||||
|
info["economics"] = {
|
||||||
|
"revenue": revenue,
|
||||||
|
"margin": margin,
|
||||||
|
"coi_level": coi_level,
|
||||||
|
"regret": regret,
|
||||||
|
}
|
||||||
|
for key in ("coi_mix", "coi_base", "coi_leakage", "coi_penalty"):
|
||||||
|
if key in info:
|
||||||
|
info["economics"][key] = info[key]
|
||||||
|
info["prices"] = prices.copy()
|
||||||
|
info["demand"] = demand.copy()
|
||||||
|
|
||||||
|
return obs, reward, terminated, truncated, info
|
||||||
|
|
||||||
|
@property
|
||||||
|
def episode_revenue(self) -> float:
|
||||||
|
return sum(self._revenue_history)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def episode_mean_price(self) -> float:
|
||||||
|
if not self._price_history:
|
||||||
|
return 0.0
|
||||||
|
return float(np.mean([np.mean(p) for p in self._price_history]))
|
||||||
84
engine/sweeps/model_mix.yaml
Normal file
84
engine/sweeps/model_mix.yaml
Normal file
@@ -0,0 +1,84 @@
|
|||||||
|
method: random
|
||||||
|
metric:
|
||||||
|
name: sweep/score
|
||||||
|
goal: maximize
|
||||||
|
command:
|
||||||
|
- ${env}
|
||||||
|
- python
|
||||||
|
- -m
|
||||||
|
- engine.train
|
||||||
|
parameters:
|
||||||
|
algo:
|
||||||
|
values: [ppo, a2c, dqn, qtable]
|
||||||
|
total_timesteps:
|
||||||
|
values: [30000, 50000, 80000]
|
||||||
|
seed:
|
||||||
|
values: [13, 42, 77]
|
||||||
|
n_products:
|
||||||
|
values: [8, 10, 12]
|
||||||
|
alpha:
|
||||||
|
distribution: uniform
|
||||||
|
min: 0.1
|
||||||
|
max: 0.6
|
||||||
|
lambda_coi:
|
||||||
|
distribution: uniform
|
||||||
|
min: 0.05
|
||||||
|
max: 0.6
|
||||||
|
robust_radius:
|
||||||
|
distribution: uniform
|
||||||
|
min: 0.0
|
||||||
|
max: 0.3
|
||||||
|
robust_points:
|
||||||
|
values: [3, 5, 7]
|
||||||
|
info_value:
|
||||||
|
distribution: uniform
|
||||||
|
min: 0.5
|
||||||
|
max: 2.0
|
||||||
|
revenue_weight:
|
||||||
|
values: [0.005, 0.01, 0.02]
|
||||||
|
learning_rate:
|
||||||
|
distribution: log_uniform_values
|
||||||
|
min: 1.0e-5
|
||||||
|
max: 1.0e-3
|
||||||
|
gamma:
|
||||||
|
values: [0.97, 0.99, 0.995]
|
||||||
|
buffer_size:
|
||||||
|
values: [20000, 50000, 100000]
|
||||||
|
batch_size:
|
||||||
|
values: [128, 256, 512]
|
||||||
|
tau:
|
||||||
|
values: [0.002, 0.005, 0.01]
|
||||||
|
train_freq:
|
||||||
|
values: [1, 4, 8]
|
||||||
|
learning_starts:
|
||||||
|
values: [500, 1000, 3000]
|
||||||
|
n_steps:
|
||||||
|
values: [512, 1024, 2048]
|
||||||
|
n_epochs:
|
||||||
|
values: [5, 10, 20]
|
||||||
|
gae_lambda:
|
||||||
|
values: [0.9, 0.95, 0.98]
|
||||||
|
clip_range:
|
||||||
|
values: [0.1, 0.2, 0.3]
|
||||||
|
ent_coef:
|
||||||
|
values: [0.0, 0.005, 0.01]
|
||||||
|
target_update_interval:
|
||||||
|
values: [500, 1000, 2000]
|
||||||
|
exploration_fraction:
|
||||||
|
values: [0.1, 0.2, 0.3]
|
||||||
|
exploration_final_eps:
|
||||||
|
values: [0.01, 0.03, 0.05]
|
||||||
|
action_levels:
|
||||||
|
values: [7, 9, 11]
|
||||||
|
action_scale_low:
|
||||||
|
values: [0.75, 0.8, 0.85]
|
||||||
|
action_scale_high:
|
||||||
|
values: [1.15, 1.2, 1.25]
|
||||||
|
q_lr:
|
||||||
|
values: [0.03, 0.05, 0.1, 0.2]
|
||||||
|
eps_start:
|
||||||
|
value: 1.0
|
||||||
|
eps_end:
|
||||||
|
values: [0.02, 0.05, 0.1]
|
||||||
|
eps_decay:
|
||||||
|
values: [0.999, 0.9995, 0.9999]
|
||||||
85
engine/sweeps/models_only.yaml
Normal file
85
engine/sweeps/models_only.yaml
Normal file
@@ -0,0 +1,85 @@
|
|||||||
|
method: grid
|
||||||
|
metric:
|
||||||
|
name: sweep/score
|
||||||
|
goal: maximize
|
||||||
|
run_cap: 4
|
||||||
|
command:
|
||||||
|
- ${env}
|
||||||
|
- python
|
||||||
|
- -m
|
||||||
|
- engine.train
|
||||||
|
parameters:
|
||||||
|
algo:
|
||||||
|
values: [ppo, a2c, dqn, qtable]
|
||||||
|
seed:
|
||||||
|
value: 42
|
||||||
|
total_timesteps:
|
||||||
|
value: 12000
|
||||||
|
eval_episodes:
|
||||||
|
value: 3
|
||||||
|
eval_freq:
|
||||||
|
value: 500
|
||||||
|
log_freq:
|
||||||
|
value: 100
|
||||||
|
revenue_weight:
|
||||||
|
value: 0.01
|
||||||
|
n_products:
|
||||||
|
value: 8
|
||||||
|
N:
|
||||||
|
value: 80
|
||||||
|
alpha:
|
||||||
|
value: 0.3
|
||||||
|
lambda_coi:
|
||||||
|
value: 0.2
|
||||||
|
robust_radius:
|
||||||
|
value: 0.0
|
||||||
|
robust_points:
|
||||||
|
value: 1
|
||||||
|
info_value:
|
||||||
|
value: 1.0
|
||||||
|
learning_rate:
|
||||||
|
value: 0.0003
|
||||||
|
gamma:
|
||||||
|
value: 0.99
|
||||||
|
buffer_size:
|
||||||
|
value: 20000
|
||||||
|
batch_size:
|
||||||
|
value: 128
|
||||||
|
tau:
|
||||||
|
value: 0.005
|
||||||
|
train_freq:
|
||||||
|
value: 1
|
||||||
|
learning_starts:
|
||||||
|
value: 500
|
||||||
|
n_steps:
|
||||||
|
value: 512
|
||||||
|
n_epochs:
|
||||||
|
value: 10
|
||||||
|
gae_lambda:
|
||||||
|
value: 0.95
|
||||||
|
clip_range:
|
||||||
|
value: 0.2
|
||||||
|
ent_coef:
|
||||||
|
value: 0.0
|
||||||
|
target_update_interval:
|
||||||
|
value: 500
|
||||||
|
exploration_fraction:
|
||||||
|
value: 0.2
|
||||||
|
exploration_final_eps:
|
||||||
|
value: 0.05
|
||||||
|
action_levels:
|
||||||
|
value: 7
|
||||||
|
action_scale_low:
|
||||||
|
value: 0.9
|
||||||
|
action_scale_high:
|
||||||
|
value: 1.1
|
||||||
|
q_lr:
|
||||||
|
value: 0.1
|
||||||
|
q_bins:
|
||||||
|
value: 6
|
||||||
|
eps_start:
|
||||||
|
value: 1.0
|
||||||
|
eps_end:
|
||||||
|
value: 0.05
|
||||||
|
eps_decay:
|
||||||
|
value: 0.9995
|
||||||
54
engine/sweeps/sac_tune.yaml
Normal file
54
engine/sweeps/sac_tune.yaml
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
method: bayes
|
||||||
|
metric:
|
||||||
|
name: sweep/score
|
||||||
|
goal: maximize
|
||||||
|
command:
|
||||||
|
- ${env}
|
||||||
|
- python
|
||||||
|
- -m
|
||||||
|
- engine.train
|
||||||
|
parameters:
|
||||||
|
algo:
|
||||||
|
value: sac
|
||||||
|
total_timesteps:
|
||||||
|
values: [50000, 80000, 120000]
|
||||||
|
seed:
|
||||||
|
values: [13, 42, 77]
|
||||||
|
alpha:
|
||||||
|
distribution: uniform
|
||||||
|
min: 0.15
|
||||||
|
max: 0.55
|
||||||
|
n_products:
|
||||||
|
values: [8, 10, 12]
|
||||||
|
lambda_coi:
|
||||||
|
distribution: uniform
|
||||||
|
min: 0.05
|
||||||
|
max: 0.5
|
||||||
|
robust_radius:
|
||||||
|
distribution: uniform
|
||||||
|
min: 0.05
|
||||||
|
max: 0.3
|
||||||
|
robust_points:
|
||||||
|
values: [3, 5, 7]
|
||||||
|
info_value:
|
||||||
|
distribution: uniform
|
||||||
|
min: 0.5
|
||||||
|
max: 2.0
|
||||||
|
revenue_weight:
|
||||||
|
values: [0.005, 0.01, 0.02]
|
||||||
|
learning_rate:
|
||||||
|
distribution: log_uniform_values
|
||||||
|
min: 3.0e-5
|
||||||
|
max: 1.0e-3
|
||||||
|
gamma:
|
||||||
|
values: [0.98, 0.99, 0.995]
|
||||||
|
buffer_size:
|
||||||
|
values: [50000, 100000, 200000]
|
||||||
|
batch_size:
|
||||||
|
values: [128, 256, 512]
|
||||||
|
tau:
|
||||||
|
values: [0.002, 0.005, 0.01]
|
||||||
|
train_freq:
|
||||||
|
values: [1, 4, 8]
|
||||||
|
learning_starts:
|
||||||
|
values: [1000, 3000, 5000]
|
||||||
86
engine/sweeps/small_arch_compare.yaml
Normal file
86
engine/sweeps/small_arch_compare.yaml
Normal file
@@ -0,0 +1,86 @@
|
|||||||
|
method: random
|
||||||
|
metric:
|
||||||
|
name: sweep/score
|
||||||
|
goal: maximize
|
||||||
|
command:
|
||||||
|
- ${env}
|
||||||
|
- python
|
||||||
|
- -m
|
||||||
|
- engine.train
|
||||||
|
parameters:
|
||||||
|
algo:
|
||||||
|
values: [ppo, a2c, dqn, qtable]
|
||||||
|
arch:
|
||||||
|
values: [tiny, small, medium]
|
||||||
|
activation:
|
||||||
|
values: [relu, tanh]
|
||||||
|
total_timesteps:
|
||||||
|
values: [8000, 12000, 20000]
|
||||||
|
seed:
|
||||||
|
values: [13, 42, 77]
|
||||||
|
n_products:
|
||||||
|
values: [6, 8, 10]
|
||||||
|
alpha:
|
||||||
|
distribution: uniform
|
||||||
|
min: 0.1
|
||||||
|
max: 0.5
|
||||||
|
lambda_coi:
|
||||||
|
distribution: uniform
|
||||||
|
min: 0.05
|
||||||
|
max: 0.4
|
||||||
|
robust_radius:
|
||||||
|
values: [0.0, 0.1, 0.2]
|
||||||
|
robust_points:
|
||||||
|
values: [3, 5]
|
||||||
|
info_value:
|
||||||
|
values: [0.75, 1.0, 1.5]
|
||||||
|
revenue_weight:
|
||||||
|
values: [0.005, 0.01, 0.02]
|
||||||
|
learning_rate:
|
||||||
|
distribution: log_uniform_values
|
||||||
|
min: 1.0e-5
|
||||||
|
max: 5.0e-4
|
||||||
|
gamma:
|
||||||
|
values: [0.98, 0.99]
|
||||||
|
buffer_size:
|
||||||
|
values: [10000, 30000, 50000]
|
||||||
|
batch_size:
|
||||||
|
values: [64, 128, 256]
|
||||||
|
tau:
|
||||||
|
values: [0.002, 0.005, 0.01]
|
||||||
|
train_freq:
|
||||||
|
values: [1, 4]
|
||||||
|
learning_starts:
|
||||||
|
values: [500, 1000, 2000]
|
||||||
|
n_steps:
|
||||||
|
values: [256, 512, 1024]
|
||||||
|
n_epochs:
|
||||||
|
values: [5, 10]
|
||||||
|
gae_lambda:
|
||||||
|
values: [0.9, 0.95]
|
||||||
|
clip_range:
|
||||||
|
values: [0.1, 0.2]
|
||||||
|
ent_coef:
|
||||||
|
values: [0.0, 0.005]
|
||||||
|
target_update_interval:
|
||||||
|
values: [500, 1000]
|
||||||
|
exploration_fraction:
|
||||||
|
values: [0.1, 0.2]
|
||||||
|
exploration_final_eps:
|
||||||
|
values: [0.02, 0.05]
|
||||||
|
action_levels:
|
||||||
|
values: [5, 7, 9]
|
||||||
|
action_scale_low:
|
||||||
|
values: [0.85, 0.9]
|
||||||
|
action_scale_high:
|
||||||
|
values: [1.1, 1.15]
|
||||||
|
q_lr:
|
||||||
|
values: [0.05, 0.1, 0.2]
|
||||||
|
q_bins:
|
||||||
|
values: [4, 6, 8]
|
||||||
|
eps_start:
|
||||||
|
value: 1.0
|
||||||
|
eps_end:
|
||||||
|
values: [0.02, 0.05]
|
||||||
|
eps_decay:
|
||||||
|
values: [0.999, 0.9995]
|
||||||
@@ -1,8 +1,10 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import json
|
import json
|
||||||
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from gymnasium.wrappers import FlattenObservation
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import wandb
|
import wandb
|
||||||
@@ -20,9 +22,7 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
HAS_SB3 = False
|
HAS_SB3 = False
|
||||||
|
|
||||||
from .wrapper import PHANTOM
|
from .jax import JAX_AVAILABLE
|
||||||
from .lib import EconomicMetricsWrapper, MetricsCallback
|
|
||||||
from .lib.discrete import EventQTable
|
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_CFG = {
|
DEFAULT_CFG = {
|
||||||
@@ -69,14 +69,34 @@ DEFAULT_CFG = {
|
|||||||
"arch": "small",
|
"arch": "small",
|
||||||
"activation": "relu",
|
"activation": "relu",
|
||||||
"q_bins": 6,
|
"q_bins": 6,
|
||||||
|
"max_steps": 100,
|
||||||
|
"margin_floor": 0.05,
|
||||||
|
"margin_floor_patience": 5,
|
||||||
|
"use_jax": False,
|
||||||
|
"jax_num_envs": 16,
|
||||||
|
"jax_num_steps": 128,
|
||||||
|
"jax_num_minibatches": 4,
|
||||||
|
"jax_update_epochs": 4,
|
||||||
|
"jax_anneal_lr": True,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _truthy(value: str | bool | None) -> bool:
|
||||||
|
if isinstance(value, bool):
|
||||||
|
return value
|
||||||
|
if value is None:
|
||||||
|
return False
|
||||||
|
return str(value).strip().lower() in {"1", "true", "yes", "on"}
|
||||||
|
|
||||||
|
|
||||||
def _cfg(raw: dict | None = None) -> dict:
|
def _cfg(raw: dict | None = None) -> dict:
|
||||||
cfg = dict(DEFAULT_CFG)
|
cfg = dict(DEFAULT_CFG)
|
||||||
if raw:
|
if raw:
|
||||||
cfg.update({k: v for k, v in raw.items() if v is not None})
|
cfg.update({k: v for k, v in raw.items() if v is not None})
|
||||||
cfg["algo"] = str(cfg["algo"]).lower()
|
cfg["algo"] = str(cfg["algo"]).lower()
|
||||||
|
cfg["use_jax"] = _truthy(cfg.get("use_jax")) or _truthy(
|
||||||
|
os.environ.get("PHANTOM_USE_JAX")
|
||||||
|
)
|
||||||
return cfg
|
return cfg
|
||||||
|
|
||||||
|
|
||||||
@@ -89,6 +109,11 @@ def _wandb_cfg_dict() -> dict:
|
|||||||
|
|
||||||
|
|
||||||
def make_env(cfg: dict):
|
def make_env(cfg: dict):
|
||||||
|
from gymnasium.wrappers import FlattenObservation
|
||||||
|
|
||||||
|
from .wrapper import PHANTOM
|
||||||
|
from .lib.wrappers import EconomicMetricsWrapper
|
||||||
|
|
||||||
env = PHANTOM(
|
env = PHANTOM(
|
||||||
n_products=int(cfg["n_products"]),
|
n_products=int(cfg["n_products"]),
|
||||||
alpha=float(cfg["alpha"]),
|
alpha=float(cfg["alpha"]),
|
||||||
@@ -101,6 +126,9 @@ def make_env(cfg: dict):
|
|||||||
action_levels=int(cfg["action_levels"]),
|
action_levels=int(cfg["action_levels"]),
|
||||||
action_scale_low=float(cfg["action_scale_low"]),
|
action_scale_low=float(cfg["action_scale_low"]),
|
||||||
action_scale_high=float(cfg["action_scale_high"]),
|
action_scale_high=float(cfg["action_scale_high"]),
|
||||||
|
max_steps=int(cfg.get("max_steps", 100)),
|
||||||
|
margin_floor=float(cfg.get("margin_floor", 0.05)),
|
||||||
|
margin_floor_patience=int(cfg.get("margin_floor_patience", 5)),
|
||||||
render_mode=None,
|
render_mode=None,
|
||||||
)
|
)
|
||||||
env = EconomicMetricsWrapper(env)
|
env = EconomicMetricsWrapper(env)
|
||||||
@@ -235,6 +263,8 @@ def build_model(cfg: dict, env):
|
|||||||
|
|
||||||
|
|
||||||
def train_qtable(cfg: dict) -> tuple[EventQTable, dict]:
|
def train_qtable(cfg: dict) -> tuple[EventQTable, dict]:
|
||||||
|
from .lib.discrete import EventQTable
|
||||||
|
|
||||||
np.random.seed(int(cfg["seed"]))
|
np.random.seed(int(cfg["seed"]))
|
||||||
env = make_env(cfg)
|
env = make_env(cfg)
|
||||||
eval_env = make_env(cfg)
|
eval_env = make_env(cfg)
|
||||||
@@ -275,6 +305,8 @@ def train_qtable(cfg: dict) -> tuple[EventQTable, dict]:
|
|||||||
def train_sb3(cfg: dict) -> tuple[object, dict]:
|
def train_sb3(cfg: dict) -> tuple[object, dict]:
|
||||||
if not HAS_SB3:
|
if not HAS_SB3:
|
||||||
raise ImportError("stable-baselines3 is required for SB3 models")
|
raise ImportError("stable-baselines3 is required for SB3 models")
|
||||||
|
from .lib.callbacks import MetricsCallback
|
||||||
|
|
||||||
env = make_env(cfg)
|
env = make_env(cfg)
|
||||||
eval_env = make_env(cfg)
|
eval_env = make_env(cfg)
|
||||||
env = Monitor(env)
|
env = Monitor(env)
|
||||||
@@ -303,7 +335,20 @@ def train_sb3(cfg: dict) -> tuple[object, dict]:
|
|||||||
|
|
||||||
def train_once(cfg: dict) -> dict:
|
def train_once(cfg: dict) -> dict:
|
||||||
algo = cfg["algo"]
|
algo = cfg["algo"]
|
||||||
if algo == "qtable":
|
if cfg.get("use_jax"):
|
||||||
|
if not JAX_AVAILABLE:
|
||||||
|
raise ImportError(
|
||||||
|
"JAX backend requested but JAX is not installed. "
|
||||||
|
"Install engine/jax/requirements.txt and jax[tpu] for TPU runs."
|
||||||
|
)
|
||||||
|
if algo == "qtable":
|
||||||
|
raise ValueError("qtable is not supported in JAX backend")
|
||||||
|
try:
|
||||||
|
from .jax.train import train_jax
|
||||||
|
except Exception as exc: # pragma: no cover
|
||||||
|
raise ImportError(f"Failed to import JAX trainer: {exc}") from exc
|
||||||
|
_, metrics = train_jax(cfg)
|
||||||
|
elif algo == "qtable":
|
||||||
_, metrics = train_qtable(cfg)
|
_, metrics = train_qtable(cfg)
|
||||||
else:
|
else:
|
||||||
_, metrics = train_sb3(cfg)
|
_, metrics = train_sb3(cfg)
|
||||||
@@ -357,8 +402,17 @@ def main():
|
|||||||
p.add_argument("--learning-rate", type=float)
|
p.add_argument("--learning-rate", type=float)
|
||||||
p.add_argument("--gamma", type=float)
|
p.add_argument("--gamma", type=float)
|
||||||
p.add_argument("--revenue-weight", type=float)
|
p.add_argument("--revenue-weight", type=float)
|
||||||
|
p.add_argument("--max-steps", type=int)
|
||||||
|
p.add_argument("--margin-floor", type=float)
|
||||||
|
p.add_argument("--margin-floor-patience", type=int)
|
||||||
p.add_argument("--arch", type=str)
|
p.add_argument("--arch", type=str)
|
||||||
p.add_argument("--activation", type=str)
|
p.add_argument("--activation", type=str)
|
||||||
|
p.add_argument("--jax", action="store_true")
|
||||||
|
p.add_argument("--jax-num-envs", type=int)
|
||||||
|
p.add_argument("--jax-num-steps", type=int)
|
||||||
|
p.add_argument("--jax-num-minibatches", type=int)
|
||||||
|
p.add_argument("--jax-update-epochs", type=int)
|
||||||
|
p.add_argument("--jax-anneal-lr", type=str)
|
||||||
p.add_argument("--sweep-agent", action="store_true")
|
p.add_argument("--sweep-agent", action="store_true")
|
||||||
p.add_argument("--sweep-id", type=str)
|
p.add_argument("--sweep-id", type=str)
|
||||||
p.add_argument("--count", type=int, default=0)
|
p.add_argument("--count", type=int, default=0)
|
||||||
@@ -377,8 +431,19 @@ def main():
|
|||||||
"learning_rate": args.learning_rate,
|
"learning_rate": args.learning_rate,
|
||||||
"gamma": args.gamma,
|
"gamma": args.gamma,
|
||||||
"revenue_weight": args.revenue_weight,
|
"revenue_weight": args.revenue_weight,
|
||||||
|
"max_steps": args.max_steps,
|
||||||
|
"margin_floor": args.margin_floor,
|
||||||
|
"margin_floor_patience": args.margin_floor_patience,
|
||||||
"arch": args.arch,
|
"arch": args.arch,
|
||||||
"activation": args.activation,
|
"activation": args.activation,
|
||||||
|
"use_jax": args.jax,
|
||||||
|
"jax_num_envs": args.jax_num_envs,
|
||||||
|
"jax_num_steps": args.jax_num_steps,
|
||||||
|
"jax_num_minibatches": args.jax_num_minibatches,
|
||||||
|
"jax_update_epochs": args.jax_update_epochs,
|
||||||
|
"jax_anneal_lr": _truthy(args.jax_anneal_lr)
|
||||||
|
if args.jax_anneal_lr is not None
|
||||||
|
else None,
|
||||||
}
|
}
|
||||||
overrides = {k: v for k, v in overrides.items() if v is not None}
|
overrides = {k: v for k, v in overrides.items() if v is not None}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user