mirror of
https://github.com/velocitatem/PHANTOM.git
synced 2026-05-31 08:33:36 +00:00
fix: coi better defined and aligned and sac improved
This commit is contained in:
@@ -14,7 +14,7 @@ from dataclasses import dataclass, field
|
|||||||
from typing import Dict, List, Tuple
|
from typing import Dict, List, Tuple
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from .coi import COIWindow, compute_coi_window, coi_erosion
|
from .coi import COIWindow, compute_coi_window
|
||||||
from .separability import TRANS_H, TRANS_A, kl_div, build_kernel, compute_divergence, estimate_alpha
|
from .separability import TRANS_H, TRANS_A, kl_div, build_kernel, compute_divergence, estimate_alpha
|
||||||
|
|
||||||
ACTION_WEIGHTS = {"add_to_cart": 0.8, "checkout": 0.9, "purchase": 1.0, "view": 0.15, "detail": 0.25, "hover": 0.3, "start": 0.05, "end": 0.0}
|
ACTION_WEIGHTS = {"add_to_cart": 0.8, "checkout": 0.9, "purchase": 1.0, "view": 0.15, "detail": 0.25, "hover": 0.3, "start": 0.05, "end": 0.0}
|
||||||
@@ -209,7 +209,8 @@ if __name__ == "__main__":
|
|||||||
print(f'sessions: {len(sessions)}, agents: {sum(1 for s in sessions if s.actor=="A")}')
|
print(f'sessions: {len(sessions)}, agents: {sum(1 for s in sessions if s.actor=="A")}')
|
||||||
|
|
||||||
for n in [1, 5, 10, 50, 100]:
|
for n in [1, 5, 10, 50, 100]:
|
||||||
print(f'N={n:3d} agents -> COI erosion: {coi_erosion(n, price_std=5.0):.3f}')
|
# theoretical: erosion = 1 - 2/(N+1) for uniform order statistic
|
||||||
|
print(f'N={n:3d} agents -> COI erosion: {1.0 - 2.0/(n+1):.3f}')
|
||||||
|
|
||||||
events = [Event('view', 0, 20.0, 0.1), Event('detail', 0, 20.0, 0.5), Event('cart', 0, 20.0, 1.0), Event('purchase', 0, 20.0, 2.0)]
|
events = [Event('view', 0, 20.0, 0.1), Event('detail', 0, 20.0, 0.5), Event('cart', 0, 20.0, 1.0), Event('purchase', 0, 20.0, 2.0)]
|
||||||
print(f'human-like session alpha_hat: {estimate_alpha(Session(sid="test", events=events, actor="H")):.3f}')
|
print(f'human-like session alpha_hat: {estimate_alpha(Session(sid="test", events=events, actor="H")):.3f}')
|
||||||
|
|||||||
@@ -157,7 +157,7 @@ class PricingEnv(gym.Env if HAS_GYM else object):
|
|||||||
"n_purchases": int(np.sum(purchases)),
|
"n_purchases": int(np.sum(purchases)),
|
||||||
"avg_margin": float(np.mean((prices - self._sys.costs) / self._sys.costs)),
|
"avg_margin": float(np.mean((prices - self._sys.costs) / self._sys.costs)),
|
||||||
"n_sessions": len(demand), "n_agents": n_agents, "price_std": float(np.std(prices)),
|
"n_sessions": len(demand), "n_agents": n_agents, "price_std": float(np.std(prices)),
|
||||||
"coi_erosion": coi_erosion(max(1, n_agents), float(np.std(prices))),
|
"coi_erosion": coi_erosion(coi.policy, coi.agent),
|
||||||
"coi_policy": float(coi.policy), "coi_agent": float(coi.agent),
|
"coi_policy": float(coi.policy), "coi_agent": float(coi.agent),
|
||||||
"coi_leakage": float(coi.leak), "coi_survival": float(coi.survival_ratio),
|
"coi_leakage": float(coi.leak), "coi_survival": float(coi.survival_ratio),
|
||||||
"cumulative_reward": sum(self._episode_rewards), "step": self._t,
|
"cumulative_reward": sum(self._episode_rewards), "step": self._t,
|
||||||
|
|||||||
@@ -28,7 +28,6 @@ except ImportError:
|
|||||||
HAS_TB = False
|
HAS_TB = False
|
||||||
|
|
||||||
from .simplified_env import PricingEnv, EnvConfig, make_env, adaptive_policy, fixed_price_policy, random_policy
|
from .simplified_env import PricingEnv, EnvConfig, make_env, adaptive_policy, fixed_price_policy, random_policy
|
||||||
from .coi import coi_erosion
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -214,7 +213,7 @@ def train(cfg: ExperimentConfig) -> Dict[str, Any]:
|
|||||||
common = dict(verbose=1, seed=cfg.seed, tensorboard_log=str(log_path), device="auto")
|
common = dict(verbose=1, seed=cfg.seed, tensorboard_log=str(log_path), device="auto")
|
||||||
model = {
|
model = {
|
||||||
"ppo": lambda: PPO("MlpPolicy", train_env, learning_rate=3e-4, n_steps=2048, batch_size=64, n_epochs=10, gamma=0.99, gae_lambda=0.95, clip_range=0.2, ent_coef=0.01, **common),
|
"ppo": lambda: PPO("MlpPolicy", train_env, learning_rate=3e-4, n_steps=2048, batch_size=64, n_epochs=10, gamma=0.99, gae_lambda=0.95, clip_range=0.2, ent_coef=0.01, **common),
|
||||||
"sac": lambda: SAC("MlpPolicy", train_env, learning_rate=3e-4, buffer_size=100_000, batch_size=256, tau=0.005, gamma=0.99, **common),
|
"sac": lambda: SAC("MlpPolicy", train_env, learning_rate=1e-4, buffer_size=50_000, batch_size=512, tau=0.02, gamma=0.99, learning_starts=1000, ent_coef="auto_0.1", train_freq=4, **common),
|
||||||
"a2c": lambda: A2C("MlpPolicy", train_env, learning_rate=7e-4, n_steps=5, gamma=0.99, **common),
|
"a2c": lambda: A2C("MlpPolicy", train_env, learning_rate=7e-4, n_steps=5, gamma=0.99, **common),
|
||||||
}[cfg.algo.lower()]()
|
}[cfg.algo.lower()]()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user