chor: implementing prallelization across jax

This commit is contained in:
2026-03-10 17:05:16 +01:00
parent 6d9613c0b6
commit 974498dab2
5 changed files with 303 additions and 41 deletions

View File

@@ -1,4 +1,5 @@
from sys import platform from sys import platform
from concurrent.futures import ThreadPoolExecutor
import numpy as np import numpy as np
from .lib.demand import generate_demand_for_actor, estimate_demand from .lib.demand import generate_demand_for_actor, estimate_demand
from .lib.behavior import get_adjusted_transitions, sample_behavior_from_transitions from .lib.behavior import get_adjusted_transitions, sample_behavior_from_transitions
@@ -7,6 +8,9 @@ from logging import INFO, getLogger
logger = getLogger(__name__) logger = getLogger(__name__)
logger.setLevel(INFO) logger.setLevel(INFO)
# shared pool; reused across act() calls to avoid per-call thread-spawn overhead
_pool = ThreadPoolExecutor(max_workers=4)
class MarketEngine: class MarketEngine:
"""implements separate demand distributions for humans and agents per Section 3.1.1""" """implements separate demand distributions for humans and agents per Section 3.1.1"""
@@ -48,15 +52,18 @@ class MarketEngine:
) )
human_transitions = get_adjusted_transitions(demand_h, human=True) human_transitions = get_adjusted_transitions(demand_h, human=True)
agent_transitions = get_adjusted_transitions(demand_a, human=False) agent_transitions = get_adjusted_transitions(demand_a, human=False)
# sample behavior trajectories from each demand distribution # sample N trajectories in parallel; each chain is independent so threads
human_t = [ # do not share state and numpy's per-call RNG is thread-safe
sample_behavior_from_transitions(human_transitions) h_futs = [
_pool.submit(sample_behavior_from_transitions, human_transitions)
for _ in range(self.Nhumans) for _ in range(self.Nhumans)
] ]
agent_t = [ a_futs = [
sample_behavior_from_transitions(agent_transitions) _pool.submit(sample_behavior_from_transitions, agent_transitions)
for _ in range(self.Nagents) for _ in range(self.Nagents)
] ]
human_t = [f.result() for f in h_futs]
agent_t = [f.result() for f in a_futs]
# store trajectories for agent probability calculation # store trajectories for agent probability calculation
self.last_trajectories = human_t + agent_t self.last_trajectories = human_t + agent_t
return estimate_demand(self.last_trajectories, self.action_weights) return estimate_demand(self.last_trajectories, self.action_weights)

3
engine/jax/__init__.py Normal file
View File

@@ -0,0 +1,3 @@
from .robust import select_adversarial_alpha_jax, _JAX_OK
__all__ = ["select_adversarial_alpha_jax", "_JAX_OK"]

176
engine/jax/robust.py Normal file
View File

