mirror of
https://github.com/velocitatem/PHANTOM.git
synced 2026-05-31 08:33:36 +00:00
chor: implementing prallelization across jax
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
from sys import platform
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import numpy as np
|
||||
from .lib.demand import generate_demand_for_actor, estimate_demand
|
||||
from .lib.behavior import get_adjusted_transitions, sample_behavior_from_transitions
|
||||
@@ -7,6 +8,9 @@ from logging import INFO, getLogger
|
||||
logger = getLogger(__name__)
|
||||
logger.setLevel(INFO)
|
||||
|
||||
# shared pool; reused across act() calls to avoid per-call thread-spawn overhead
|
||||
_pool = ThreadPoolExecutor(max_workers=4)
|
||||
|
||||
|
||||
class MarketEngine:
|
||||
"""implements separate demand distributions for humans and agents per Section 3.1.1"""
|
||||
@@ -48,15 +52,18 @@ class MarketEngine:
|
||||
)
|
||||
human_transitions = get_adjusted_transitions(demand_h, human=True)
|
||||
agent_transitions = get_adjusted_transitions(demand_a, human=False)
|
||||
# sample behavior trajectories from each demand distribution
|
||||
human_t = [
|
||||
sample_behavior_from_transitions(human_transitions)
|
||||
# sample N trajectories in parallel; each chain is independent so threads
|
||||
# do not share state and numpy's per-call RNG is thread-safe
|
||||
h_futs = [
|
||||
_pool.submit(sample_behavior_from_transitions, human_transitions)
|
||||
for _ in range(self.Nhumans)
|
||||
]
|
||||
agent_t = [
|
||||
sample_behavior_from_transitions(agent_transitions)
|
||||
a_futs = [
|
||||
_pool.submit(sample_behavior_from_transitions, agent_transitions)
|
||||
for _ in range(self.Nagents)
|
||||
]
|
||||
human_t = [f.result() for f in h_futs]
|
||||
agent_t = [f.result() for f in a_futs]
|
||||
# store trajectories for agent probability calculation
|
||||
self.last_trajectories = human_t + agent_t
|
||||
return estimate_demand(self.last_trajectories, self.action_weights)
|
||||
|
||||
3
engine/jax/__init__.py
Normal file
3
engine/jax/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .robust import select_adversarial_alpha_jax, _JAX_OK
|
||||
|
||||
__all__ = ["select_adversarial_alpha_jax", "_JAX_OK"]
|
||||
176
engine/jax/robust.py
Normal file
176
engine/jax/robust.py
Normal file
@@ -0,0 +1,176 @@
|
||||
"""JAX-accelerated robust inner loop for PHANTOM.
|
||||
|
||||
provides a drop-in replacement for the sequential alpha-candidate evaluation in
|
||||
wrapper.py::_select_adversarial_alpha. the demand generation and reward
|
||||
computation are vmapped over the K candidate alpha values so all candidates are
|
||||
evaluated in a single vectorized pass instead of K sequential Python calls.
|
||||
|
||||
public surface:
|
||||
select_adversarial_alpha_jax(candidates, prices, human_params, agent_params,
|
||||
noise_std, n_sessions, n_products,
|
||||
baseline_prices, lambda_coi, info_value,
|
||||
reward_profit_weight, rng_key)
|
||||
-> (best_alpha: float, rewards: np.ndarray)
|
||||
|
||||
falls back gracefully when JAX is unavailable.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import numpy as np
|
||||
|
||||
try:
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax import vmap, jit
|
||||
|
||||
_JAX_OK = True
|
||||
except ImportError:
|
||||
_JAX_OK = False
|
||||
|
||||
|
||||
def _demand_for_actor_jax(prices, mean, std, noise_std, key):
|
||||
"""d(p;theta) = max(0, val - price + noise), normalized to sum 100."""
|
||||
k1, k2 = jax.random.split(key)
|
||||
val = jax.random.normal(k1, shape=prices.shape) * std + mean
|
||||
noise = jax.random.normal(k2, shape=prices.shape) * noise_std
|
||||
demand = jnp.maximum(0.0, val - prices + noise)
|
||||
total = demand.sum()
|
||||
return jnp.where(total > 0, demand / total * 100.0, demand)
|
||||
|
||||
|
||||
def _reward_for_candidate(
|
||||
alpha,
|
||||
prices,
|
||||
human_mean,
|
||||
human_std,
|
||||
agent_mean,
|
||||
agent_std,
|
||||
noise_std,
|
||||
baseline_prices,
|
||||
lambda_coi,
|
||||
info_value,
|
||||
reward_profit_weight,
|
||||
key,
|
||||
):
|
||||
"""compute a scalar reward for a single alpha candidate (pure JAX, vmappable)."""
|
||||
k_h, k_a = jax.random.split(key)
|
||||
# mixed demand proxy: weighted sum of human and agent demand signals
|
||||
demand_h = _demand_for_actor_jax(prices, human_mean, human_std, noise_std, k_h)
|
||||
demand_a = _demand_for_actor_jax(prices, agent_mean, agent_std, noise_std, k_a)
|
||||
demand = (1.0 - alpha) * demand_h + alpha * demand_a
|
||||
|
||||
revenue = jnp.dot(prices, demand)
|
||||
floor_cost = jnp.dot(baseline_prices, demand)
|
||||
profit = revenue - floor_cost
|
||||
|
||||
# agent_prob proxy: use alpha directly (no trajectory available in vectorized path)
|
||||
coi_leakage = alpha * info_value
|
||||
info_budget = jnp.maximum(floor_cost, 1.0)
|
||||
coi_penalty = lambda_coi * coi_leakage * info_budget
|
||||
|
||||
return reward_profit_weight * profit - coi_penalty
|
||||
|
||||
|
||||
if _JAX_OK:
|
||||
# compile once; retracing only happens on shape/dtype changes
|
||||
# 12 args: alpha, prices, h_mean, h_std, a_mean, a_std, noise_std,
|
||||
# baseline_prices, lambda_coi, info_value, reward_profit_weight, key
|
||||
_reward_batched = jit(
|
||||
vmap(
|
||||
_reward_for_candidate,
|
||||
in_axes=(0, None, None, None, None, None, None, None, None, None, None, 0),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def select_adversarial_alpha_jax(
|
||||
candidates: np.ndarray,
|
||||
prices: np.ndarray,
|
||||
human_params: tuple,
|
||||
agent_params: tuple,
|
||||
noise_std: float,
|
||||
baseline_prices: np.ndarray,
|
||||
lambda_coi: float,
|
||||
info_value: float,
|
||||
reward_profit_weight: float,
|
||||
rng_seed: int = 0,
|
||||
) -> tuple[float, np.ndarray]:
|
||||
"""evaluate all alpha candidates in a single vmapped pass.
|
||||
|
||||
returns (best_alpha, rewards_array) where best_alpha minimizes reward
|
||||
(worst case for the platform, driving robust policy training).
|
||||
|
||||
falls back to a pure-numpy sequential loop when JAX is unavailable so the
|
||||
wrapper can call this function unconditionally.
|
||||
"""
|
||||
if not _JAX_OK:
|
||||
return _fallback(
|
||||
candidates,
|
||||
prices,
|
||||
human_params,
|
||||
agent_params,
|
||||
noise_std,
|
||||
baseline_prices,
|
||||
lambda_coi,
|
||||
info_value,
|
||||
reward_profit_weight,
|
||||
)
|
||||
|
||||
k = len(candidates)
|
||||
key = jax.random.PRNGKey(rng_seed)
|
||||
keys = jax.random.split(key, k)
|
||||
|
||||
rewards = np.asarray(
|
||||
_reward_batched(
|
||||
jnp.asarray(candidates, dtype=jnp.float32),
|
||||
jnp.asarray(prices, dtype=jnp.float32),
|
||||
float(human_params[0]),
|
||||
float(human_params[1]),
|
||||
float(agent_params[0]),
|
||||
float(agent_params[1]),
|
||||
float(noise_std),
|
||||
jnp.asarray(baseline_prices, dtype=jnp.float32),
|
||||
float(lambda_coi),
|
||||
float(info_value),
|
||||
float(reward_profit_weight),
|
||||
keys,
|
||||
)
|
||||
)
|
||||
best_idx = int(np.argmin(rewards))
|
||||
return float(candidates[best_idx]), rewards
|
||||
|
||||
|
||||
def _fallback(
|
||||
candidates,
|
||||
prices,
|
||||
human_params,
|
||||
agent_params,
|
||||
noise_std,
|
||||
baseline_prices,
|
||||
lambda_coi,
|
||||
info_value,
|
||||
reward_profit_weight,
|
||||
):
|
||||
"""numpy fallback matching the reward formula above."""
|
||||
rewards = []
|
||||
for alpha in candidates:
|
||||
rng = np.random.default_rng()
|
||||
val_h = rng.normal(*human_params, size=len(prices))
|
||||
val_a = rng.normal(*agent_params, size=len(prices))
|
||||
noise_h = rng.normal(0, noise_std, len(prices))
|
||||
noise_a = rng.normal(0, noise_std, len(prices))
|
||||
d_h = np.maximum(0, val_h - prices + noise_h)
|
||||
d_a = np.maximum(0, val_a - prices + noise_a)
|
||||
s_h, s_a = d_h.sum(), d_a.sum()
|
||||
d_h = d_h / s_h * 100 if s_h > 0 else d_h
|
||||
d_a = d_a / s_a * 100 if s_a > 0 else d_a
|
||||
demand = (1.0 - alpha) * d_h + alpha * d_a
|
||||
revenue = float(np.dot(prices, demand))
|
||||
floor_cost = float(np.dot(baseline_prices, demand))
|
||||
profit = revenue - floor_cost
|
||||
coi_penalty = lambda_coi * alpha * info_value * max(floor_cost, 1.0)
|
||||
rewards.append(reward_profit_weight * profit - coi_penalty)
|
||||
rewards = np.array(rewards)
|
||||
best_idx = int(np.argmin(rewards))
|
||||
return float(candidates[best_idx]), rewards
|
||||
@@ -22,6 +22,9 @@ human_dir = str(base_dir / "collected_data")
|
||||
agent_dir = str(base_dir / "agents" / "collected_data")
|
||||
|
||||
_cache = {} # lazy cache for models and base pivots
|
||||
# cache keyed by (human: bool, condition_tuple) so we skip Kronecker re-expansion
|
||||
# for repeated calls with the same demand condition inside the robustness inner loop
|
||||
_transition_cache: dict = {}
|
||||
|
||||
|
||||
def _get_base_pivot(human: bool):
|
||||
@@ -68,22 +71,41 @@ def trajectory_to_events(trajectory: list) -> list:
|
||||
"""extract event names from trajectory for KL divergence calculation
|
||||
|
||||
trajectories are in format 'eventName_product0', extract just eventName
|
||||
|
||||
args:
|
||||
trajectory: list like ['view_product0', 'add_to_cart_product1', 'checkout_product1']
|
||||
|
||||
returns:
|
||||
list: event names like ['view', 'add_to_cart', 'checkout']
|
||||
"""
|
||||
events = []
|
||||
for state in trajectory:
|
||||
# state format from sample_behavior: 'eventName_productX'
|
||||
if "_product" in state:
|
||||
event = state.rsplit("_product", 1)[0]
|
||||
else:
|
||||
event = state
|
||||
events.append(event)
|
||||
return events
|
||||
return [s.rsplit("_product", 1)[0] if "_product" in s else s for s in trajectory]
|
||||
|
||||
|
||||
class _TransitionTable:
|
||||
"""numpy-backed transition table; replaces per-step pandas .loc[] indexing.
|
||||
|
||||
the profiling hotspot was DataFrame.xs called ~4-16k times per outer step.
|
||||
converting once to a dense float32 array with an int-keyed state index map
|
||||
reduces each row lookup to a single array slice with no pandas overhead.
|
||||
rows are pre-normalized so sampling requires no per-step division.
|
||||
"""
|
||||
|
||||
__slots__ = ("matrix", "states", "state_index", "n_states")
|
||||
|
||||
def __init__(self, df: pd.DataFrame):
|
||||
self.states: list[str] = df.index.tolist()
|
||||
self.state_index: dict[str, int] = {s: i for i, s in enumerate(self.states)}
|
||||
# float64 throughout: float32 row-sums can drift enough to break np.random.choice
|
||||
mat = np.nan_to_num(
|
||||
df.values.astype(np.float64), nan=0.0, posinf=0.0, neginf=0.0
|
||||
)
|
||||
mat = np.clip(mat, 0.0, None)
|
||||
row_sums = mat.sum(axis=1)
|
||||
# dead rows (all zero) get uniform distribution so sampling never receives NaN
|
||||
dead = row_sums <= 0
|
||||
mat[dead] = 1.0
|
||||
row_sums[dead] = float(mat.shape[1])
|
||||
mat = mat / row_sums[:, np.newaxis]
|
||||
# final nan guard in case fp still drifts
|
||||
np.nan_to_num(mat, nan=0.0, copy=False)
|
||||
row_sums2 = mat.sum(axis=1, keepdims=True)
|
||||
row_sums2[row_sums2 <= 0] = 1.0
|
||||
self.matrix: np.ndarray = mat / row_sums2
|
||||
self.n_states: int = len(self.states)
|
||||
|
||||
|
||||
def adjust_behavior_to_condition(condition, transition_matrix):
|
||||
@@ -92,46 +114,68 @@ def adjust_behavior_to_condition(condition, transition_matrix):
|
||||
condition = np.nan_to_num(condition, nan=0.0, posinf=0.0, neginf=0.0)
|
||||
condition = np.clip(condition, 0.0, None)
|
||||
s = float(np.sum(condition))
|
||||
if not np.isfinite(s) or s <= 0:
|
||||
cond_norm = np.full(len(condition), 1.0 / max(len(condition), 1), dtype=float)
|
||||
else:
|
||||
cond_norm = condition / s
|
||||
cond_norm = (
|
||||
condition / s
|
||||
if np.isfinite(s) and s > 0
|
||||
else np.full(len(condition), 1.0 / max(len(condition), 1), dtype=float)
|
||||
)
|
||||
n_products = len(condition)
|
||||
base_vals = transition_matrix.values
|
||||
base_cols, base_rows = (
|
||||
transition_matrix.columns.tolist(),
|
||||
transition_matrix.index.tolist(),
|
||||
)
|
||||
|
||||
# expand via kronecker-like tiling: each cell becomes a P*P block weighted by outer product of cond_norm
|
||||
expanded = np.kron(base_vals, np.outer(cond_norm, cond_norm))
|
||||
new_cols = [f"{c}_product{p}" for c in base_cols for p in range(n_products)]
|
||||
new_rows = [f"{r}_product{p}" for r in base_rows for p in range(n_products)]
|
||||
return pd.DataFrame(expanded, index=new_rows, columns=new_cols)
|
||||
|
||||
|
||||
def get_adjusted_transitions(condition, human=True):
|
||||
def get_adjusted_transitions(condition, human=True) -> _TransitionTable:
|
||||
"""return a _TransitionTable for the given demand condition.
|
||||
|
||||
results are cached by (human, rounded-condition) so that repeated calls with
|
||||
the same condition inside the robustness inner loop (K candidates, same prices)
|
||||
skip the Kronecker expansion entirely.
|
||||
"""
|
||||
condition = np.asarray(condition, dtype=float)
|
||||
# round to 4 significant digits for cache key stability
|
||||
cache_key = (human, tuple(np.round(condition, 4).tolist()))
|
||||
if cache_key in _transition_cache:
|
||||
return _transition_cache[cache_key]
|
||||
base_pivot = _get_base_pivot(human)
|
||||
return adjust_behavior_to_condition(condition, base_pivot)
|
||||
df = adjust_behavior_to_condition(condition, base_pivot)
|
||||
table = _TransitionTable(df)
|
||||
_transition_cache[cache_key] = table
|
||||
return table
|
||||
|
||||
|
||||
def sample_behavior_from_transitions(adjusted_transitions, max_len=40):
|
||||
trajectory = [np.random.choice(adjusted_transitions.index)]
|
||||
def clear_transition_cache():
|
||||
"""drop cached transition tables; call between episodes if condition space is large."""
|
||||
_transition_cache.clear()
|
||||
|
||||
|
||||
def sample_behavior_from_transitions(table, max_len=40):
|
||||
"""sample a Markov trajectory.
|
||||
|
||||
accepts _TransitionTable (fast path) or a legacy pandas DataFrame so existing
|
||||
call sites that pass a DataFrame directly continue to work unchanged.
|
||||
"""
|
||||
if isinstance(table, pd.DataFrame):
|
||||
table = _TransitionTable(table)
|
||||
|
||||
idx = np.random.randint(table.n_states)
|
||||
trajectory = [table.states[idx]]
|
||||
while len(trajectory) < max_len and "checkout" not in trajectory[-1]:
|
||||
probs = np.asarray(adjusted_transitions.loc[trajectory[-1]].values, dtype=float)
|
||||
probs = np.nan_to_num(probs, nan=0.0, posinf=0.0, neginf=0.0)
|
||||
probs = np.clip(probs, 0.0, None)
|
||||
s = float(np.sum(probs))
|
||||
sample = np.random.choice(
|
||||
adjusted_transitions.columns, p=(probs / s) if s > 0 else None
|
||||
)
|
||||
trajectory.append(sample)
|
||||
row = table.matrix[table.state_index[trajectory[-1]]]
|
||||
idx = int(np.random.choice(table.n_states, p=row))
|
||||
trajectory.append(table.states[idx])
|
||||
return trajectory
|
||||
|
||||
|
||||
def sample_behavior(condition, human=True, max_len=40):
|
||||
adjusted_transitions = get_adjusted_transitions(condition, human=human)
|
||||
return sample_behavior_from_transitions(adjusted_transitions, max_len=max_len)
|
||||
table = get_adjusted_transitions(condition, human=human)
|
||||
return sample_behavior_from_transitions(table, max_len=max_len)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -10,6 +10,7 @@ from .lib.coi import (
|
||||
)
|
||||
from .lib.behavior import get_transition_models, trajectory_to_events
|
||||
from .lib.wrappers import EconomicMetricsWrapper
|
||||
from .jax.robust import select_adversarial_alpha_jax, _JAX_OK
|
||||
|
||||
|
||||
class _ActionPricingEngine(PricingEngine):
|
||||
@@ -121,6 +122,7 @@ class PHANTOM(gym.Env):
|
||||
self._prices = None
|
||||
self._demand = None
|
||||
self._step_count = 0
|
||||
self._global_step = 0 # monotonic; used as JAX RNG seed across resets
|
||||
self._demand_history = []
|
||||
self._price_history = []
|
||||
self._revenue_history = []
|
||||
@@ -261,8 +263,37 @@ class PHANTOM(gym.Env):
|
||||
return float(np.mean(rewards)) if rewards else 0.0
|
||||
|
||||
def _select_adversarial_alpha(self, prices: np.ndarray) -> float:
|
||||
"""inner robust step: evaluate candidates and pick worst-case alpha"""
|
||||
"""inner robust step: pick worst-case alpha from the ambiguity interval.
|
||||
|
||||
when JAX is available and robust_rollouts==1 we use a vmapped pass over
|
||||
all K candidates in a single call (no Python loop, no market.act overhead).
|
||||
the JAX path approximates demand as the mixed closed-form d(p;theta) signal
|
||||
rather than running full trajectory sampling, which is accurate for the
|
||||
alpha-selection decision while being dramatically cheaper.
|
||||
|
||||
when robust_rollouts>1 or JAX is unavailable we fall back to the sequential
|
||||
market.act() loop so behavior is identical to the original implementation.
|
||||
"""
|
||||
candidates = self._alpha_candidates()
|
||||
if len(candidates) == 1:
|
||||
return float(candidates[0])
|
||||
|
||||
if _JAX_OK and self.robust_rollouts == 1:
|
||||
best_alpha, _ = select_adversarial_alpha_jax(
|
||||
candidates=candidates,
|
||||
prices=prices,
|
||||
human_params=self.market.human_params,
|
||||
agent_params=self.market.agent_params,
|
||||
noise_std=self.market.noise_std,
|
||||
baseline_prices=self.baseline_prices,
|
||||
lambda_coi=self.lambda_coi,
|
||||
info_value=self.info_value,
|
||||
reward_profit_weight=self.reward_profit_weight,
|
||||
rng_seed=self._global_step,
|
||||
)
|
||||
return best_alpha
|
||||
|
||||
# fallback: full trajectory-based sequential evaluation
|
||||
evaluations = [
|
||||
(float(alpha), self._evaluate_candidate(float(alpha), prices))
|
||||
for alpha in candidates
|
||||
@@ -299,6 +330,7 @@ class PHANTOM(gym.Env):
|
||||
def step(self, action):
|
||||
self._prices = self._decode_action(action)
|
||||
alpha_adv = self._select_adversarial_alpha(self._prices)
|
||||
self._global_step += 1 # always increment; JAX path may have already done so
|
||||
self._set_market_mix(alpha_adv)
|
||||
self._platform_stub.set_prices(self._prices)
|
||||
self._step_count += 1
|
||||
|
||||
Reference in New Issue
Block a user