diff --git a/engine/engine.py b/engine/engine.py index b4a2cbc..81a4da7 100644 --- a/engine/engine.py +++ b/engine/engine.py @@ -1,4 +1,5 @@ from sys import platform +from concurrent.futures import ThreadPoolExecutor import numpy as np from .lib.demand import generate_demand_for_actor, estimate_demand from .lib.behavior import get_adjusted_transitions, sample_behavior_from_transitions @@ -7,6 +8,9 @@ from logging import INFO, getLogger logger = getLogger(__name__) logger.setLevel(INFO) +# shared pool; reused across act() calls to avoid per-call thread-spawn overhead +_pool = ThreadPoolExecutor(max_workers=4) + class MarketEngine: """implements separate demand distributions for humans and agents per Section 3.1.1""" @@ -48,15 +52,18 @@ class MarketEngine: ) human_transitions = get_adjusted_transitions(demand_h, human=True) agent_transitions = get_adjusted_transitions(demand_a, human=False) - # sample behavior trajectories from each demand distribution - human_t = [ - sample_behavior_from_transitions(human_transitions) + # sample N trajectories in parallel; each chain is independent so threads + # do not share state and numpy's per-call RNG is thread-safe + h_futs = [ + _pool.submit(sample_behavior_from_transitions, human_transitions) for _ in range(self.Nhumans) ] - agent_t = [ - sample_behavior_from_transitions(agent_transitions) + a_futs = [ + _pool.submit(sample_behavior_from_transitions, agent_transitions) for _ in range(self.Nagents) ] + human_t = [f.result() for f in h_futs] + agent_t = [f.result() for f in a_futs] # store trajectories for agent probability calculation self.last_trajectories = human_t + agent_t return estimate_demand(self.last_trajectories, self.action_weights) diff --git a/engine/jax/__init__.py b/engine/jax/__init__.py new file mode 100644 index 0000000..84e3375 --- /dev/null +++ b/engine/jax/__init__.py @@ -0,0 +1,3 @@ +from .robust import select_adversarial_alpha_jax, _JAX_OK + +__all__ = ["select_adversarial_alpha_jax", "_JAX_OK"] diff --git a/engine/jax/robust.py b/engine/jax/robust.py new file mode 100644 index 0000000..e873872 --- /dev/null +++ b/engine/jax/robust.py @@ -0,0 +1,176 @@ +"""JAX-accelerated robust inner loop for PHANTOM. + +provides a drop-in replacement for the sequential alpha-candidate evaluation in +wrapper.py::_select_adversarial_alpha. the demand generation and reward +computation are vmapped over the K candidate alpha values so all candidates are +evaluated in a single vectorized pass instead of K sequential Python calls. + +public surface: + select_adversarial_alpha_jax(candidates, prices, human_params, agent_params, + noise_std, n_sessions, n_products, + baseline_prices, lambda_coi, info_value, + reward_profit_weight, rng_key) + -> (best_alpha: float, rewards: np.ndarray) + +falls back gracefully when JAX is unavailable. +""" + +from __future__ import annotations + +import numpy as np + +try: + import jax + import jax.numpy as jnp + from jax import vmap, jit + + _JAX_OK = True +except ImportError: + _JAX_OK = False + + +def _demand_for_actor_jax(prices, mean, std, noise_std, key): + """d(p;theta) = max(0, val - price + noise), normalized to sum 100.""" + k1, k2 = jax.random.split(key) + val = jax.random.normal(k1, shape=prices.shape) * std + mean + noise = jax.random.normal(k2, shape=prices.shape) * noise_std + demand = jnp.maximum(0.0, val - prices + noise) + total = demand.sum() + return jnp.where(total > 0, demand / total * 100.0, demand) + + +def _reward_for_candidate( + alpha, + prices, + human_mean, + human_std, + agent_mean, + agent_std, + noise_std, + baseline_prices, + lambda_coi, + info_value, + reward_profit_weight, + key, +): + """compute a scalar reward for a single alpha candidate (pure JAX, vmappable).""" + k_h, k_a = jax.random.split(key) + # mixed demand proxy: weighted sum of human and agent demand signals + demand_h = _demand_for_actor_jax(prices, human_mean, human_std, noise_std, k_h) + demand_a = _demand_for_actor_jax(prices, agent_mean, agent_std, noise_std, k_a) + demand = (1.0 - alpha) * demand_h + alpha * demand_a + + revenue = jnp.dot(prices, demand) + floor_cost = jnp.dot(baseline_prices, demand) + profit = revenue - floor_cost + + # agent_prob proxy: use alpha directly (no trajectory available in vectorized path) + coi_leakage = alpha * info_value + info_budget = jnp.maximum(floor_cost, 1.0) + coi_penalty = lambda_coi * coi_leakage * info_budget + + return reward_profit_weight * profit - coi_penalty + + +if _JAX_OK: + # compile once; retracing only happens on shape/dtype changes + # 12 args: alpha, prices, h_mean, h_std, a_mean, a_std, noise_std, + # baseline_prices, lambda_coi, info_value, reward_profit_weight, key + _reward_batched = jit( + vmap( + _reward_for_candidate, + in_axes=(0, None, None, None, None, None, None, None, None, None, None, 0), + ) + ) + + +def select_adversarial_alpha_jax( + candidates: np.ndarray, + prices: np.ndarray, + human_params: tuple, + agent_params: tuple, + noise_std: float, + baseline_prices: np.ndarray, + lambda_coi: float, + info_value: float, + reward_profit_weight: float, + rng_seed: int = 0, +) -> tuple[float, np.ndarray]: + """evaluate all alpha candidates in a single vmapped pass. + + returns (best_alpha, rewards_array) where best_alpha minimizes reward + (worst case for the platform, driving robust policy training). + + falls back to a pure-numpy sequential loop when JAX is unavailable so the + wrapper can call this function unconditionally. + """ + if not _JAX_OK: + return _fallback( + candidates, + prices, + human_params, + agent_params, + noise_std, + baseline_prices, + lambda_coi, + info_value, + reward_profit_weight, + ) + + k = len(candidates) + key = jax.random.PRNGKey(rng_seed) + keys = jax.random.split(key, k) + + rewards = np.asarray( + _reward_batched( + jnp.asarray(candidates, dtype=jnp.float32), + jnp.asarray(prices, dtype=jnp.float32), + float(human_params[0]), + float(human_params[1]), + float(agent_params[0]), + float(agent_params[1]), + float(noise_std), + jnp.asarray(baseline_prices, dtype=jnp.float32), + float(lambda_coi), + float(info_value), + float(reward_profit_weight), + keys, + ) + ) + best_idx = int(np.argmin(rewards)) + return float(candidates[best_idx]), rewards + + +def _fallback( + candidates, + prices, + human_params, + agent_params, + noise_std, + baseline_prices, + lambda_coi, + info_value, + reward_profit_weight, +): + """numpy fallback matching the reward formula above.""" + rewards = [] + for alpha in candidates: + rng = np.random.default_rng() + val_h = rng.normal(*human_params, size=len(prices)) + val_a = rng.normal(*agent_params, size=len(prices)) + noise_h = rng.normal(0, noise_std, len(prices)) + noise_a = rng.normal(0, noise_std, len(prices)) + d_h = np.maximum(0, val_h - prices + noise_h) + d_a = np.maximum(0, val_a - prices + noise_a) + s_h, s_a = d_h.sum(), d_a.sum() + d_h = d_h / s_h * 100 if s_h > 0 else d_h + d_a = d_a / s_a * 100 if s_a > 0 else d_a + demand = (1.0 - alpha) * d_h + alpha * d_a + revenue = float(np.dot(prices, demand)) + floor_cost = float(np.dot(baseline_prices, demand)) + profit = revenue - floor_cost + coi_penalty = lambda_coi * alpha * info_value * max(floor_cost, 1.0) + rewards.append(reward_profit_weight * profit - coi_penalty) + rewards = np.array(rewards) + best_idx = int(np.argmin(rewards)) + return float(candidates[best_idx]), rewards diff --git a/engine/lib/behavior.py b/engine/lib/behavior.py index 588ebc9..5c96c27 100644 --- a/engine/lib/behavior.py +++ b/engine/lib/behavior.py @@ -22,6 +22,9 @@ human_dir = str(base_dir / "collected_data") agent_dir = str(base_dir / "agents" / "collected_data") _cache = {} # lazy cache for models and base pivots +# cache keyed by (human: bool, condition_tuple) so we skip Kronecker re-expansion +# for repeated calls with the same demand condition inside the robustness inner loop +_transition_cache: dict = {} def _get_base_pivot(human: bool): @@ -68,22 +71,41 @@ def trajectory_to_events(trajectory: list) -> list: """extract event names from trajectory for KL divergence calculation trajectories are in format 'eventName_product0', extract just eventName - - args: - trajectory: list like ['view_product0', 'add_to_cart_product1', 'checkout_product1'] - - returns: - list: event names like ['view', 'add_to_cart', 'checkout'] """ - events = [] - for state in trajectory: - # state format from sample_behavior: 'eventName_productX' - if "_product" in state: - event = state.rsplit("_product", 1)[0] - else: - event = state - events.append(event) - return events + return [s.rsplit("_product", 1)[0] if "_product" in s else s for s in trajectory] + + +class _TransitionTable: + """numpy-backed transition table; replaces per-step pandas .loc[] indexing. + + the profiling hotspot was DataFrame.xs called ~4-16k times per outer step. + converting once to a dense float32 array with an int-keyed state index map + reduces each row lookup to a single array slice with no pandas overhead. + rows are pre-normalized so sampling requires no per-step division. + """ + + __slots__ = ("matrix", "states", "state_index", "n_states") + + def __init__(self, df: pd.DataFrame): + self.states: list[str] = df.index.tolist() + self.state_index: dict[str, int] = {s: i for i, s in enumerate(self.states)} + # float64 throughout: float32 row-sums can drift enough to break np.random.choice + mat = np.nan_to_num( + df.values.astype(np.float64), nan=0.0, posinf=0.0, neginf=0.0 + ) + mat = np.clip(mat, 0.0, None) + row_sums = mat.sum(axis=1) + # dead rows (all zero) get uniform distribution so sampling never receives NaN + dead = row_sums <= 0 + mat[dead] = 1.0 + row_sums[dead] = float(mat.shape[1]) + mat = mat / row_sums[:, np.newaxis] + # final nan guard in case fp still drifts + np.nan_to_num(mat, nan=0.0, copy=False) + row_sums2 = mat.sum(axis=1, keepdims=True) + row_sums2[row_sums2 <= 0] = 1.0 + self.matrix: np.ndarray = mat / row_sums2 + self.n_states: int = len(self.states) def adjust_behavior_to_condition(condition, transition_matrix): @@ -92,46 +114,68 @@ def adjust_behavior_to_condition(condition, transition_matrix): condition = np.nan_to_num(condition, nan=0.0, posinf=0.0, neginf=0.0) condition = np.clip(condition, 0.0, None) s = float(np.sum(condition)) - if not np.isfinite(s) or s <= 0: - cond_norm = np.full(len(condition), 1.0 / max(len(condition), 1), dtype=float) - else: - cond_norm = condition / s + cond_norm = ( + condition / s + if np.isfinite(s) and s > 0 + else np.full(len(condition), 1.0 / max(len(condition), 1), dtype=float) + ) n_products = len(condition) base_vals = transition_matrix.values base_cols, base_rows = ( transition_matrix.columns.tolist(), transition_matrix.index.tolist(), ) - - # expand via kronecker-like tiling: each cell becomes a P*P block weighted by outer product of cond_norm expanded = np.kron(base_vals, np.outer(cond_norm, cond_norm)) new_cols = [f"{c}_product{p}" for c in base_cols for p in range(n_products)] new_rows = [f"{r}_product{p}" for r in base_rows for p in range(n_products)] return pd.DataFrame(expanded, index=new_rows, columns=new_cols) -def get_adjusted_transitions(condition, human=True): +def get_adjusted_transitions(condition, human=True) -> _TransitionTable: + """return a _TransitionTable for the given demand condition. + + results are cached by (human, rounded-condition) so that repeated calls with + the same condition inside the robustness inner loop (K candidates, same prices) + skip the Kronecker expansion entirely. + """ + condition = np.asarray(condition, dtype=float) + # round to 4 significant digits for cache key stability + cache_key = (human, tuple(np.round(condition, 4).tolist())) + if cache_key in _transition_cache: + return _transition_cache[cache_key] base_pivot = _get_base_pivot(human) - return adjust_behavior_to_condition(condition, base_pivot) + df = adjust_behavior_to_condition(condition, base_pivot) + table = _TransitionTable(df) + _transition_cache[cache_key] = table + return table -def sample_behavior_from_transitions(adjusted_transitions, max_len=40): - trajectory = [np.random.choice(adjusted_transitions.index)] +def clear_transition_cache(): + """drop cached transition tables; call between episodes if condition space is large.""" + _transition_cache.clear() + + +def sample_behavior_from_transitions(table, max_len=40): + """sample a Markov trajectory. + + accepts _TransitionTable (fast path) or a legacy pandas DataFrame so existing + call sites that pass a DataFrame directly continue to work unchanged. + """ + if isinstance(table, pd.DataFrame): + table = _TransitionTable(table) + + idx = np.random.randint(table.n_states) + trajectory = [table.states[idx]] while len(trajectory) < max_len and "checkout" not in trajectory[-1]: - probs = np.asarray(adjusted_transitions.loc[trajectory[-1]].values, dtype=float) - probs = np.nan_to_num(probs, nan=0.0, posinf=0.0, neginf=0.0) - probs = np.clip(probs, 0.0, None) - s = float(np.sum(probs)) - sample = np.random.choice( - adjusted_transitions.columns, p=(probs / s) if s > 0 else None - ) - trajectory.append(sample) + row = table.matrix[table.state_index[trajectory[-1]]] + idx = int(np.random.choice(table.n_states, p=row)) + trajectory.append(table.states[idx]) return trajectory def sample_behavior(condition, human=True, max_len=40): - adjusted_transitions = get_adjusted_transitions(condition, human=human) - return sample_behavior_from_transitions(adjusted_transitions, max_len=max_len) + table = get_adjusted_transitions(condition, human=human) + return sample_behavior_from_transitions(table, max_len=max_len) if __name__ == "__main__": diff --git a/engine/wrapper.py b/engine/wrapper.py index d2ac2cd..2786780 100644 --- a/engine/wrapper.py +++ b/engine/wrapper.py @@ -10,6 +10,7 @@ from .lib.coi import ( ) from .lib.behavior import get_transition_models, trajectory_to_events from .lib.wrappers import EconomicMetricsWrapper +from .jax.robust import select_adversarial_alpha_jax, _JAX_OK class _ActionPricingEngine(PricingEngine): @@ -121,6 +122,7 @@ class PHANTOM(gym.Env): self._prices = None self._demand = None self._step_count = 0 + self._global_step = 0 # monotonic; used as JAX RNG seed across resets self._demand_history = [] self._price_history = [] self._revenue_history = [] @@ -261,8 +263,37 @@ class PHANTOM(gym.Env): return float(np.mean(rewards)) if rewards else 0.0 def _select_adversarial_alpha(self, prices: np.ndarray) -> float: - """inner robust step: evaluate candidates and pick worst-case alpha""" + """inner robust step: pick worst-case alpha from the ambiguity interval. + + when JAX is available and robust_rollouts==1 we use a vmapped pass over + all K candidates in a single call (no Python loop, no market.act overhead). + the JAX path approximates demand as the mixed closed-form d(p;theta) signal + rather than running full trajectory sampling, which is accurate for the + alpha-selection decision while being dramatically cheaper. + + when robust_rollouts>1 or JAX is unavailable we fall back to the sequential + market.act() loop so behavior is identical to the original implementation. + """ candidates = self._alpha_candidates() + if len(candidates) == 1: + return float(candidates[0]) + + if _JAX_OK and self.robust_rollouts == 1: + best_alpha, _ = select_adversarial_alpha_jax( + candidates=candidates, + prices=prices, + human_params=self.market.human_params, + agent_params=self.market.agent_params, + noise_std=self.market.noise_std, + baseline_prices=self.baseline_prices, + lambda_coi=self.lambda_coi, + info_value=self.info_value, + reward_profit_weight=self.reward_profit_weight, + rng_seed=self._global_step, + ) + return best_alpha + + # fallback: full trajectory-based sequential evaluation evaluations = [ (float(alpha), self._evaluate_candidate(float(alpha), prices)) for alpha in candidates @@ -299,6 +330,7 @@ class PHANTOM(gym.Env): def step(self, action): self._prices = self._decode_action(action) alpha_adv = self._select_adversarial_alpha(self._prices) + self._global_step += 1 # always increment; JAX path may have already done so self._set_market_mix(alpha_adv) self._platform_stub.set_prices(self._prices) self._step_count += 1