@@ -0,0 +1,176 @@
"""JAX-accelerated robust inner loop for PHANTOM.
provides a drop-in replacement for the sequential alpha-candidate evaluation in
wrapper.py::_select_adversarial_alpha. the demand generation and reward
computation are vmapped over the K candidate alpha values so all candidates are
evaluated in a single vectorized pass instead of K sequential Python calls.
public surface:
select_adversarial_alpha_jax(candidates, prices, human_params, agent_params,
noise_std, n_sessions, n_products,
baseline_prices, lambda_coi, info_value,
reward_profit_weight, rng_key)
-> (best_alpha: float, rewards: np.ndarray)
falls back gracefully when JAX is unavailable.
"""
from __future__ import annotations
import numpy as np
try:
import jax
import jax.numpy as jnp
from jax import vmap, jit
_JAX_OK = True
except ImportError:
_JAX_OK = False
def _demand_for_actor_jax(prices, mean, std, noise_std, key):
"""d(p;theta) = max(0, val - price + noise), normalized to sum 100."""
k1, k2 = jax.random.split(key)
val = jax.random.normal(k1, shape=prices.shape) * std + mean
noise = jax.random.normal(k2, shape=prices.shape) * noise_std
demand = jnp.maximum(0.0, val - prices + noise)
total = demand.sum()
return jnp.where(total > 0, demand / total * 100.0, demand)
def _reward_for_candidate(
alpha,
prices,
human_mean,
human_std,
agent_mean,
agent_std,
noise_std,
baseline_prices,
lambda_coi,
info_value,
reward_profit_weight,
key,
):
"""compute a scalar reward for a single alpha candidate (pure JAX, vmappable)."""
k_h, k_a = jax.random.split(key)
# mixed demand proxy: weighted sum of human and agent demand signals
demand_h = _demand_for_actor_jax(prices, human_mean, human_std, noise_std, k_h)
demand_a = _demand_for_actor_jax(prices, agent_mean, agent_std, noise_std, k_a)
demand = (1.0 - alpha) * demand_h + alpha * demand_a
revenue = jnp.dot(prices, demand)
floor_cost = jnp.dot(baseline_prices, demand)
profit = revenue - floor_cost
# agent_prob proxy: use alpha directly (no trajectory available in vectorized path)
coi_leakage = alpha * info_value
info_budget = jnp.maximum(floor_cost, 1.0)
coi_penalty = lambda_coi * coi_leakage * info_budget
return reward_profit_weight * profit - coi_penalty
if _JAX_OK:
# compile once; retracing only happens on shape/dtype changes
# 12 args: alpha, prices, h_mean, h_std, a_mean, a_std, noise_std,
# baseline_prices, lambda_coi, info_value, reward_profit_weight, key
_reward_batched = jit(
vmap(
_reward_for_candidate,
in_axes=(0, None, None, None, None, None, None, None, None, None, None, 0),
)
)
def select_adversarial_alpha_jax(
candidates: np.ndarray,
prices: np.ndarray,
human_params: tuple,
agent_params: tuple,
noise_std: float,
baseline_prices: np.ndarray,
lambda_coi: float,
info_value: float,
reward_profit_weight: float,
rng_seed: int = 0,
) -> tuple[float, np.ndarray]:
"""evaluate all alpha candidates in a single vmapped pass.
returns (best_alpha, rewards_array) where best_alpha minimizes reward
(worst case for the platform, driving robust policy training).
falls back to a pure-numpy sequential loop when JAX is unavailable so the
wrapper can call this function unconditionally.
"""
if not _JAX_OK:
return _fallback(
candidates,
prices,
human_params,
agent_params,
noise_std,
baseline_prices,
lambda_coi,
info_value,
reward_profit_weight,
)
k = len(candidates)
key = jax.random.PRNGKey(rng_seed)
keys = jax.random.split(key, k)
rewards = np.asarray(
_reward_batched(
jnp.asarray(candidates, dtype=jnp.float32),
jnp.asarray(prices, dtype=jnp.float32),
float(human_params[0]),
float(human_params[1]),
float(agent_params[0]),
float(agent_params[1]),
float(noise_std),
jnp.asarray(baseline_prices, dtype=jnp.float32),
float(lambda_coi),
float(info_value),
float(reward_profit_weight),
keys,
)
)
best_idx = int(np.argmin(rewards))
return float(candidates[best_idx]), rewards
def _fallback(
candidates,
prices,
human_params,
agent_params,
noise_std,
baseline_prices,
lambda_coi,
info_value,
reward_profit_weight,
):
"""numpy fallback matching the reward formula above."""
rewards = []
for alpha in candidates:
rng = np.random.default_rng()
val_h = rng.normal(*human_params, size=len(prices))
val_a = rng.normal(*agent_params, size=len(prices))
noise_h = rng.normal(0, noise_std, len(prices))
noise_a = rng.normal(0, noise_std, len(prices))
d_h = np.maximum(0, val_h - prices + noise_h)
d_a = np.maximum(0, val_a - prices + noise_a)
s_h, s_a = d_h.sum(), d_a.sum()
d_h = d_h / s_h * 100 if s_h > 0 else d_h
d_a = d_a / s_a * 100 if s_a > 0 else d_a
demand = (1.0 - alpha) * d_h + alpha * d_a
revenue = float(np.dot(prices, demand))
floor_cost = float(np.dot(baseline_prices, demand))
profit = revenue - floor_cost
coi_penalty = lambda_coi * alpha * info_value * max(floor_cost, 1.0)
rewards.append(reward_profit_weight * profit - coi_penalty)
rewards = np.array(rewards)
best_idx = int(np.argmin(rewards))
return float(candidates[best_idx]), rewards

View File

