mirror of
https://github.com/velocitatem/PHANTOM.git
synced 2026-05-31 08:33:36 +00:00
chor: implementing prallelization across jax
This commit is contained in:
@@ -10,6 +10,7 @@ from .lib.coi import (
|
||||
)
|
||||
from .lib.behavior import get_transition_models, trajectory_to_events
|
||||
from .lib.wrappers import EconomicMetricsWrapper
|
||||
from .jax.robust import select_adversarial_alpha_jax, _JAX_OK
|
||||
|
||||
|
||||
class _ActionPricingEngine(PricingEngine):
|
||||
@@ -121,6 +122,7 @@ class PHANTOM(gym.Env):
|
||||
self._prices = None
|
||||
self._demand = None
|
||||
self._step_count = 0
|
||||
self._global_step = 0 # monotonic; used as JAX RNG seed across resets
|
||||
self._demand_history = []
|
||||
self._price_history = []
|
||||
self._revenue_history = []
|
||||
@@ -261,8 +263,37 @@ class PHANTOM(gym.Env):
|
||||
return float(np.mean(rewards)) if rewards else 0.0
|
||||
|
||||
def _select_adversarial_alpha(self, prices: np.ndarray) -> float:
|
||||
"""inner robust step: evaluate candidates and pick worst-case alpha"""
|
||||
"""inner robust step: pick worst-case alpha from the ambiguity interval.
|
||||
|
||||
when JAX is available and robust_rollouts==1 we use a vmapped pass over
|
||||
all K candidates in a single call (no Python loop, no market.act overhead).
|
||||
the JAX path approximates demand as the mixed closed-form d(p;theta) signal
|
||||
rather than running full trajectory sampling, which is accurate for the
|
||||
alpha-selection decision while being dramatically cheaper.
|
||||
|
||||
when robust_rollouts>1 or JAX is unavailable we fall back to the sequential
|
||||
market.act() loop so behavior is identical to the original implementation.
|
||||
"""
|
||||
candidates = self._alpha_candidates()
|
||||
if len(candidates) == 1:
|
||||
return float(candidates[0])
|
||||
|
||||
if _JAX_OK and self.robust_rollouts == 1:
|
||||
best_alpha, _ = select_adversarial_alpha_jax(
|
||||
candidates=candidates,
|
||||
prices=prices,
|
||||
human_params=self.market.human_params,
|
||||
agent_params=self.market.agent_params,
|
||||
noise_std=self.market.noise_std,
|
||||
baseline_prices=self.baseline_prices,
|
||||
lambda_coi=self.lambda_coi,
|
||||
info_value=self.info_value,
|
||||
reward_profit_weight=self.reward_profit_weight,
|
||||
rng_seed=self._global_step,
|
||||
)
|
||||
return best_alpha
|
||||
|
||||
# fallback: full trajectory-based sequential evaluation
|
||||
evaluations = [
|
||||
(float(alpha), self._evaluate_candidate(float(alpha), prices))
|
||||
for alpha in candidates
|
||||
@@ -299,6 +330,7 @@ class PHANTOM(gym.Env):
|
||||
def step(self, action):
|
||||
self._prices = self._decode_action(action)
|
||||
alpha_adv = self._select_adversarial_alpha(self._prices)
|
||||
self._global_step += 1 # always increment; JAX path may have already done so
|
||||
self._set_market_mix(alpha_adv)
|
||||
self._platform_stub.set_prices(self._prices)
|
||||
self._step_count += 1
|
||||
|
||||
Reference in New Issue
Block a user