chor: implementing prallelization across jax

This commit is contained in:
2026-03-10 17:05:16 +01:00
parent 6d9613c0b6
commit 974498dab2
5 changed files with 303 additions and 41 deletions

View File

@@ -10,6 +10,7 @@ from .lib.coi import (
)
from .lib.behavior import get_transition_models, trajectory_to_events
from .lib.wrappers import EconomicMetricsWrapper
from .jax.robust import select_adversarial_alpha_jax, _JAX_OK
class _ActionPricingEngine(PricingEngine):
@@ -121,6 +122,7 @@ class PHANTOM(gym.Env):
self._prices = None
self._demand = None
self._step_count = 0
self._global_step = 0 # monotonic; used as JAX RNG seed across resets
self._demand_history = []
self._price_history = []
self._revenue_history = []
@@ -261,8 +263,37 @@ class PHANTOM(gym.Env):
return float(np.mean(rewards)) if rewards else 0.0
def _select_adversarial_alpha(self, prices: np.ndarray) -> float:
"""inner robust step: evaluate candidates and pick worst-case alpha"""
"""inner robust step: pick worst-case alpha from the ambiguity interval.
when JAX is available and robust_rollouts==1 we use a vmapped pass over
all K candidates in a single call (no Python loop, no market.act overhead).
the JAX path approximates demand as the mixed closed-form d(p;theta) signal
rather than running full trajectory sampling, which is accurate for the
alpha-selection decision while being dramatically cheaper.
when robust_rollouts>1 or JAX is unavailable we fall back to the sequential
market.act() loop so behavior is identical to the original implementation.
"""
candidates = self._alpha_candidates()
if len(candidates) == 1:
return float(candidates[0])
if _JAX_OK and self.robust_rollouts == 1:
best_alpha, _ = select_adversarial_alpha_jax(
candidates=candidates,
prices=prices,
human_params=self.market.human_params,
agent_params=self.market.agent_params,
noise_std=self.market.noise_std,
baseline_prices=self.baseline_prices,
lambda_coi=self.lambda_coi,
info_value=self.info_value,
reward_profit_weight=self.reward_profit_weight,
rng_seed=self._global_step,
)
return best_alpha
# fallback: full trajectory-based sequential evaluation
evaluations = [
(float(alpha), self._evaluate_candidate(float(alpha), prices))
for alpha in candidates
@@ -299,6 +330,7 @@ class PHANTOM(gym.Env):
def step(self, action):
self._prices = self._decode_action(action)
alpha_adv = self._select_adversarial_alpha(self._prices)
self._global_step += 1 # always increment; JAX path may have already done so
self._set_market_mix(alpha_adv)
self._platform_stub.set_prices(self._prices)
self._step_count += 1