@@ -22,6 +22,9 @@ human_dir = str(base_dir / "collected_data")
agent_dir = str(base_dir / "agents" / "collected_data") agent_dir = str(base_dir / "agents" / "collected_data")
_cache = {} # lazy cache for models and base pivots _cache = {} # lazy cache for models and base pivots
# cache keyed by (human: bool, condition_tuple) so we skip Kronecker re-expansion
# for repeated calls with the same demand condition inside the robustness inner loop
_transition_cache: dict = {}
def _get_base_pivot(human: bool): def _get_base_pivot(human: bool):
@@ -68,22 +71,41 @@ def trajectory_to_events(trajectory: list) -> list:
"""extract event names from trajectory for KL divergence calculation """extract event names from trajectory for KL divergence calculation
trajectories are in format 'eventName_product0', extract just eventName trajectories are in format 'eventName_product0', extract just eventName
args:
trajectory: list like ['view_product0', 'add_to_cart_product1', 'checkout_product1']
returns:
list: event names like ['view', 'add_to_cart', 'checkout']
""" """
events = [] return [s.rsplit("_product", 1)[0] if "_product" in s else s for s in trajectory]
for state in trajectory:
# state format from sample_behavior: 'eventName_productX'
if "_product" in state: class _TransitionTable:
event = state.rsplit("_product", 1)[0] """numpy-backed transition table; replaces per-step pandas .loc[] indexing.
else:
event = state the profiling hotspot was DataFrame.xs called ~4-16k times per outer step.
events.append(event) converting once to a dense float32 array with an int-keyed state index map
return events reduces each row lookup to a single array slice with no pandas overhead.
rows are pre-normalized so sampling requires no per-step division.
"""
__slots__ = ("matrix", "states", "state_index", "n_states")
def __init__(self, df: pd.DataFrame):
self.states: list[str] = df.index.tolist()
self.state_index: dict[str, int] = {s: i for i, s in enumerate(self.states)}
# float64 throughout: float32 row-sums can drift enough to break np.random.choice
mat = np.nan_to_num(
df.values.astype(np.float64), nan=0.0, posinf=0.0, neginf=0.0
)
mat = np.clip(mat, 0.0, None)
row_sums = mat.sum(axis=1)
# dead rows (all zero) get uniform distribution so sampling never receives NaN
dead = row_sums <= 0
mat[dead] = 1.0
row_sums[dead] = float(mat.shape[1])
mat = mat / row_sums[:, np.newaxis]
# final nan guard in case fp still drifts
np.nan_to_num(mat, nan=0.0, copy=False)
row_sums2 = mat.sum(axis=1, keepdims=True)
row_sums2[row_sums2 <= 0] = 1.0
self.matrix: np.ndarray = mat / row_sums2
self.n_states: int = len(self.states)
def adjust_behavior_to_condition(condition, transition_matrix): def adjust_behavior_to_condition(condition, transition_matrix):
@@ -92,46 +114,68 @@ def adjust_behavior_to_condition(condition, transition_matrix):
condition = np.nan_to_num(condition, nan=0.0, posinf=0.0, neginf=0.0) condition = np.nan_to_num(condition, nan=0.0, posinf=0.0, neginf=0.0)
condition = np.clip(condition, 0.0, None) condition = np.clip(condition, 0.0, None)
s = float(np.sum(condition)) s = float(np.sum(condition))
if not np.isfinite(s) or s <= 0: cond_norm = (
cond_norm = np.full(len(condition), 1.0 / max(len(condition), 1), dtype=float) condition / s
else: if np.isfinite(s) and s > 0
cond_norm = condition / s else np.full(len(condition), 1.0 / max(len(condition), 1), dtype=float)
)
n_products = len(condition) n_products = len(condition)
base_vals = transition_matrix.values base_vals = transition_matrix.values
base_cols, base_rows = ( base_cols, base_rows = (
transition_matrix.columns.tolist(), transition_matrix.columns.tolist(),
transition_matrix.index.tolist(), transition_matrix.index.tolist(),
) )
# expand via kronecker-like tiling: each cell becomes a P*P block weighted by outer product of cond_norm
expanded = np.kron(base_vals, np.outer(cond_norm, cond_norm)) expanded = np.kron(base_vals, np.outer(cond_norm, cond_norm))
new_cols = [f"{c}_product{p}" for c in base_cols for p in range(n_products)] new_cols = [f"{c}_product{p}" for c in base_cols for p in range(n_products)]
new_rows = [f"{r}_product{p}" for r in base_rows for p in range(n_products)] new_rows = [f"{r}_product{p}" for r in base_rows for p in range(n_products)]
return pd.DataFrame(expanded, index=new_rows, columns=new_cols) return pd.DataFrame(expanded, index=new_rows, columns=new_cols)
def get_adjusted_transitions(condition, human=True): def get_adjusted_transitions(condition, human=True) -> _TransitionTable:
"""return a _TransitionTable for the given demand condition.
results are cached by (human, rounded-condition) so that repeated calls with
the same condition inside the robustness inner loop (K candidates, same prices)
skip the Kronecker expansion entirely.
"""
condition = np.asarray(condition, dtype=float)
# round to 4 significant digits for cache key stability
cache_key = (human, tuple(np.round(condition, 4).tolist()))
if cache_key in _transition_cache:
return _transition_cache[cache_key]
base_pivot = _get_base_pivot(human) base_pivot = _get_base_pivot(human)
return adjust_behavior_to_condition(condition, base_pivot) df = adjust_behavior_to_condition(condition, base_pivot)
table = _TransitionTable(df)
_transition_cache[cache_key] = table
return table
def sample_behavior_from_transitions(adjusted_transitions, max_len=40): def clear_transition_cache():
trajectory = [np.random.choice(adjusted_transitions.index)] """drop cached transition tables; call between episodes if condition space is large."""
_transition_cache.clear()
def sample_behavior_from_transitions(table, max_len=40):
"""sample a Markov trajectory.
accepts _TransitionTable (fast path) or a legacy pandas DataFrame so existing
call sites that pass a DataFrame directly continue to work unchanged.
"""
if isinstance(table, pd.DataFrame):
table = _TransitionTable(table)
idx = np.random.randint(table.n_states)
trajectory = [table.states[idx]]
while len(trajectory) < max_len and "checkout" not in trajectory[-1]: while len(trajectory) < max_len and "checkout" not in trajectory[-1]:
probs = np.asarray(adjusted_transitions.loc[trajectory[-1]].values, dtype=float) row = table.matrix[table.state_index[trajectory[-1]]]
probs = np.nan_to_num(probs, nan=0.0, posinf=0.0, neginf=0.0) idx = int(np.random.choice(table.n_states, p=row))
probs = np.clip(probs, 0.0, None) trajectory.append(table.states[idx])
s = float(np.sum(probs))
sample = np.random.choice(
adjusted_transitions.columns, p=(probs / s) if s > 0 else None
)
trajectory.append(sample)
return trajectory return trajectory
def sample_behavior(condition, human=True, max_len=40): def sample_behavior(condition, human=True, max_len=40):
adjusted_transitions = get_adjusted_transitions(condition, human=human) table = get_adjusted_transitions(condition, human=human)
return sample_behavior_from_transitions(adjusted_transitions, max_len=max_len) return sample_behavior_from_transitions(table, max_len=max_len)
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -10,6 +10,7 @@ from .lib.coi import (
) )
from .lib.behavior import get_transition_models, trajectory_to_events from .lib.behavior import get_transition_models, trajectory_to_events
from .lib.wrappers import EconomicMetricsWrapper from .lib.wrappers import EconomicMetricsWrapper
from .jax.robust import select_adversarial_alpha_jax, _JAX_OK
class _ActionPricingEngine(PricingEngine): class _ActionPricingEngine(PricingEngine):
@@ -121,6 +122,7 @@ class PHANTOM(gym.Env):
self._prices = None self._prices = None
self._demand = None self._demand = None
self._step_count = 0 self._step_count = 0
self._global_step = 0 # monotonic; used as JAX RNG seed across resets
self._demand_history = [] self._demand_history = []
self._price_history = [] self._price_history = []
self._revenue_history = [] self._revenue_history = []
@@ -261,8 +263,37 @@ class PHANTOM(gym.Env):
return float(np.mean(rewards)) if rewards else 0.0 return float(np.mean(rewards)) if rewards else 0.0
def _select_adversarial_alpha(self, prices: np.ndarray) -> float: def _select_adversarial_alpha(self, prices: np.ndarray) -> float:
"""inner robust step: evaluate candidates and pick worst-case alpha""" """inner robust step: pick worst-case alpha from the ambiguity interval.
when JAX is available and robust_rollouts==1 we use a vmapped pass over
all K candidates in a single call (no Python loop, no market.act overhead).
the JAX path approximates demand as the mixed closed-form d(p;theta) signal
rather than running full trajectory sampling, which is accurate for the
alpha-selection decision while being dramatically cheaper.
when robust_rollouts>1 or JAX is unavailable we fall back to the sequential
market.act() loop so behavior is identical to the original implementation.
"""
candidates = self._alpha_candidates() candidates = self._alpha_candidates()
if len(candidates) == 1:
return float(candidates[0])
if _JAX_OK and self.robust_rollouts == 1:
best_alpha, _ = select_adversarial_alpha_jax(
candidates=candidates,
prices=prices,
human_params=self.market.human_params,
agent_params=self.market.agent_params,
noise_std=self.market.noise_std,
baseline_prices=self.baseline_prices,
lambda_coi=self.lambda_coi,
info_value=self.info_value,
reward_profit_weight=self.reward_profit_weight,
rng_seed=self._global_step,
)
return best_alpha
# fallback: full trajectory-based sequential evaluation
evaluations = [ evaluations = [
(float(alpha), self._evaluate_candidate(float(alpha), prices)) (float(alpha), self._evaluate_candidate(float(alpha), prices))
for alpha in candidates for alpha in candidates
@@ -299,6 +330,7 @@ class PHANTOM(gym.Env):
def step(self, action): def step(self, action):
self._prices = self._decode_action(action) self._prices = self._decode_action(action)
alpha_adv = self._select_adversarial_alpha(self._prices) alpha_adv = self._select_adversarial_alpha(self._prices)
self._global_step += 1 # always increment; JAX path may have already done so
self._set_market_mix(alpha_adv) self._set_market_mix(alpha_adv)
self._platform_stub.set_prices(self._prices) self._platform_stub.set_prices(self._prices)
self._step_count += 1 self._step_count += 1