adding naive jax and libraries and make adjustments

This commit is contained in:
2026-02-17 14:48:18 +01:00
parent 66c4a0cd1d
commit 802f31b4a1
17 changed files with 2331 additions and 6 deletions

287
engine/jax/env.py Normal file
View File

@@ -0,0 +1,287 @@
"""JAX-native PHANTOM environment with robust contamination step."""
from __future__ import annotations
from typing import NamedTuple
try:
import jax
import jax.numpy as jnp
except ImportError as exc: # pragma: no cover
raise ImportError("engine.jax.env requires JAX") from exc
from .primitives import (
_sample_sessions_jax,
agent_probability_from_kl,
batch_kl,
compute_session_transitions,
load_transition_data,
purchase_flags,
reward_with_coi_penalty,
revenue_from_demand,
weighted_demand,
)
class EnvParams(NamedTuple):
n_products: int
n_sessions: int
max_episode_steps: int
max_session_steps: int
price_low: float
price_high: float
lambda_coi: float
info_value: float
robust_radius: float
margin_floor: float
margin_floor_patience: int
action_scales: jax.Array
alpha_nominal: float
alpha_candidates: jax.Array
human_T: jax.Array
agent_T: jax.Array
terminal_mask: jax.Array
purchase_mask: jax.Array
event_weights: jax.Array
start_idx: int
term_idx: int
class EnvState(NamedTuple):
prices: jax.Array
demand: jax.Array
step_count: jax.Array
low_margin_streak: jax.Array
last_agent_prob: jax.Array
last_alpha_adv: jax.Array
class CandidateEval(NamedTuple):
reward: jax.Array
revenue: jax.Array
demand: jax.Array
agent_prob: jax.Array
leakage: jax.Array
discount: jax.Array
n_purchases: jax.Array
n_agents: jax.Array
def make_env_params(
*,
n_products: int,
alpha: float,
n_sessions: int,
lambda_coi: float,
robust_radius: float,
robust_points: int,
info_value: float,
action_levels: int,
action_scale_low: float,
action_scale_high: float,
price_low: float,
price_high: float,
max_episode_steps: int,
max_session_steps: int = 40,
margin_floor: float = 0.05,
margin_floor_patience: int = 5,
prefer_behavior_data: bool = True,
) -> EnvParams:
transition = load_transition_data(prefer_data=prefer_behavior_data).to_jax()
if robust_radius <= 0.0 or robust_points <= 1:
alpha_candidates = jnp.asarray([float(alpha)], dtype=jnp.float32)
else:
lo = max(0.0, float(alpha) - float(robust_radius))
hi = min(1.0, float(alpha) + float(robust_radius))
alpha_candidates = jnp.linspace(lo, hi, int(robust_points), dtype=jnp.float32)
action_scales = jnp.linspace(
float(action_scale_low),
float(action_scale_high),
int(action_levels),
dtype=jnp.float32,
)
return EnvParams(
n_products=int(n_products),
n_sessions=int(n_sessions),
max_episode_steps=int(max_episode_steps),
max_session_steps=int(max_session_steps),
price_low=float(price_low),
price_high=float(price_high),
lambda_coi=float(lambda_coi),
info_value=float(info_value),
robust_radius=float(robust_radius),
margin_floor=float(margin_floor),
margin_floor_patience=int(margin_floor_patience),
action_scales=action_scales,
alpha_nominal=float(alpha),
alpha_candidates=alpha_candidates,
human_T=jnp.asarray(transition.human_T),
agent_T=jnp.asarray(transition.agent_T),
terminal_mask=jnp.asarray(transition.terminal_mask),
purchase_mask=jnp.asarray(transition.purchase_mask),
event_weights=jnp.asarray(transition.event_weights),
start_idx=int(transition.start_idx),
term_idx=int(transition.term_idx),
)
def _flatten_obs(demand: jax.Array, prices: jax.Array) -> jax.Array:
return jnp.concatenate([demand.astype(jnp.float32), prices.astype(jnp.float32)])
def _decode_action(
prices: jax.Array, action: jax.Array, params: EnvParams
) -> jax.Array:
idx = jnp.clip(action.astype(jnp.int32), 0, params.action_scales.shape[0] - 1)
scale = params.action_scales[idx]
next_prices = prices * scale
return jnp.clip(next_prices, params.price_low, params.price_high)
def _evaluate_candidate(
key: jax.Array,
alpha_candidate: jax.Array,
prices: jax.Array,
params: EnvParams,
) -> CandidateEval:
states, products, actors, lengths = _sample_sessions_jax(
key,
params.human_T,
params.agent_T,
params.terminal_mask,
params.start_idx,
params.term_idx,
alpha_candidate,
params.n_products,
params.n_sessions,
params.max_session_steps,
int(params.human_T.shape[0]),
)
session_trans = compute_session_transitions(
states, lengths, int(params.human_T.shape[0])
)
delta_h, delta_a = batch_kl(session_trans, params.human_T, params.agent_T)
agent_probs = agent_probability_from_kl(delta_h, delta_a)
agent_prob = jnp.mean(agent_probs)
demand = weighted_demand(states, products, params.n_products, params.event_weights)
revenue = revenue_from_demand(prices, demand)
reward, leakage, discount = reward_with_coi_penalty(
revenue,
agent_prob,
params.lambda_coi,
params.info_value,
)
purchases = purchase_flags(states, params.purchase_mask)
return CandidateEval(
reward=reward,
revenue=revenue,
demand=demand,
agent_prob=agent_prob,
leakage=leakage,
discount=discount,
n_purchases=jnp.sum(purchases.astype(jnp.float32)),
n_agents=jnp.sum(actors.astype(jnp.float32)),
)
def reset_env(key: jax.Array, params: EnvParams) -> tuple[jax.Array, EnvState]:
prices = jax.random.uniform(
key,
shape=(params.n_products,),
minval=params.price_low,
maxval=params.price_high,
)
demand = jnp.zeros((params.n_products,), dtype=jnp.float32)
state = EnvState(
prices=prices,
demand=demand,
step_count=jnp.asarray(0, dtype=jnp.int32),
low_margin_streak=jnp.asarray(0, dtype=jnp.int32),
last_agent_prob=jnp.asarray(params.alpha_nominal, dtype=jnp.float32),
last_alpha_adv=jnp.asarray(params.alpha_nominal, dtype=jnp.float32),
)
return _flatten_obs(demand, prices), state
def step_env(
key: jax.Array,
state: EnvState,
action: jax.Array,
params: EnvParams,
) -> tuple[jax.Array, EnvState, jax.Array, jax.Array, dict[str, jax.Array]]:
prices = _decode_action(state.prices, action, params)
n_candidates = params.alpha_candidates.shape[0]
cand_keys = jax.random.split(key, n_candidates)
evals = jax.vmap(
lambda k, a: _evaluate_candidate(k, a, prices, params),
in_axes=(0, 0),
)(cand_keys, params.alpha_candidates)
idx = jnp.argmin(evals.reward)
demand = evals.demand[idx]
reward = evals.reward[idx]
revenue = evals.revenue[idx]
agent_prob = evals.agent_prob[idx]
leakage = evals.leakage[idx]
discount = evals.discount[idx]
n_purchases = evals.n_purchases[idx]
n_agents = evals.n_agents[idx]
alpha_adv = params.alpha_candidates[idx]
step_count = state.step_count + 1
avg_price = jnp.maximum(jnp.mean(prices), 1e-6)
avg_margin = (avg_price - params.price_low) / avg_price
next_streak = jnp.where(
avg_margin < params.margin_floor, state.low_margin_streak + 1, 0
)
margin_collapsed = next_streak >= params.margin_floor_patience
done = (step_count >= params.max_episode_steps) | margin_collapsed
next_state = EnvState(
prices=prices,
demand=demand,
step_count=step_count,
low_margin_streak=next_streak,
last_agent_prob=agent_prob,
last_alpha_adv=alpha_adv,
)
obs = _flatten_obs(demand, prices)
info = {
"revenue": revenue,
"agent_prob": agent_prob,
"alpha_adv": alpha_adv,
"coi_leakage": leakage,
"coi_discount": discount,
"n_purchases": n_purchases,
"n_agents": n_agents,
"avg_margin": avg_margin,
}
return obs, next_state, reward, done, info
class PHANTOMJAXEnv:
def __init__(self, params: EnvParams):
self.params = params
def reset(self, key: jax.Array, params: EnvParams | None = None):
return reset_env(key, self.params if params is None else params)
def step(
self,
key: jax.Array,
state: EnvState,
action: jax.Array,
params: EnvParams | None = None,
):
return step_env(key, state, action, self.params if params is None else params)
def action_space_n(self, params: EnvParams | None = None) -> int:
p = self.params if params is None else params
return int(p.action_scales.shape[0])
def observation_dim(self, params: EnvParams | None = None) -> int:
p = self.params if params is None else params
return int(p.n_products * 2)