mirror of
https://github.com/velocitatem/PHANTOM.git
synced 2026-05-31 16:43:36 +00:00
adding naive jax and libraries and make adjustments
This commit is contained in:
287
engine/jax/env.py
Normal file
287
engine/jax/env.py
Normal file
@@ -0,0 +1,287 @@
|
||||
"""JAX-native PHANTOM environment with robust contamination step."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import NamedTuple
|
||||
|
||||
try:
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
except ImportError as exc: # pragma: no cover
|
||||
raise ImportError("engine.jax.env requires JAX") from exc
|
||||
|
||||
from .primitives import (
|
||||
_sample_sessions_jax,
|
||||
agent_probability_from_kl,
|
||||
batch_kl,
|
||||
compute_session_transitions,
|
||||
load_transition_data,
|
||||
purchase_flags,
|
||||
reward_with_coi_penalty,
|
||||
revenue_from_demand,
|
||||
weighted_demand,
|
||||
)
|
||||
|
||||
|
||||
class EnvParams(NamedTuple):
|
||||
n_products: int
|
||||
n_sessions: int
|
||||
max_episode_steps: int
|
||||
max_session_steps: int
|
||||
price_low: float
|
||||
price_high: float
|
||||
lambda_coi: float
|
||||
info_value: float
|
||||
robust_radius: float
|
||||
margin_floor: float
|
||||
margin_floor_patience: int
|
||||
action_scales: jax.Array
|
||||
alpha_nominal: float
|
||||
alpha_candidates: jax.Array
|
||||
human_T: jax.Array
|
||||
agent_T: jax.Array
|
||||
terminal_mask: jax.Array
|
||||
purchase_mask: jax.Array
|
||||
event_weights: jax.Array
|
||||
start_idx: int
|
||||
term_idx: int
|
||||
|
||||
|
||||
class EnvState(NamedTuple):
|
||||
prices: jax.Array
|
||||
demand: jax.Array
|
||||
step_count: jax.Array
|
||||
low_margin_streak: jax.Array
|
||||
last_agent_prob: jax.Array
|
||||
last_alpha_adv: jax.Array
|
||||
|
||||
|
||||
class CandidateEval(NamedTuple):
|
||||
reward: jax.Array
|
||||
revenue: jax.Array
|
||||
demand: jax.Array
|
||||
agent_prob: jax.Array
|
||||
leakage: jax.Array
|
||||
discount: jax.Array
|
||||
n_purchases: jax.Array
|
||||
n_agents: jax.Array
|
||||
|
||||
|
||||
def make_env_params(
|
||||
*,
|
||||
n_products: int,
|
||||
alpha: float,
|
||||
n_sessions: int,
|
||||
lambda_coi: float,
|
||||
robust_radius: float,
|
||||
robust_points: int,
|
||||
info_value: float,
|
||||
action_levels: int,
|
||||
action_scale_low: float,
|
||||
action_scale_high: float,
|
||||
price_low: float,
|
||||
price_high: float,
|
||||
max_episode_steps: int,
|
||||
max_session_steps: int = 40,
|
||||
margin_floor: float = 0.05,
|
||||
margin_floor_patience: int = 5,
|
||||
prefer_behavior_data: bool = True,
|
||||
) -> EnvParams:
|
||||
transition = load_transition_data(prefer_data=prefer_behavior_data).to_jax()
|
||||
if robust_radius <= 0.0 or robust_points <= 1:
|
||||
alpha_candidates = jnp.asarray([float(alpha)], dtype=jnp.float32)
|
||||
else:
|
||||
lo = max(0.0, float(alpha) - float(robust_radius))
|
||||
hi = min(1.0, float(alpha) + float(robust_radius))
|
||||
alpha_candidates = jnp.linspace(lo, hi, int(robust_points), dtype=jnp.float32)
|
||||
|
||||
action_scales = jnp.linspace(
|
||||
float(action_scale_low),
|
||||
float(action_scale_high),
|
||||
int(action_levels),
|
||||
dtype=jnp.float32,
|
||||
)
|
||||
return EnvParams(
|
||||
n_products=int(n_products),
|
||||
n_sessions=int(n_sessions),
|
||||
max_episode_steps=int(max_episode_steps),
|
||||
max_session_steps=int(max_session_steps),
|
||||
price_low=float(price_low),
|
||||
price_high=float(price_high),
|
||||
lambda_coi=float(lambda_coi),
|
||||
info_value=float(info_value),
|
||||
robust_radius=float(robust_radius),
|
||||
margin_floor=float(margin_floor),
|
||||
margin_floor_patience=int(margin_floor_patience),
|
||||
action_scales=action_scales,
|
||||
alpha_nominal=float(alpha),
|
||||
alpha_candidates=alpha_candidates,
|
||||
human_T=jnp.asarray(transition.human_T),
|
||||
agent_T=jnp.asarray(transition.agent_T),
|
||||
terminal_mask=jnp.asarray(transition.terminal_mask),
|
||||
purchase_mask=jnp.asarray(transition.purchase_mask),
|
||||
event_weights=jnp.asarray(transition.event_weights),
|
||||
start_idx=int(transition.start_idx),
|
||||
term_idx=int(transition.term_idx),
|
||||
)
|
||||
|
||||
|
||||
def _flatten_obs(demand: jax.Array, prices: jax.Array) -> jax.Array:
|
||||
return jnp.concatenate([demand.astype(jnp.float32), prices.astype(jnp.float32)])
|
||||
|
||||
|
||||
def _decode_action(
|
||||
prices: jax.Array, action: jax.Array, params: EnvParams
|
||||
) -> jax.Array:
|
||||
idx = jnp.clip(action.astype(jnp.int32), 0, params.action_scales.shape[0] - 1)
|
||||
scale = params.action_scales[idx]
|
||||
next_prices = prices * scale
|
||||
return jnp.clip(next_prices, params.price_low, params.price_high)
|
||||
|
||||
|
||||
def _evaluate_candidate(
|
||||
key: jax.Array,
|
||||
alpha_candidate: jax.Array,
|
||||
prices: jax.Array,
|
||||
params: EnvParams,
|
||||
) -> CandidateEval:
|
||||
states, products, actors, lengths = _sample_sessions_jax(
|
||||
key,
|
||||
params.human_T,
|
||||
params.agent_T,
|
||||
params.terminal_mask,
|
||||
params.start_idx,
|
||||
params.term_idx,
|
||||
alpha_candidate,
|
||||
params.n_products,
|
||||
params.n_sessions,
|
||||
params.max_session_steps,
|
||||
int(params.human_T.shape[0]),
|
||||
)
|
||||
session_trans = compute_session_transitions(
|
||||
states, lengths, int(params.human_T.shape[0])
|
||||
)
|
||||
delta_h, delta_a = batch_kl(session_trans, params.human_T, params.agent_T)
|
||||
agent_probs = agent_probability_from_kl(delta_h, delta_a)
|
||||
agent_prob = jnp.mean(agent_probs)
|
||||
|
||||
demand = weighted_demand(states, products, params.n_products, params.event_weights)
|
||||
revenue = revenue_from_demand(prices, demand)
|
||||
reward, leakage, discount = reward_with_coi_penalty(
|
||||
revenue,
|
||||
agent_prob,
|
||||
params.lambda_coi,
|
||||
params.info_value,
|
||||
)
|
||||
purchases = purchase_flags(states, params.purchase_mask)
|
||||
return CandidateEval(
|
||||
reward=reward,
|
||||
revenue=revenue,
|
||||
demand=demand,
|
||||
agent_prob=agent_prob,
|
||||
leakage=leakage,
|
||||
discount=discount,
|
||||
n_purchases=jnp.sum(purchases.astype(jnp.float32)),
|
||||
n_agents=jnp.sum(actors.astype(jnp.float32)),
|
||||
)
|
||||
|
||||
|
||||
def reset_env(key: jax.Array, params: EnvParams) -> tuple[jax.Array, EnvState]:
|
||||
prices = jax.random.uniform(
|
||||
key,
|
||||
shape=(params.n_products,),
|
||||
minval=params.price_low,
|
||||
maxval=params.price_high,
|
||||
)
|
||||
demand = jnp.zeros((params.n_products,), dtype=jnp.float32)
|
||||
state = EnvState(
|
||||
prices=prices,
|
||||
demand=demand,
|
||||
step_count=jnp.asarray(0, dtype=jnp.int32),
|
||||
low_margin_streak=jnp.asarray(0, dtype=jnp.int32),
|
||||
last_agent_prob=jnp.asarray(params.alpha_nominal, dtype=jnp.float32),
|
||||
last_alpha_adv=jnp.asarray(params.alpha_nominal, dtype=jnp.float32),
|
||||
)
|
||||
return _flatten_obs(demand, prices), state
|
||||
|
||||
|
||||
def step_env(
|
||||
key: jax.Array,
|
||||
state: EnvState,
|
||||
action: jax.Array,
|
||||
params: EnvParams,
|
||||
) -> tuple[jax.Array, EnvState, jax.Array, jax.Array, dict[str, jax.Array]]:
|
||||
prices = _decode_action(state.prices, action, params)
|
||||
n_candidates = params.alpha_candidates.shape[0]
|
||||
cand_keys = jax.random.split(key, n_candidates)
|
||||
evals = jax.vmap(
|
||||
lambda k, a: _evaluate_candidate(k, a, prices, params),
|
||||
in_axes=(0, 0),
|
||||
)(cand_keys, params.alpha_candidates)
|
||||
idx = jnp.argmin(evals.reward)
|
||||
|
||||
demand = evals.demand[idx]
|
||||
reward = evals.reward[idx]
|
||||
revenue = evals.revenue[idx]
|
||||
agent_prob = evals.agent_prob[idx]
|
||||
leakage = evals.leakage[idx]
|
||||
discount = evals.discount[idx]
|
||||
n_purchases = evals.n_purchases[idx]
|
||||
n_agents = evals.n_agents[idx]
|
||||
alpha_adv = params.alpha_candidates[idx]
|
||||
|
||||
step_count = state.step_count + 1
|
||||
avg_price = jnp.maximum(jnp.mean(prices), 1e-6)
|
||||
avg_margin = (avg_price - params.price_low) / avg_price
|
||||
next_streak = jnp.where(
|
||||
avg_margin < params.margin_floor, state.low_margin_streak + 1, 0
|
||||
)
|
||||
|
||||
margin_collapsed = next_streak >= params.margin_floor_patience
|
||||
done = (step_count >= params.max_episode_steps) | margin_collapsed
|
||||
|
||||
next_state = EnvState(
|
||||
prices=prices,
|
||||
demand=demand,
|
||||
step_count=step_count,
|
||||
low_margin_streak=next_streak,
|
||||
last_agent_prob=agent_prob,
|
||||
last_alpha_adv=alpha_adv,
|
||||
)
|
||||
obs = _flatten_obs(demand, prices)
|
||||
info = {
|
||||
"revenue": revenue,
|
||||
"agent_prob": agent_prob,
|
||||
"alpha_adv": alpha_adv,
|
||||
"coi_leakage": leakage,
|
||||
"coi_discount": discount,
|
||||
"n_purchases": n_purchases,
|
||||
"n_agents": n_agents,
|
||||
"avg_margin": avg_margin,
|
||||
}
|
||||
return obs, next_state, reward, done, info
|
||||
|
||||
|
||||
class PHANTOMJAXEnv:
|
||||
def __init__(self, params: EnvParams):
|
||||
self.params = params
|
||||
|
||||
def reset(self, key: jax.Array, params: EnvParams | None = None):
|
||||
return reset_env(key, self.params if params is None else params)
|
||||
|
||||
def step(
|
||||
self,
|
||||
key: jax.Array,
|
||||
state: EnvState,
|
||||
action: jax.Array,
|
||||
params: EnvParams | None = None,
|
||||
):
|
||||
return step_env(key, state, action, self.params if params is None else params)
|
||||
|
||||
def action_space_n(self, params: EnvParams | None = None) -> int:
|
||||
p = self.params if params is None else params
|
||||
return int(p.action_scales.shape[0])
|
||||
|
||||
def observation_dim(self, params: EnvParams | None = None) -> int:
|
||||
p = self.params if params is None else params
|
||||
return int(p.n_products * 2)
|
||||
Reference in New Issue
Block a user