mirror of
https://github.com/velocitatem/PHANTOM.git
synced 2026-05-31 08:33:36 +00:00
chore: cleaning some code
This commit is contained in:
@@ -32,6 +32,7 @@ class EnvParams(NamedTuple):
|
||||
price_high: float
|
||||
lambda_coi: float
|
||||
info_value: float
|
||||
eta_ux: float
|
||||
robust_radius: float
|
||||
margin_floor: float
|
||||
margin_floor_patience: int
|
||||
@@ -63,6 +64,7 @@ class CandidateEval(NamedTuple):
|
||||
agent_prob: jax.Array
|
||||
leakage: jax.Array
|
||||
discount: jax.Array
|
||||
ux_penalty: jax.Array
|
||||
n_purchases: jax.Array
|
||||
n_agents: jax.Array
|
||||
|
||||
@@ -76,6 +78,7 @@ def make_env_params(
|
||||
robust_radius: float,
|
||||
robust_points: int,
|
||||
info_value: float,
|
||||
eta_ux: float = 0.5,
|
||||
action_levels: int,
|
||||
action_scale_low: float,
|
||||
action_scale_high: float,
|
||||
@@ -110,6 +113,7 @@ def make_env_params(
|
||||
price_high=float(price_high),
|
||||
lambda_coi=float(lambda_coi),
|
||||
info_value=float(info_value),
|
||||
eta_ux=float(eta_ux),
|
||||
robust_radius=float(robust_radius),
|
||||
margin_floor=float(margin_floor),
|
||||
margin_floor_patience=int(margin_floor_patience),
|
||||
@@ -143,6 +147,7 @@ def _evaluate_candidate(
|
||||
key: jax.Array,
|
||||
alpha_candidate: jax.Array,
|
||||
prices: jax.Array,
|
||||
ux_volatility: jax.Array,
|
||||
params: EnvParams,
|
||||
) -> CandidateEval:
|
||||
states, products, actors, lengths = _sample_sessions_jax(
|
||||
@@ -167,11 +172,13 @@ def _evaluate_candidate(
|
||||
|
||||
demand = weighted_demand(states, products, params.n_products, params.event_weights)
|
||||
revenue = revenue_from_demand(prices, demand)
|
||||
reward, leakage, discount = reward_with_coi_penalty(
|
||||
reward, leakage, discount, ux_penalty = reward_with_coi_penalty(
|
||||
revenue,
|
||||
agent_prob,
|
||||
params.lambda_coi,
|
||||
params.info_value,
|
||||
params.eta_ux,
|
||||
ux_volatility,
|
||||
)
|
||||
purchases = purchase_flags(states, params.purchase_mask)
|
||||
return CandidateEval(
|
||||
@@ -181,6 +188,7 @@ def _evaluate_candidate(
|
||||
agent_prob=agent_prob,
|
||||
leakage=leakage,
|
||||
discount=discount,
|
||||
ux_penalty=ux_penalty,
|
||||
n_purchases=jnp.sum(purchases.astype(jnp.float32)),
|
||||
n_agents=jnp.sum(actors.astype(jnp.float32)),
|
||||
)
|
||||
@@ -212,10 +220,16 @@ def step_env(
|
||||
params: EnvParams,
|
||||
) -> tuple[jax.Array, EnvState, jax.Array, jax.Array, dict[str, jax.Array]]:
|
||||
prices = _decode_action(state.prices, action, params)
|
||||
|
||||
baseline = jnp.maximum(state.prices, 1.0)
|
||||
ux_volatility = jnp.where(
|
||||
state.step_count > 0, jnp.mean(jnp.abs(prices - state.prices) / baseline), 0.0
|
||||
)
|
||||
|
||||
n_candidates = params.alpha_candidates.shape[0]
|
||||
cand_keys = jax.random.split(key, n_candidates)
|
||||
evals = jax.vmap(
|
||||
lambda k, a: _evaluate_candidate(k, a, prices, params),
|
||||
lambda k, a: _evaluate_candidate(k, a, prices, ux_volatility, params),
|
||||
in_axes=(0, 0),
|
||||
)(cand_keys, params.alpha_candidates)
|
||||
idx = jnp.argmin(evals.reward)
|
||||
@@ -226,6 +240,7 @@ def step_env(
|
||||
agent_prob = evals.agent_prob[idx]
|
||||
leakage = evals.leakage[idx]
|
||||
discount = evals.discount[idx]
|
||||
ux_penalty = evals.ux_penalty[idx]
|
||||
n_purchases = evals.n_purchases[idx]
|
||||
n_agents = evals.n_agents[idx]
|
||||
alpha_adv = params.alpha_candidates[idx]
|
||||
@@ -255,6 +270,8 @@ def step_env(
|
||||
"alpha_adv": alpha_adv,
|
||||
"coi_leakage": leakage,
|
||||
"coi_discount": discount,
|
||||
"ux_penalty": ux_penalty,
|
||||
"volatility": ux_volatility,
|
||||
"n_purchases": n_purchases,
|
||||
"n_agents": n_agents,
|
||||
"avg_margin": avg_margin,
|
||||
|
||||
@@ -4,7 +4,7 @@ from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from typing import Mapping, Sequence
|
||||
from typing import Mapping
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -484,11 +484,17 @@ if JAX_AVAILABLE:
|
||||
|
||||
|
||||
def reward_with_coi_penalty(
|
||||
revenue, agent_prob: float, lambda_coi: float, info_value: float
|
||||
revenue,
|
||||
agent_prob: float,
|
||||
lambda_coi: float,
|
||||
info_value: float,
|
||||
eta_ux: float = 0.0,
|
||||
ux_volatility: float = 0.0,
|
||||
):
|
||||
leakage = agent_prob * info_value
|
||||
discount = jnp.clip(1.0 - lambda_coi * leakage, 0.0, 1.0)
|
||||
return revenue * discount, leakage, discount
|
||||
ux_penalty = eta_ux * revenue * ux_volatility
|
||||
return revenue * discount - ux_penalty, leakage, discount, ux_penalty
|
||||
|
||||
|
||||
if JAX_AVAILABLE:
|
||||
|
||||
Reference in New Issue
Block a user