mirror of
https://github.com/velocitatem/PHANTOM.git
synced 2026-05-31 08:33:36 +00:00
chore: cleaning some code
This commit is contained in:
@@ -32,6 +32,7 @@ class EnvParams(NamedTuple):
|
|||||||
price_high: float
|
price_high: float
|
||||||
lambda_coi: float
|
lambda_coi: float
|
||||||
info_value: float
|
info_value: float
|
||||||
|
eta_ux: float
|
||||||
robust_radius: float
|
robust_radius: float
|
||||||
margin_floor: float
|
margin_floor: float
|
||||||
margin_floor_patience: int
|
margin_floor_patience: int
|
||||||
@@ -63,6 +64,7 @@ class CandidateEval(NamedTuple):
|
|||||||
agent_prob: jax.Array
|
agent_prob: jax.Array
|
||||||
leakage: jax.Array
|
leakage: jax.Array
|
||||||
discount: jax.Array
|
discount: jax.Array
|
||||||
|
ux_penalty: jax.Array
|
||||||
n_purchases: jax.Array
|
n_purchases: jax.Array
|
||||||
n_agents: jax.Array
|
n_agents: jax.Array
|
||||||
|
|
||||||
@@ -76,6 +78,7 @@ def make_env_params(
|
|||||||
robust_radius: float,
|
robust_radius: float,
|
||||||
robust_points: int,
|
robust_points: int,
|
||||||
info_value: float,
|
info_value: float,
|
||||||
|
eta_ux: float = 0.5,
|
||||||
action_levels: int,
|
action_levels: int,
|
||||||
action_scale_low: float,
|
action_scale_low: float,
|
||||||
action_scale_high: float,
|
action_scale_high: float,
|
||||||
@@ -110,6 +113,7 @@ def make_env_params(
|
|||||||
price_high=float(price_high),
|
price_high=float(price_high),
|
||||||
lambda_coi=float(lambda_coi),
|
lambda_coi=float(lambda_coi),
|
||||||
info_value=float(info_value),
|
info_value=float(info_value),
|
||||||
|
eta_ux=float(eta_ux),
|
||||||
robust_radius=float(robust_radius),
|
robust_radius=float(robust_radius),
|
||||||
margin_floor=float(margin_floor),
|
margin_floor=float(margin_floor),
|
||||||
margin_floor_patience=int(margin_floor_patience),
|
margin_floor_patience=int(margin_floor_patience),
|
||||||
@@ -143,6 +147,7 @@ def _evaluate_candidate(
|
|||||||
key: jax.Array,
|
key: jax.Array,
|
||||||
alpha_candidate: jax.Array,
|
alpha_candidate: jax.Array,
|
||||||
prices: jax.Array,
|
prices: jax.Array,
|
||||||
|
ux_volatility: jax.Array,
|
||||||
params: EnvParams,
|
params: EnvParams,
|
||||||
) -> CandidateEval:
|
) -> CandidateEval:
|
||||||
states, products, actors, lengths = _sample_sessions_jax(
|
states, products, actors, lengths = _sample_sessions_jax(
|
||||||
@@ -167,11 +172,13 @@ def _evaluate_candidate(
|
|||||||
|
|
||||||
demand = weighted_demand(states, products, params.n_products, params.event_weights)
|
demand = weighted_demand(states, products, params.n_products, params.event_weights)
|
||||||
revenue = revenue_from_demand(prices, demand)
|
revenue = revenue_from_demand(prices, demand)
|
||||||
reward, leakage, discount = reward_with_coi_penalty(
|
reward, leakage, discount, ux_penalty = reward_with_coi_penalty(
|
||||||
revenue,
|
revenue,
|
||||||
agent_prob,
|
agent_prob,
|
||||||
params.lambda_coi,
|
params.lambda_coi,
|
||||||
params.info_value,
|
params.info_value,
|
||||||
|
params.eta_ux,
|
||||||
|
ux_volatility,
|
||||||
)
|
)
|
||||||
purchases = purchase_flags(states, params.purchase_mask)
|
purchases = purchase_flags(states, params.purchase_mask)
|
||||||
return CandidateEval(
|
return CandidateEval(
|
||||||
@@ -181,6 +188,7 @@ def _evaluate_candidate(
|
|||||||
agent_prob=agent_prob,
|
agent_prob=agent_prob,
|
||||||
leakage=leakage,
|
leakage=leakage,
|
||||||
discount=discount,
|
discount=discount,
|
||||||
|
ux_penalty=ux_penalty,
|
||||||
n_purchases=jnp.sum(purchases.astype(jnp.float32)),
|
n_purchases=jnp.sum(purchases.astype(jnp.float32)),
|
||||||
n_agents=jnp.sum(actors.astype(jnp.float32)),
|
n_agents=jnp.sum(actors.astype(jnp.float32)),
|
||||||
)
|
)
|
||||||
@@ -212,10 +220,16 @@ def step_env(
|
|||||||
params: EnvParams,
|
params: EnvParams,
|
||||||
) -> tuple[jax.Array, EnvState, jax.Array, jax.Array, dict[str, jax.Array]]:
|
) -> tuple[jax.Array, EnvState, jax.Array, jax.Array, dict[str, jax.Array]]:
|
||||||
prices = _decode_action(state.prices, action, params)
|
prices = _decode_action(state.prices, action, params)
|
||||||
|
|
||||||
|
baseline = jnp.maximum(state.prices, 1.0)
|
||||||
|
ux_volatility = jnp.where(
|
||||||
|
state.step_count > 0, jnp.mean(jnp.abs(prices - state.prices) / baseline), 0.0
|
||||||
|
)
|
||||||
|
|
||||||
n_candidates = params.alpha_candidates.shape[0]
|
n_candidates = params.alpha_candidates.shape[0]
|
||||||
cand_keys = jax.random.split(key, n_candidates)
|
cand_keys = jax.random.split(key, n_candidates)
|
||||||
evals = jax.vmap(
|
evals = jax.vmap(
|
||||||
lambda k, a: _evaluate_candidate(k, a, prices, params),
|
lambda k, a: _evaluate_candidate(k, a, prices, ux_volatility, params),
|
||||||
in_axes=(0, 0),
|
in_axes=(0, 0),
|
||||||
)(cand_keys, params.alpha_candidates)
|
)(cand_keys, params.alpha_candidates)
|
||||||
idx = jnp.argmin(evals.reward)
|
idx = jnp.argmin(evals.reward)
|
||||||
@@ -226,6 +240,7 @@ def step_env(
|
|||||||
agent_prob = evals.agent_prob[idx]
|
agent_prob = evals.agent_prob[idx]
|
||||||
leakage = evals.leakage[idx]
|
leakage = evals.leakage[idx]
|
||||||
discount = evals.discount[idx]
|
discount = evals.discount[idx]
|
||||||
|
ux_penalty = evals.ux_penalty[idx]
|
||||||
n_purchases = evals.n_purchases[idx]
|
n_purchases = evals.n_purchases[idx]
|
||||||
n_agents = evals.n_agents[idx]
|
n_agents = evals.n_agents[idx]
|
||||||
alpha_adv = params.alpha_candidates[idx]
|
alpha_adv = params.alpha_candidates[idx]
|
||||||
@@ -255,6 +270,8 @@ def step_env(
|
|||||||
"alpha_adv": alpha_adv,
|
"alpha_adv": alpha_adv,
|
||||||
"coi_leakage": leakage,
|
"coi_leakage": leakage,
|
||||||
"coi_discount": discount,
|
"coi_discount": discount,
|
||||||
|
"ux_penalty": ux_penalty,
|
||||||
|
"volatility": ux_volatility,
|
||||||
"n_purchases": n_purchases,
|
"n_purchases": n_purchases,
|
||||||
"n_agents": n_agents,
|
"n_agents": n_agents,
|
||||||
"avg_margin": avg_margin,
|
"avg_margin": avg_margin,
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Mapping, Sequence
|
from typing import Mapping
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@@ -484,11 +484,17 @@ if JAX_AVAILABLE:
|
|||||||
|
|
||||||
|
|
||||||
def reward_with_coi_penalty(
|
def reward_with_coi_penalty(
|
||||||
revenue, agent_prob: float, lambda_coi: float, info_value: float
|
revenue,
|
||||||
|
agent_prob: float,
|
||||||
|
lambda_coi: float,
|
||||||
|
info_value: float,
|
||||||
|
eta_ux: float = 0.0,
|
||||||
|
ux_volatility: float = 0.0,
|
||||||
):
|
):
|
||||||
leakage = agent_prob * info_value
|
leakage = agent_prob * info_value
|
||||||
discount = jnp.clip(1.0 - lambda_coi * leakage, 0.0, 1.0)
|
discount = jnp.clip(1.0 - lambda_coi * leakage, 0.0, 1.0)
|
||||||
return revenue * discount, leakage, discount
|
ux_penalty = eta_ux * revenue * ux_volatility
|
||||||
|
return revenue * discount - ux_penalty, leakage, discount, ux_penalty
|
||||||
|
|
||||||
|
|
||||||
if JAX_AVAILABLE:
|
if JAX_AVAILABLE:
|
||||||
|
|||||||
@@ -3,7 +3,10 @@ from typing import Dict
|
|||||||
|
|
||||||
|
|
||||||
def compute_agent_probability(
|
def compute_agent_probability(
|
||||||
trajectory: list, human_transitions: Dict, agent_transitions: Dict
|
trajectory: list,
|
||||||
|
human_transitions: Dict,
|
||||||
|
agent_transitions: Dict,
|
||||||
|
temperature: float = 1.0,
|
||||||
) -> float:
|
) -> float:
|
||||||
"""estimate agent probability via KL divergence between trajectory transitions and reference models
|
"""estimate agent probability via KL divergence between trajectory transitions and reference models
|
||||||
|
|
||||||
@@ -52,9 +55,9 @@ def compute_agent_probability(
|
|||||||
kl_agent = kl_div(empirical, agent_transitions)
|
kl_agent = kl_div(empirical, agent_transitions)
|
||||||
|
|
||||||
# convert to probability via softmax (lower KL = higher prob)
|
# convert to probability via softmax (lower KL = higher prob)
|
||||||
# agent_prob = exp(-kl_agent) / (exp(-kl_human) + exp(-kl_agent))
|
t = float(max(temperature, 1e-6))
|
||||||
exp_h = np.exp(-kl_human)
|
exp_h = np.exp(-kl_human / t)
|
||||||
exp_a = np.exp(-kl_agent)
|
exp_a = np.exp(-kl_agent / t)
|
||||||
return float(exp_a / (exp_h + exp_a + 1e-10))
|
return float(exp_a / (exp_h + exp_a + 1e-10))
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
"""shared factor definitions for experimental designs"""
|
"""shared factor definitions for experimental designs"""
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass
|
||||||
from typing import Callable, Any
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Factor:
|
class Factor:
|
||||||
|
|||||||
@@ -287,7 +287,7 @@ def _sb3_model_cls(algo: str):
|
|||||||
raise ValueError(f"unsupported algo '{algo}'")
|
raise ValueError(f"unsupported algo '{algo}'")
|
||||||
|
|
||||||
|
|
||||||
def train_qtable(cfg: dict) -> tuple[EventQTable, dict]:
|
def train_qtable(cfg: dict) -> tuple["EventQTable", dict]:
|
||||||
from .lib.discrete import EventQTable
|
from .lib.discrete import EventQTable
|
||||||
|
|
||||||
np.random.seed(int(cfg["seed"]))
|
np.random.seed(int(cfg["seed"]))
|
||||||
|
|||||||
@@ -48,6 +48,7 @@ class PHANTOM(gym.Env):
|
|||||||
robust_radius: float = 0.0,
|
robust_radius: float = 0.0,
|
||||||
robust_points: int = 5,
|
robust_points: int = 5,
|
||||||
info_value: float = 1.0,
|
info_value: float = 1.0,
|
||||||
|
eta_ux: float = 0.5,
|
||||||
action_levels: int = 9,
|
action_levels: int = 9,
|
||||||
action_scale_low: float = 0.9,
|
action_scale_low: float = 0.9,
|
||||||
action_scale_high: float = 1.1,
|
action_scale_high: float = 1.1,
|
||||||
@@ -75,6 +76,7 @@ class PHANTOM(gym.Env):
|
|||||||
self.robust_radius = max(0.0, float(robust_radius))
|
self.robust_radius = max(0.0, float(robust_radius))
|
||||||
self.robust_points = max(1, int(robust_points))
|
self.robust_points = max(1, int(robust_points))
|
||||||
self.info_value = float(info_value)
|
self.info_value = float(info_value)
|
||||||
|
self.eta_ux = float(eta_ux)
|
||||||
self.action_levels = max(2, int(action_levels))
|
self.action_levels = max(2, int(action_levels))
|
||||||
self._action_scales = np.linspace(
|
self._action_scales = np.linspace(
|
||||||
float(action_scale_low), float(action_scale_high), self.action_levels
|
float(action_scale_low), float(action_scale_high), self.action_levels
|
||||||
@@ -179,11 +181,26 @@ class PHANTOM(gym.Env):
|
|||||||
revenue = float(np.dot(prices, demand_arr))
|
revenue = float(np.dot(prices, demand_arr))
|
||||||
purchases = extract_purchases(trajectories)
|
purchases = extract_purchases(trajectories)
|
||||||
coi_mix = compute_uplift_coi(prices, purchases, self.baseline_prices)
|
coi_mix = compute_uplift_coi(prices, purchases, self.baseline_prices)
|
||||||
|
|
||||||
# multiplicative penalty so COI term scales with revenue magnitude
|
# multiplicative penalty so COI term scales with revenue magnitude
|
||||||
coi_leakage = float(agent_prob * self.info_value)
|
coi_leakage = float(agent_prob * self.info_value)
|
||||||
discount = float(np.clip(1.0 - self.lambda_coi * coi_leakage, 0.0, 1.0))
|
discount = float(np.clip(1.0 - self.lambda_coi * coi_leakage, 0.0, 1.0))
|
||||||
coi_penalty = revenue * (1.0 - discount) # absolute penalty in revenue units
|
coi_penalty = revenue * (1.0 - discount) # absolute penalty in revenue units
|
||||||
reward = revenue * discount
|
|
||||||
|
# calculate UX penalty based on price volatility
|
||||||
|
if len(self._price_history) > 0:
|
||||||
|
volatility = float(
|
||||||
|
np.mean(
|
||||||
|
np.abs(prices - self._price_history[-1])
|
||||||
|
/ np.maximum(self.baseline_prices, 1.0)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
volatility = 0.0
|
||||||
|
ux_penalty = self.eta_ux * revenue * volatility
|
||||||
|
|
||||||
|
reward = revenue * discount - ux_penalty
|
||||||
|
|
||||||
return reward, {
|
return reward, {
|
||||||
"revenue": revenue,
|
"revenue": revenue,
|
||||||
"coi_mix": float(coi_mix),
|
"coi_mix": float(coi_mix),
|
||||||
@@ -191,6 +208,8 @@ class PHANTOM(gym.Env):
|
|||||||
"coi_leakage": coi_leakage,
|
"coi_leakage": coi_leakage,
|
||||||
"coi_penalty": coi_penalty,
|
"coi_penalty": coi_penalty,
|
||||||
"coi_discount": discount,
|
"coi_discount": discount,
|
||||||
|
"ux_penalty": ux_penalty,
|
||||||
|
"volatility": volatility,
|
||||||
}
|
}
|
||||||
|
|
||||||
def _alpha_candidates(self) -> np.ndarray:
|
def _alpha_candidates(self) -> np.ndarray:
|
||||||
@@ -200,27 +219,34 @@ class PHANTOM(gym.Env):
|
|||||||
hi = min(1.0, self.nominal_alpha + self.robust_radius)
|
hi = min(1.0, self.nominal_alpha + self.robust_radius)
|
||||||
return np.linspace(lo, hi, self.robust_points)
|
return np.linspace(lo, hi, self.robust_points)
|
||||||
|
|
||||||
def _select_adversarial_alpha(
|
def _evaluate_candidate(
|
||||||
self, prices: np.ndarray
|
self, alpha: float, prices: np.ndarray
|
||||||
) -> tuple[float, dict, list, float]:
|
) -> tuple[float, dict, list, float]:
|
||||||
"""inner robust step: pick worst-case alpha and return its outcome directly to avoid double-sampling"""
|
self._set_market_mix(alpha)
|
||||||
candidates = self._alpha_candidates()
|
|
||||||
best_alpha, worst_reward = float(candidates[0]), np.inf
|
|
||||||
best_demand, best_trajectories, best_agent_prob = None, [], 0.0
|
|
||||||
for alpha in candidates:
|
|
||||||
self._set_market_mix(float(alpha))
|
|
||||||
demand = self.market.act(prices)
|
demand = self.market.act(prices)
|
||||||
trajectories = list(self.market.last_trajectories)
|
trajectories = list(self.market.last_trajectories)
|
||||||
agent_prob = self._compute_agent_prob(trajectories)
|
agent_prob = self._compute_agent_prob(trajectories)
|
||||||
reward, _ = self._compute_reward(prices, demand, agent_prob, trajectories)
|
reward, _ = self._compute_reward(prices, demand, agent_prob, trajectories)
|
||||||
if reward < worst_reward:
|
return reward, demand, trajectories, agent_prob
|
||||||
worst_reward = reward
|
|
||||||
best_alpha, best_demand, best_trajectories, best_agent_prob = (
|
def _select_adversarial_alpha(
|
||||||
float(alpha),
|
self, prices: np.ndarray
|
||||||
demand,
|
) -> tuple[float, dict, list, float]:
|
||||||
trajectories,
|
"""inner robust step: evaluate candidates and pick worst-case alpha"""
|
||||||
agent_prob,
|
candidates = self._alpha_candidates()
|
||||||
)
|
evaluations = [
|
||||||
|
(alpha, *self._evaluate_candidate(float(alpha), prices))
|
||||||
|
for alpha in candidates
|
||||||
|
]
|
||||||
|
|
||||||
|
# min over alpha in Wasserstein interval
|
||||||
|
best_eval = min(evaluations, key=lambda x: x[1]) # index 1 is reward
|
||||||
|
|
||||||
|
best_alpha = best_eval[0]
|
||||||
|
best_demand = best_eval[2]
|
||||||
|
best_trajectories = best_eval[3]
|
||||||
|
best_agent_prob = best_eval[4]
|
||||||
|
|
||||||
return best_alpha, best_demand, best_trajectories, best_agent_prob
|
return best_alpha, best_demand, best_trajectories, best_agent_prob
|
||||||
|
|
||||||
def _record_history(self):
|
def _record_history(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user