mirror of
https://github.com/velocitatem/PHANTOM.git
synced 2026-05-31 08:33:36 +00:00
chore: fixing discretization of actions
This commit is contained in:
@@ -5,3 +5,4 @@ from .wrappers import EconomicMetricsWrapper
|
||||
from .callbacks import MetricsCallback, EvalMetricsCallback
|
||||
from .providers import ProviderBenchmark, ProviderResult, BenchmarkConfig
|
||||
from .coi import compute_uplift_coi, extract_purchases, compute_agent_probability
|
||||
from .discrete import EventQTable
|
||||
|
||||
@@ -70,7 +70,14 @@ def trajectory_to_events(trajectory: list) -> list:
|
||||
|
||||
def adjust_behavior_to_condition(condition, transition_matrix):
|
||||
# expand NxN transition matrix to (N*P)x(N*P) weighted by demand condition
|
||||
cond_norm = condition / np.sum(condition)
|
||||
condition = np.asarray(condition, dtype=float)
|
||||
condition = np.nan_to_num(condition, nan=0.0, posinf=0.0, neginf=0.0)
|
||||
condition = np.clip(condition, 0.0, None)
|
||||
s = float(np.sum(condition))
|
||||
if not np.isfinite(s) or s <= 0:
|
||||
cond_norm = np.full(len(condition), 1.0 / max(len(condition), 1), dtype=float)
|
||||
else:
|
||||
cond_norm = condition / s
|
||||
n_products = len(condition)
|
||||
base_vals = transition_matrix.values
|
||||
base_cols, base_rows = (
|
||||
@@ -91,10 +98,12 @@ def sample_behavior(condition, human=True, max_len=40):
|
||||
|
||||
trajectory = [np.random.choice(adjusted_transitions.index)]
|
||||
while len(trajectory) < max_len and "checkout" not in trajectory[-1]:
|
||||
probs = adjusted_transitions.loc[trajectory[-1]].values
|
||||
probs = np.asarray(adjusted_transitions.loc[trajectory[-1]].values, dtype=float)
|
||||
probs = np.nan_to_num(probs, nan=0.0, posinf=0.0, neginf=0.0)
|
||||
probs = np.clip(probs, 0.0, None)
|
||||
s = float(np.sum(probs))
|
||||
sample = np.random.choice(
|
||||
adjusted_transitions.columns,
|
||||
p=probs / np.sum(probs) if np.sum(probs) > 0 else None,
|
||||
adjusted_transitions.columns, p=(probs / s) if s > 0 else None
|
||||
)
|
||||
trajectory.append(sample)
|
||||
return trajectory
|
||||
|
||||
Reference in New Issue
Block a user