catchup: rogue scripts

This commit is contained in:
2026-02-27 12:45:46 +01:00
parent e8a9716f69
commit 5444a4ea13
27 changed files with 6908 additions and 2 deletions

17
.dockerignore Normal file
View File

@@ -0,0 +1,17 @@
.git
.venv
.venv-tpu
**/__pycache__
**/*.pyc
**/*.pyo
**/.pytest_cache
**/.mypy_cache
**/.ruff_cache
**/.ipynb_checkpoints
wandb
build
paper/build
paper/build-cais
node_modules
**/node_modules
*.egg-info

18
.env.sweep.example Normal file
View File

@@ -0,0 +1,18 @@
# Copy this file to .env.sweep and fill in values.
# Required for wandb runs and sweep agent workers.
WANDB_API_KEY=
WANDB_ENTITY=
WANDB_PROJECT=phantom-pricing
# Required for private repo bootstrap workers.
GITHUB_TOKEN=
# Optional defaults for bootstrap mode.
# REPO_URL=https://github.com/org/repo.git
# BRANCH=main
# WORKDIR=$HOME/PHANTOM-agent
# SWEEP_ID=entity/project/id
# AGENT_COUNT=0
# AGENT_LOOP=1
# RETRY_SECONDS=20

57
.gitignore vendored
View File

@@ -1,21 +1,50 @@
# environment and secrets
**/.env
.env.*
!.env.*.example
**/.venv
# python build/cache artifacts
**/__pycache__
phantom.egg-info/
*.egg-info/
# notebook artifacts
**/.ipynb_checkpoints/
**/.virtual_documents/
# editor/tool state
**/.pdf-view-restore
.nextstep
.ignore-gitlogue
.cloudflare
# generated svg/graphics
**/session_*.svg
**/*graph.svg
**/auto/*.el
# misc generated
*.old
**/package-lock.json
**/*.parquet
**/_build/
# paper build artifacts
paper/src/bib/auto
**/_build/
paper/src/auto/*
paper/src/bib/auto
paper/template/*
paper/build-cais/
paper/src/main.pdf
paper/src/main-blx.bib
paper/src/svg-inkscape/
paper/src/mirrors/
paper/variations/
paper/src/graphics/test_*.png
thesis-latest.pdf
# experiment run artifacts and logs
docs/goals/*.md
PHANTOM.wiki/
experiments/airflow/logs/*
@@ -23,11 +52,35 @@ experiments/airflow/logs/scheduler/
experiments/airflow/logs/dag_processor_manager/
experiments/collected_data/
experiments/agents/collected_data/
tests/e2e/test-results/
tests/e2e/node_modules/**
# rl/sim run outputs
sim/rl/behavior_loader/*.dot
sim/rl/behavior_loader/*.png
sim/rl/behavior_loader/*.svg
sim/rl/behavior_loader/*.pdf
tests/e2e/node_modules/**
sim/rl/runs/
lab/case/thesis/runs*/
sim/case/thesis_simplified/runs*/
# model binaries
engine/models/*.zip
*.zip
# wandb local state
wandb/
# data directory (large datasets)
data/
# ktem local app data
ktem_app_data/
# generated visualization pdfs
*_mdp_viz.pdf
phantom_env_comparison.png
sim/phantom_env_comparison.png
# web clone
PHANTOM_web/*

1
AGENTS.md Symbolic link
View File

@@ -0,0 +1 @@
CLAUDE.md

View File

@@ -0,0 +1,93 @@
method: bayes
metric:
name: sweep/score
goal: maximize
command:
- ${env}
- python
- -m
- engine.train
parameters:
# fixed: always use JAX backend so TPU chips are actually exercised
use_jax:
value: true
# all four algos have JAX implementations
algo:
values: [ppo, a2c, dqn, qtable]
total_timesteps:
values: [50000, 80000, 120000]
checkpoint_interval:
value: 200000
seed:
values: [13, 42, 77]
n_products:
values: [8, 10, 12]
# COI framework parameters -- primary research variables
alpha:
distribution: uniform
min: 0.1
max: 0.6
lambda_coi:
distribution: uniform
min: 0.05
max: 0.6
robust_radius:
distribution: uniform
min: 0.0
max: 0.3
robust_points:
values: [3, 5, 7]
info_value:
distribution: uniform
min: 0.5
max: 2.0
revenue_weight:
values: [0.005, 0.01, 0.02]
# shared hyperparameters
learning_rate:
distribution: log_uniform_values
min: 1.0e-5
max: 1.0e-3
gamma:
values: [0.97, 0.99, 0.995]
# JAX parallelism -- key lever for TPU throughput
jax_num_envs:
values: [8, 16, 32]
jax_num_steps:
values: [64, 128, 256]
jax_num_minibatches:
values: [2, 4, 8]
jax_update_epochs:
values: [2, 4, 8]
# PPO/A2C specific
gae_lambda:
values: [0.9, 0.95, 0.98]
clip_range:
values: [0.1, 0.2, 0.3]
ent_coef:
values: [0.0, 0.005, 0.01]
# DQN specific
buffer_size:
values: [20000, 50000, 100000]
batch_size:
values: [128, 256, 512]
learning_starts:
values: [500, 1000, 3000]
exploration_fraction:
values: [0.1, 0.2, 0.3]
exploration_final_eps:
values: [0.01, 0.03, 0.05]
# QTable specific
q_lr:
values: [0.03, 0.05, 0.1, 0.2]
eps_end:
values: [0.02, 0.05, 0.1]
eps_decay:
values: [0.999, 0.9995, 0.9999]
# action space
action_levels:
values: [7, 9, 11]
action_scale_low:
values: [0.75, 0.8, 0.85]
action_scale_high:
values: [1.15, 1.2, 1.25]

View File

@@ -0,0 +1,64 @@
method: bayes
metric:
name: sweep/score
goal: maximize
command:
- ${env}
- python
- -m
- engine.train
parameters:
use_jax:
value: true
# pmap requires all workers to compile the same computation graph shape,
# so structural params are fixed -- only research/scalar params are swept
algo:
values: [ppo, a2c]
jax_num_envs:
value: 32
jax_num_steps:
value: 128
jax_num_minibatches:
value: 4
jax_update_epochs:
value: 4
total_timesteps:
value: 100000
checkpoint_interval:
value: 200000
n_products:
value: 10
action_levels:
value: 9
# research parameters -- primary sweep targets
alpha:
distribution: uniform
min: 0.1
max: 0.6
lambda_coi:
distribution: uniform
min: 0.05
max: 0.6
robust_radius:
distribution: uniform
min: 0.0
max: 0.3
info_value:
distribution: uniform
min: 0.5
max: 2.0
revenue_weight:
values: [0.005, 0.01, 0.02]
# training hyperparameters
learning_rate:
distribution: log_uniform_values
min: 1.0e-5
max: 1.0e-3
gamma:
values: [0.97, 0.99, 0.995]
gae_lambda:
values: [0.9, 0.95, 0.98]
clip_range:
values: [0.1, 0.2, 0.3]
ent_coef:
values: [0.0, 0.005, 0.01]

130
engine/wandb_checkpoint.py Normal file
View File

@@ -0,0 +1,130 @@
from __future__ import annotations
import hashlib
import json
import re
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Any, Mapping
try:
import wandb
from wandb.errors import CommError
HAS_WANDB = True
except ImportError:
HAS_WANDB = False
wandb = None # type: ignore[assignment]
CommError = RuntimeError # type: ignore[assignment]
def _safe_value(value: Any) -> Any:
if isinstance(value, (str, int, float, bool)) or value is None:
return value
if isinstance(value, (list, tuple)):
return [_safe_value(v) for v in value]
if isinstance(value, dict):
return {str(k): _safe_value(value[k]) for k in sorted(value)}
return str(value)
def _safe_scope(scope: str | None) -> str:
raw = "manual" if scope in (None, "") else str(scope)
cleaned = re.sub(r"[^A-Za-z0-9_.-]+", "-", raw).strip("-")
return cleaned or "manual"
def checkpoint_artifact_name(
cfg: Mapping[str, Any], *, backend: str, sweep_id: str | None = None
) -> str:
payload = {k: _safe_value(cfg[k]) for k in sorted(cfg)}
scope = _safe_scope(sweep_id)
canonical = json.dumps(
{"backend": backend, "scope": scope, "cfg": payload},
sort_keys=True,
separators=(",", ":"),
)
digest = hashlib.sha1(canonical.encode("utf-8")).hexdigest()[:14]
return f"phantom-{backend}-ckpt-{scope}-{digest}"[:128]
def _is_missing_artifact_error(exc: Exception) -> bool:
if isinstance(exc, CommError):
msg = str(exc).lower()
return "not found" in msg or "does not exist" in msg
return False
def download_latest_checkpoint(
artifact_name: str, *, file_name: str
) -> tuple[Path, dict[str, Any]] | None:
if not HAS_WANDB or wandb.run is None:
return None
try:
artifact = wandb.run.use_artifact(f"{artifact_name}:latest")
except Exception as exc:
if _is_missing_artifact_error(exc):
return None
raise
directory = Path(artifact.download())
checkpoint_path = directory / file_name
if not checkpoint_path.exists():
return None
metadata = dict(getattr(artifact, "metadata", {}) or {})
return checkpoint_path, metadata
def _aliases_from_metadata(metadata: dict[str, Any] | None) -> list[str]:
aliases = ["latest"]
if metadata is None:
return aliases
if "step" in metadata:
try:
aliases.append(f"step-{int(metadata['step'])}")
except (TypeError, ValueError):
pass
return aliases
def log_checkpoint_bytes(
artifact_name: str,
*,
file_name: str,
payload: bytes,
metadata: dict[str, Any] | None = None,
) -> bool:
if not HAS_WANDB or wandb.run is None:
return False
with TemporaryDirectory(prefix="phantom-ckpt-") as tmpdir:
path = Path(tmpdir) / file_name
path.write_bytes(payload)
artifact = wandb.Artifact(
name=artifact_name,
type="checkpoint",
metadata=metadata or {},
)
artifact.add_file(path.as_posix(), name=file_name)
wandb.log_artifact(artifact, aliases=_aliases_from_metadata(metadata))
return True
def log_checkpoint_file(
artifact_name: str,
*,
file_path: str | Path,
artifact_file_name: str,
metadata: dict[str, Any] | None = None,
) -> bool:
if not HAS_WANDB or wandb.run is None:
return False
src = Path(file_path)
if not src.exists():
return False
artifact = wandb.Artifact(
name=artifact_name,
type="checkpoint",
metadata=metadata or {},
)
artifact.add_file(src.as_posix(), name=artifact_file_name)
wandb.log_artifact(artifact, aliases=_aliases_from_metadata(metadata))
return True

View File

@@ -0,0 +1,269 @@
"""
Session-Aware Pricing DAG
THIS implements the core pricing computation (policy layer).
Flow: τ → θ̂ → D → p*
1. Fetch recent sessions from Kafka (last 10 active)
2. Extract features per session (τ → θ̂)
3. Map features to demand proxy (θ̂ → D)
4. Compute optimal prices (D → p*)
5. Write to Redis session:{sessionId}:prices
Scheduled: every 1 minute when enabled
"""
from airflow import DAG
from airflow.operators.python import PythonOperator
from airflow.utils.dates import days_ago
from datetime import timedelta
import pandas as pd
import numpy as np
import logging
import sys
import pickle
sys.path.insert(0, '/opt/airflow')
from procesing.context import PipelineContext
from procesing.providers import SupabaseProvider, BackendAPIProvider
from procesing.steps.session import ExtractSessionFeaturesStep
from procesing.pricers.simple import SimpleSurgePricer, session_features_to_demand
from procesing.pricing import StateSpace
from lib.model_registry import ModelRegistry
DEFAULT_ARGS = {
'owner': 'phantom-research',
'depends_on_past': False,
'email_on_failure': False,
'email_on_retry': False,
'retries': 1,
'retry_delay': timedelta(seconds=30),
}
class CompositeProvider(SupabaseProvider, BackendAPIProvider):
def __init__(self):
SupabaseProvider.__init__(self)
BackendAPIProvider.__init__(self)
def _get_context(store_mode: str = 'hotel') -> PipelineContext:
return PipelineContext(provider=CompositeProvider(), store_mode=store_mode)
def fetch_recent_sessions(**kwargs):
"""
Task: Fetch last N active sessions from Kafka.
Returns: DataFrame of interaction events for recent sessions.
"""
dag_conf = kwargs.get('dag_run').conf if kwargs.get('dag_run') else {}
store_mode = dag_conf.get('store_mode', 'hotel')
session_limit = dag_conf.get('session_limit', 10)
ctx = _get_context(store_mode)
provider = ctx.provider
# fetch all recent interactions from Kafka
try:
interactions_df = provider.fetch_kafka_topic("user-interactions")
except Exception as e:
logging.error(f"Failed to fetch interactions: {e}")
kwargs['ti'].xcom_push(key='sessions_data', value=pickle.dumps(pd.DataFrame()))
return 0
if interactions_df.empty or 'sessionId' not in interactions_df.columns:
kwargs['ti'].xcom_push(key='sessions_data', value=pickle.dumps(pd.DataFrame()))
return 0
# identify last N active sessions (most recent by event count)
recent_sessions = interactions_df['sessionId'].value_counts().head(session_limit).index.tolist()
# filter to only those sessions
filtered_df = interactions_df[interactions_df['sessionId'].isin(recent_sessions)].copy()
kwargs['ti'].xcom_push(key='sessions_data', value=pickle.dumps(filtered_df))
kwargs['ti'].xcom_push(key='session_ids', value=recent_sessions)
logging.info(f"Fetched {len(filtered_df)} events for {len(recent_sessions)} sessions")
return len(recent_sessions)
def extract_session_features(**kwargs):
"""
Task: Extract behavioral features from session trajectories.
THIS implements τ → θ̂ transformation.
"""
ti = kwargs['ti']
sessions_df = pickle.loads(ti.xcom_pull(key='sessions_data'))
if sessions_df.empty:
ti.xcom_push(key='session_features', value=pickle.dumps(pd.DataFrame()))
return 0
dag_conf = kwargs.get('dag_run').conf if kwargs.get('dag_run') else {}
ctx = _get_context(dag_conf.get('store_mode', 'hotel'))
# extract features using vectorized pipeline
feature_extractor = ExtractSessionFeaturesStep(ctx)
features_df = feature_extractor.transform(sessions_df)
ti.xcom_push(key='session_features', value=pickle.dumps(features_df))
logging.info(f"Extracted {len(features_df.columns)} features for {len(features_df)} sessions")
logging.info(f"Feature columns: {list(features_df.columns)}")
logging.info(f"Sample features (first session):\n{features_df.iloc[0].to_dict()}")
return len(features_df)
def compute_session_prices(**kwargs):
"""
Task: Compute optimal prices for each session.
THIS implements θ̂ → D → p* transformation.
"""
ti = kwargs['ti']
features_df = pickle.loads(ti.xcom_pull(key='session_features'))
if features_df.empty:
ti.xcom_push(key='price_results', value=pickle.dumps({}))
return 0
dag_conf = kwargs.get('dag_run').conf if kwargs.get('dag_run') else {}
store_mode = dag_conf.get('store_mode', 'hotel')
ctx = _get_context(store_mode)
# fetch product catalog for base prices
products_df = ctx.provider.fetch_products(store_mode)
if products_df.empty:
logging.error("No products found in catalog")
ti.xcom_push(key='price_results', value=pickle.dumps({}))
return 0
products_df['base_price'] = products_df['metadata'].apply(
lambda m: m.get('base_price', 100.0) if isinstance(m, dict) else 100.0
)
# initialize pricing model
pricer = SimpleSurgePricer(
high_threshold=dag_conf.get('high_threshold', 10),
low_threshold=dag_conf.get('low_threshold', 2),
surge_multiplier=dag_conf.get('surge_multiplier', 1.15),
discount_multiplier=dag_conf.get('discount_multiplier', 0.95)
)
pricer.fit(products_df)
# compute prices per session
price_results = {}
n_products = len(products_df)
logging.info(f"Starting price computation for {len(features_df)} sessions, {n_products} products")
logging.info(f"Pricer config: high_thresh={pricer.high_threshold}, low_thresh={pricer.low_threshold}, surge_mult={pricer.surge_multiplier}")
for idx, session_row in features_df.iterrows():
session_id = session_row.get('sessionId')
if not session_id:
continue
# map features to demand proxy (θ̂ → D)
session_features_single = pd.DataFrame([session_row])
demand_proxy = session_features_to_demand(session_features_single)
logging.info(f"[Session {session_id}] Features → Demand: {demand_proxy:.2f}")
logging.info(f"[Session {session_id}] Key features: velocity={session_row.get('interaction_velocity', 0):.2f}, cart_ratio={session_row.get('cart_to_view_ratio', 0):.2f}, item_views={session_row.get('item_views', 0)}")
# build state space
state_space = StateSpace(
demand=np.full(n_products, demand_proxy), # broadcast session demand to all products
prices=products_df['base_price'].values,
session_features=session_features_single
)
# compute optimal prices (D → p*)
optimal_prices = pricer.predict(state_space)
base_avg = products_df['base_price'].mean()
optimal_avg = optimal_prices.mean()
price_change_pct = ((optimal_avg - base_avg) / base_avg) * 100
logging.info(f"[Session {session_id}] Price adjustment: base_avg={base_avg:.2f}, optimal_avg={optimal_avg:.2f}, change={price_change_pct:+.1f}%")
# store as dict {productId: price}
price_map = {
str(products_df.iloc[i]['id']): float(optimal_prices[i])
for i in range(n_products)
}
price_results[session_id] = price_map
ti.xcom_push(key='price_results', value=pickle.dumps(price_results))
logging.info(f"Computed prices for {len(price_results)} sessions, {n_products} products each")
return len(price_results)
def publish_to_registry(**kwargs):
"""
Task: Write session prices to Redis registry.
THIS is the write path: prices → session:{sessionId}:prices
"""
ti = kwargs['ti']
price_results = pickle.loads(ti.xcom_pull(key='price_results'))
if not price_results:
logging.warning("No prices to publish")
return 0
registry = ModelRegistry()
ttl = kwargs.get('dag_run').conf.get('ttl', 1800) if kwargs.get('dag_run') and kwargs.get('dag_run').conf else 1800
published_count = 0
for session_id, price_map in price_results.items():
registry.set_session_prices(session_id, price_map, ttl=ttl)
published_count += 1
logging.info(f"Published prices for {published_count} sessions to registry (TTL={ttl}s)")
return {
'sessions_published': published_count,
'products_per_session': len(next(iter(price_results.values()))) if price_results else 0,
'status': 'success'
}
# DAG definition
with DAG(
'session_pricing_pipeline',
default_args=DEFAULT_ARGS,
description='Session-aware pricing: extract features → compute prices → publish to registry',
schedule_interval='*/1 * * * *', # every 1 minute
start_date=days_ago(1),
catchup=False,
max_active_runs=1,
tags=['pricing', 'session-aware', 'research', 'real-time'],
) as dag:
t_fetch_sessions = PythonOperator(
task_id='fetch_recent_sessions',
python_callable=fetch_recent_sessions,
provide_context=True,
)
t_extract_features = PythonOperator(
task_id='extract_session_features',
python_callable=extract_session_features,
provide_context=True,
)
t_compute_prices = PythonOperator(
task_id='compute_session_prices',
python_callable=compute_session_prices,
provide_context=True,
)
t_publish = PythonOperator(
task_id='publish_to_registry',
python_callable=publish_to_registry,
provide_context=True,
)
# linear dependency: fetch → extract → compute → publish
t_fetch_sessions >> t_extract_features >> t_compute_prices >> t_publish

View File

@@ -0,0 +1 @@
from .encoder import Window, extract_windows, build_windows, WindowDataset, PrototypeClassifier, train, loocv

View File

@@ -0,0 +1,210 @@
"""Contrastive encoder via trajectory windowing. Classification by prototype distance."""
import sys
sys.path.insert(0, "/home/velocitatem/Documents/Projects/PHANTOM/sim/rl/behavior_loader")
sys.path.insert(0, "/home/velocitatem/Documents/Projects/PHANTOM/experiments/ml")
from sim.rl.behavior_loader.loader import JointLoader, PayloadModel
from arch import TrajectoryEncoder, featurize_trajectory, nt_xent_loss
from typing import List, Dict, Tuple
from dataclasses import dataclass
from datetime import datetime
import numpy as np, torch, torch.nn.functional as F, random, optuna
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam
from torch.utils.tensorboard import SummaryWriter
RUNS = "/home/velocitatem/Documents/Projects/PHANTOM/experiments/ml/runs"
AGENT_DIR = "/home/velocitatem/Documents/Projects/PHANTOM/experiments/agents/collected_data/"
HUMAN_DIR = "/home/velocitatem/Documents/Projects/PHANTOM/experiments/collected_data/"
@dataclass
class Window:
events: List[PayloadModel]
traj_id: str
label: int # 0=human, 1=agent
def extract_windows(events: List[PayloadModel], traj_id: str, label: int,
sizes: List[int] = [5, 10, 15], stride: int = 2) -> List[Window]:
"""Multi-scale overlapping windows from trajectory"""
n = len(events)
wins = [Window(events[i:i+s], traj_id, label) for s in sizes if n >= s for i in range(0, n-s+1, stride)]
if n >= 3: wins.append(Window(events, traj_id, label)) # full traj
return wins
def build_windows(data: Dict[str, List], sizes=[5,10,15], stride=2) -> List[Window]:
return [w for tid, evts in data.items()
for w in extract_windows(evts, tid, 0 if tid.startswith('human_') else 1, sizes, stride)]
class WindowDataset(Dataset):
"""Yields (anchor, positive) pairs from same class"""
def __init__(self, windows: List[Window], dim: int = 64):
self.wins, self.dim = windows, dim
self.by_label = {0: [i for i,w in enumerate(windows) if w.label==0],
1: [i for i,w in enumerate(windows) if w.label==1]}
self.by_traj = {}
for i, w in enumerate(windows): self.by_traj.setdefault(w.traj_id, []).append(i)
def __len__(self): return len(self.wins)
def _feat(self, evts): return featurize_trajectory(evts, None, self.dim)
def _aug(self, evts): # subsample 70-100%
if len(evts) < 4: return evts
k = max(3, int(len(evts) * random.uniform(0.7, 1.0)))
start = random.randint(0, len(evts) - k)
return evts[start:start+k]
def __getitem__(self, idx):
w = self.wins[idx]
pool = [i for i in self.by_label[w.label] if self.wins[i].traj_id != w.traj_id]
pos_idx = random.choice(pool) if pool else idx
a = torch.tensor(self._feat(self._aug(w.events)), dtype=torch.float32)
p = torch.tensor(self._feat(self._aug(self.wins[pos_idx].events)), dtype=torch.float32)
return a, p, w.label
class PrototypeClassifier:
"""Classify by distance to class centroids"""
def __init__(self, encoder: TrajectoryEncoder, device = 'cuda', dim=64):
self.enc, self.dev, self.dim = encoder, device, dim
self.centroids = {0: None, 1: None}
def fit(self, windows: List[Window]):
self.enc.eval()
embs = {0: [], 1: []}
with torch.no_grad():
for w in windows:
x = torch.tensor(featurize_trajectory(w.events, None, self.dim), dtype=torch.float32)
z = self.enc(x.unsqueeze(0).unsqueeze(1).to(self.dev))
embs[w.label].append(z)
self.centroids = {k: torch.cat(v).mean(0, keepdim=True) if v else None for k, v in embs.items()}
return self
def predict(self, events: List[PayloadModel]) -> Tuple[int, float, Dict]:
"""Returns (pred, confidence, debug). Confidence via softmax over -distances."""
self.enc.eval()
with torch.no_grad():
x = torch.tensor(featurize_trajectory(events, None, self.dim), dtype=torch.float32)
z = self.enc(x.unsqueeze(0).unsqueeze(1).to(self.dev))
dists = {k: torch.norm(z - c, dim=1).item() for k, c in self.centroids.items() if c is not None}
if not dists: return 0, 0.0, {'d': {}, 'p': [0.5, 0.5]}
pred = min(dists, key=dists.get)
d0, d1 = dists.get(0, 1e6), dists.get(1, 1e6) # softmax(-d) gives higher prob to closer centroid
probs = F.softmax(torch.tensor([[-d0, -d1]]), dim=1).squeeze()
return pred, probs[pred].item(), {'d': dists, 'p': probs.tolist()}
def train(epochs=200, lr=5e-4, batch=16, dim=64, emb=32, temp=0.5,
sizes=[5,10,15], stride=2, name=None, verbose=True):
data = JointLoader(HUMAN_DIR, AGENT_DIR).get_data()
wins = build_windows(data, sizes, stride)
if verbose: print(f"Windows: {len(wins)} ({sum(w.label==0 for w in wins)}h/{sum(w.label==1 for w in wins)}a)")
dev = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
enc = TrajectoryEncoder(dim, emb).to(dev)
opt = Adam(enc.parameters(), lr=lr)
loader = DataLoader(WindowDataset(wins, dim), batch_size=batch, shuffle=True, drop_last=True)
name = name or f"enc_{dim}_{emb}_{datetime.now():%Y%m%d_%H%M%S}"
writer = SummaryWriter(f"{RUNS}/encoder/{name}")
for ep in range(epochs):
enc.train()
total, n = 0.0, 0
for a, p, _ in loader:
loss = nt_xent_loss(enc(a.unsqueeze(1).to(dev)), enc(p.unsqueeze(1).to(dev)), temp)
opt.zero_grad(); loss.backward(); opt.step()
total += loss.item(); n += 1
avg = total / max(n, 1)
writer.add_scalar('loss-ntxent', avg, ep)
if verbose and (ep+1) % 20 == 0: print(f"Epoch {ep+1}: {avg:.4f}")
writer.close()
return enc, wins, dev
def loocv(epochs=100, lr=5e-4, dim=64, emb=32, temp=0.5, sizes=[5,10,15], stride=2, verbose=True):
"""Leave-one-trajectory-out CV"""
data = JointLoader(HUMAN_DIR, AGENT_DIR).get_data()
dev = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
results = []
for test_id in data:
train_data = {k: v for k, v in data.items() if k != test_id}
if not any(k.startswith('human_') for k in train_data) or not any(k.startswith('agent_') for k in train_data):
continue
wins = build_windows(train_data, sizes, stride)
enc = TrajectoryEncoder(dim, emb).to(dev)
opt = Adam(enc.parameters(), lr=lr)
loader = DataLoader(WindowDataset(wins, dim), batch_size=min(16, len(wins)//2 or 1),
shuffle=True, drop_last=len(wins)>2)
for _ in range(epochs):
enc.train()
for a, p, _ in loader:
loss = nt_xent_loss(enc(a.unsqueeze(1).to(dev)), enc(p.unsqueeze(1).to(dev)), temp)
opt.zero_grad(); loss.backward(); opt.step()
clf = PrototypeClassifier(enc, dev, dim).fit(wins)
pred, conf, dbg = clf.predict(data[test_id])
actual = 0 if test_id.startswith('human_') else 1
results.append((pred, actual, conf))
if verbose: print(f"{test_id[:18]}: pred={pred} conf={conf:.2f} actual={actual} {'OK' if pred==actual else 'MISS'}")
if results:
acc = sum(p==a for p,a,_ in results) / len(results)
if verbose: print(f"\nAccuracy: {acc:.1%} ({sum(p==a for p,a,_ in results)}/{len(results)})")
return acc, results
return 0.0, []
def hparam_tune(n_trials=50, epochs=60, n_jobs=2, verbose=True):
"""Optuna hyperparameter search maximizing LOOCV accuracy"""
def objective(trial):
lr = trial.suggest_float('lr', 1e-5, 1e-2, log=True)
dim = trial.suggest_categorical('dim', [32, 64, 128, 256])
emb = trial.suggest_categorical('emb', [16, 32, 64, 128])
temp = trial.suggest_float('temp', 0.05, 1.0)
stride = trial.suggest_int('stride', 1, 4)
sizes = [trial.suggest_int(f's{i}', 3, 20) for i in range(3)]
sizes = sorted(set(sizes)) # unique sorted
acc, _ = loocv(epochs, lr, dim, emb, temp, sizes, stride, verbose=False)
return acc
study = optuna.create_study(direction='maximize', study_name='encoder_hparam',
sampler=optuna.samplers.TPESampler(seed=42))
study.optimize(objective, n_trials=n_trials, n_jobs=n_jobs, show_progress_bar=verbose)
best = study.best_params
if verbose:
print(f"\nBest accuracy: {study.best_value:.1%}")
print(f"Best params: {best}")
return best, study
if __name__ == "__main__":
import argparse
p = argparse.ArgumentParser()
p.add_argument('--mode', choices=['train', 'eval', 'hparam'], default='train')
p.add_argument('--epochs', type=int, default=200)
p.add_argument('--lr', type=float, default=5e-4)
p.add_argument('--dim', type=int, default=128)
p.add_argument('--emb', type=int, default=64)
p.add_argument('--temp', type=float, default=0.1)
p.add_argument('--sizes', type=str, default='5,10,15')
p.add_argument('--stride', type=int, default=2)
p.add_argument('--n_trials', type=int, default=50)
args = p.parse_args()
sizes = [int(x) for x in args.sizes.split(',')]
if args.mode == 'train':
enc, wins, dev = train(args.epochs, args.lr, 16, args.dim, args.emb, args.temp, sizes, args.stride)
elif args.mode == 'hparam':
best, study = hparam_tune(args.n_trials, min(args.epochs, 60))
else:
loocv(args.epochs, args.lr, args.dim, args.emb, args.temp, sizes, args.stride)

View File

@@ -0,0 +1,957 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 10,
"id": "62eafcd9-5462-4063-8873-0e7fb9add907",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from kafka import KafkaConsumer\n",
"import pandas as pd\n",
"import json\n",
"import numpy as np\n",
"import os\n",
"from dotenv import load_dotenv\n",
"import matplotlib.pyplot as plt\n",
"from IPython.display import display, SVG, Image\n",
"load_dotenv()"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "4af65cb4-e8cf-4877-b2db-13ac19f3838f",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"<class 'pandas.core.frame.DataFrame'>\n",
"RangeIndex: 73 entries, 0 to 72\n",
"Data columns (total 13 columns):\n",
" # Column Non-Null Count Dtype \n",
"--- ------ -------------- ----- \n",
" 0 sessionId 73 non-null object \n",
" 1 eventName 73 non-null object \n",
" 2 page 73 non-null object \n",
" 3 productId 67 non-null object \n",
" 4 storeMode 73 non-null object \n",
" 5 userAgent 73 non-null object \n",
" 6 ts 73 non-null object \n",
" 7 metadata_referrer 6 non-null object \n",
" 8 metadata_roomType 45 non-null object \n",
" 9 metadata_price 45 non-null float64\n",
" 10 metadata_nights 45 non-null float64\n",
" 11 metadata_elementText 22 non-null object \n",
" 12 metadata_dwellTime 22 non-null float64\n",
"dtypes: float64(3), object(10)\n",
"memory usage: 7.5+ KB\n"
]
}
],
"source": [
"KAFKA_PORT=os.getenv(\"KAFKA_PORT\", 9092)\n",
"topic = \"user-interactions\"\n",
"consumer = KafkaConsumer(\n",
" topic, \n",
" enable_auto_commit=True,\n",
" value_deserializer=lambda x: json.loads(x.decode('utf-8')),\n",
" auto_offset_reset='earliest', \n",
" bootstrap_servers=['localhost:9092'])\n",
"messages=consumer.poll(timeout_ms=1000,max_records=10000)\n",
"df = []\n",
"for m in messages.values():\n",
" for i in m:\n",
" df.append(i.value)\n",
"df = pd.DataFrame(df)\n",
"# explode metadata col json\n",
"df = df.join(pd.json_normalize(df.pop(\"metadata\"), sep=\".\").add_prefix(\"metadata_\"))\n",
"df.info()"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "f6819a1c-32ab-49c7-845b-5df7bf60f561",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>sessionId</th>\n",
" <th>eventName</th>\n",
" <th>page</th>\n",
" <th>productId</th>\n",
" <th>storeMode</th>\n",
" <th>userAgent</th>\n",
" <th>ts</th>\n",
" <th>metadata_referrer</th>\n",
" <th>metadata_roomType</th>\n",
" <th>metadata_price</th>\n",
" <th>metadata_nights</th>\n",
" <th>metadata_elementText</th>\n",
" <th>metadata_dwellTime</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>d176d7c9-4027-4702-9e31-2a71395cdda0</td>\n",
" <td>page_view</td>\n",
" <td>/products</td>\n",
" <td>None</td>\n",
" <td>hotel</td>\n",
" <td>Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/53...</td>\n",
" <td>2025-11-14T13:23:46.270Z</td>\n",
" <td></td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>f0317a5d-e424-44e9-b784-c8f7291ffe31</td>\n",
" <td>page_view</td>\n",
" <td>/</td>\n",
" <td>None</td>\n",
" <td>hotel</td>\n",
" <td>Mozilla/5.0 (X11; Linux x86_64; rv:143.0) Geck...</td>\n",
" <td>2025-11-14T13:26:00.291Z</td>\n",
" <td></td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>f0317a5d-e424-44e9-b784-c8f7291ffe31</td>\n",
" <td>page_view</td>\n",
" <td>/products</td>\n",
" <td>None</td>\n",
" <td>hotel</td>\n",
" <td>Mozilla/5.0 (X11; Linux x86_64; rv:143.0) Geck...</td>\n",
" <td>2025-11-14T13:26:07.769Z</td>\n",
" <td></td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>f0317a5d-e424-44e9-b784-c8f7291ffe31</td>\n",
" <td>view_item_page</td>\n",
" <td>/products</td>\n",
" <td>htl-0</td>\n",
" <td>hotel</td>\n",
" <td>Mozilla/5.0 (X11; Linux x86_64; rv:143.0) Geck...</td>\n",
" <td>2025-11-14T13:26:15.010Z</td>\n",
" <td>NaN</td>\n",
" <td>Premium Room</td>\n",
" <td>269.0</td>\n",
" <td>1.0</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>238dc588-a7ab-4c0e-bccd-6abca5076c66</td>\n",
" <td>page_view</td>\n",
" <td>/products</td>\n",
" <td>None</td>\n",
" <td>hotel</td>\n",
" <td>Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7...</td>\n",
" <td>2025-11-14T13:27:15.457Z</td>\n",
" <td></td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>238dc588-a7ab-4c0e-bccd-6abca5076c66</td>\n",
" <td>view_item_page</td>\n",
" <td>/products</td>\n",
" <td>htl-0</td>\n",
" <td>hotel</td>\n",
" <td>Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7...</td>\n",
" <td>2025-11-14T13:27:15.591Z</td>\n",
" <td>NaN</td>\n",
" <td>Premium Room</td>\n",
" <td>264.0</td>\n",
" <td>2.0</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>432</th>\n",
" <td>214d9fad-9b00-40c3-bd0e-7739b6acd654</td>\n",
" <td>click</td>\n",
" <td>1762448192425</td>\n",
" <td>DIV</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>/</td>\n",
" <td>NaN</td>\n",
" <td>1623.0</td>\n",
" <td>493.0</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6</th>\n",
" <td>238dc588-a7ab-4c0e-bccd-6abca5076c66</td>\n",
" <td>view_item_page</td>\n",
" <td>/products</td>\n",
" <td>htl-0</td>\n",
" <td>hotel</td>\n",
" <td>Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7...</td>\n",
" <td>2025-11-14T13:27:21.483Z</td>\n",
" <td>NaN</td>\n",
" <td>Premium Room</td>\n",
" <td>264.0</td>\n",
" <td>2.0</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7</th>\n",
" <td>238dc588-a7ab-4c0e-bccd-6abca5076c66</td>\n",
" <td>hover_over_title</td>\n",
" <td>/products</td>\n",
" <td>htl-0</td>\n",
" <td>hotel</td>\n",
" <td>Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7...</td>\n",
" <td>2025-11-14T13:27:22.646Z</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>Grand Plaza Hotel</td>\n",
" <td>1200.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>8</th>\n",
" <td>238dc588-a7ab-4c0e-bccd-6abca5076c66</td>\n",
" <td>view_item_page</td>\n",
" <td>/products</td>\n",
" <td>htl-0</td>\n",
" <td>hotel</td>\n",
" <td>Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7...</td>\n",
" <td>2025-11-14T13:27:25.889Z</td>\n",
" <td>NaN</td>\n",
" <td>Premium Room</td>\n",
" <td>264.0</td>\n",
" <td>2.0</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>35</th>\n",
" <td>013fc334-4045-4d5a-8739-dd0a8766a63b</td>\n",
" <td>page_view</td>\n",
" <td>/products</td>\n",
" <td>None</td>\n",
" <td>hotel</td>\n",
" <td>Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/53...</td>\n",
" <td>2025-11-14T13:53:59.993Z</td>\n",
" <td></td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>36</th>\n",
" <td>013fc334-4045-4d5a-8739-dd0a8766a63b</td>\n",
" <td>view_item_page</td>\n",
" <td>/products</td>\n",
" <td>htl-0</td>\n",
" <td>hotel</td>\n",
" <td>Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/53...</td>\n",
" <td>2025-11-14T13:54:10.705Z</td>\n",
" <td>NaN</td>\n",
" <td>Premium Room</td>\n",
" <td>223.0</td>\n",
" <td>3.0</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>37</th>\n",
" <td>013fc334-4045-4d5a-8739-dd0a8766a63b</td>\n",
" <td>hover_over_title</td>\n",
" <td>/products</td>\n",
" <td>htl-0</td>\n",
" <td>hotel</td>\n",
" <td>Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/53...</td>\n",
" <td>2025-11-14T13:54:11.771Z</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>416.0</td>\n",
" <td>397.0</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>Grand Plaza Hotel</td>\n",
" <td>1200.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>38</th>\n",
" <td>013fc334-4045-4d5a-8739-dd0a8766a63b</td>\n",
" <td>view_item_page</td>\n",
" <td>/products</td>\n",
" <td>htl-1</td>\n",
" <td>hotel</td>\n",
" <td>Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/53...</td>\n",
" <td>2025-11-14T13:54:29.772Z</td>\n",
" <td>NaN</td>\n",
" <td>Standard Room</td>\n",
" <td>267.0</td>\n",
" <td>5.0</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>39</th>\n",
" <td>013fc334-4045-4d5a-8739-dd0a8766a63b</td>\n",
" <td>hover_over_title</td>\n",
" <td>/products</td>\n",
" <td>htl-1</td>\n",
" <td>hotel</td>\n",
" <td>Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/53...</td>\n",
" <td>2025-11-14T13:54:30.833Z</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>NaN</td>\n",
" <td>Seaside Resort</td>\n",
" <td>1200.0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" sessionId eventName page \\\n",
"0 d176d7c9-4027-4702-9e31-2a71395cdda0 page_view /products \n",
"1 f0317a5d-e424-44e9-b784-c8f7291ffe31 page_view / \n",
"2 f0317a5d-e424-44e9-b784-c8f7291ffe31 page_view /products \n",
"3 f0317a5d-e424-44e9-b784-c8f7291ffe31 view_item_page /products \n",
"4 238dc588-a7ab-4c0e-bccd-6abca5076c66 page_view /products \n",
"5 238dc588-a7ab-4c0e-bccd-6abca5076c66 view_item_page /products \n",
"6 238dc588-a7ab-4c0e-bccd-6abca5076c66 view_item_page /products \n",
"7 238dc588-a7ab-4c0e-bccd-6abca5076c66 hover_over_title /products \n",
"8 238dc588-a7ab-4c0e-bccd-6abca5076c66 view_item_page /products \n",
"35 013fc334-4045-4d5a-8739-dd0a8766a63b page_view /products \n",
"36 013fc334-4045-4d5a-8739-dd0a8766a63b view_item_page /products \n",
"37 013fc334-4045-4d5a-8739-dd0a8766a63b hover_over_title /products \n",
"38 013fc334-4045-4d5a-8739-dd0a8766a63b view_item_page /products \n",
"39 013fc334-4045-4d5a-8739-dd0a8766a63b hover_over_title /products \n",
"\n",
" productId storeMode userAgent \\\n",
"0 None hotel Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/53... \n",
"1 None hotel Mozilla/5.0 (X11; Linux x86_64; rv:143.0) Geck... \n",
"2 None hotel Mozilla/5.0 (X11; Linux x86_64; rv:143.0) Geck... \n",
"3 htl-0 hotel Mozilla/5.0 (X11; Linux x86_64; rv:143.0) Geck... \n",
"4 None hotel Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7... \n",
"5 htl-0 hotel Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7... \n",
"6 htl-0 hotel Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7... \n",
"7 htl-0 hotel Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7... \n",
"8 htl-0 hotel Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7... \n",
"35 None hotel Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/53... \n",
"36 htl-0 hotel Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/53... \n",
"37 htl-0 hotel Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/53... \n",
"38 htl-1 hotel Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/53... \n",
"39 htl-1 hotel Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/53... \n",
"\n",
" ts metadata_referrer metadata_roomType \\\n",
"0 2025-11-14T13:23:46.270Z NaN \n",
"1 2025-11-14T13:26:00.291Z NaN \n",
"2 2025-11-14T13:26:07.769Z NaN \n",
"3 2025-11-14T13:26:15.010Z NaN Premium Room \n",
"4 2025-11-14T13:27:15.457Z NaN \n",
"5 2025-11-14T13:27:15.591Z NaN Premium Room \n",
"6 2025-11-14T13:27:21.483Z NaN Premium Room \n",
"7 2025-11-14T13:27:22.646Z NaN NaN \n",
"8 2025-11-14T13:27:25.889Z NaN Premium Room \n",
"35 2025-11-14T13:53:59.993Z NaN \n",
"36 2025-11-14T13:54:10.705Z NaN Premium Room \n",
"37 2025-11-14T13:54:11.771Z NaN NaN \n",
"38 2025-11-14T13:54:29.772Z NaN Standard Room \n",
"39 2025-11-14T13:54:30.833Z NaN NaN \n",
"\n",
" metadata_price metadata_nights metadata_elementText metadata_dwellTime \n",
"0 NaN NaN NaN NaN \n",
"1 NaN NaN NaN NaN \n",
"2 NaN NaN NaN NaN \n",
"3 269.0 1.0 NaN NaN \n",
"4 NaN NaN NaN NaN \n",
"5 264.0 2.0 NaN NaN \n",
"6 264.0 2.0 NaN NaN \n",
"7 NaN NaN Grand Plaza Hotel 1200.0 \n",
"8 264.0 2.0 NaN NaN \n",
"35 NaN NaN NaN NaN \n",
"36 223.0 3.0 NaN NaN \n",
"37 NaN NaN Grand Plaza Hotel 1200.0 \n",
"38 267.0 5.0 NaN NaN \n",
"39 NaN NaN Seaside Resort 1200.0 "
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df.groupby('sessionId').head()"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "380eca5f-8304-4fb2-be32-e8bcfd312085",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['013fc334-4045-4d5a-8739-dd0a8766a63b',\n",
" '238dc588-a7ab-4c0e-bccd-6abca5076c66',\n",
" 'd176d7c9-4027-4702-9e31-2a71395cdda0',\n",
" 'f0317a5d-e424-44e9-b784-c8f7291ffe31']"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"sessions = list(set(df['sessionId'])); sessions # 238dc588-a7ab-4c0e-bccd-6abca5076c66"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "f4ae6f81-dcb8-44be-aee7-30dbc3a6bae1",
"metadata": {},
"outputs": [],
"source": [
"# map sessions to experiments"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "050d90a4-20a9-47f5-b998-c31178a54cb3",
"metadata": {},
"outputs": [],
"source": [
"def build_transition_prob_matrix(df: pd.DataFrame):\n",
" df = df.dropna(subset=['eventName'])\n",
" events = df['eventName'].tolist()\n",
" labels = pd.Index(events).unique().tolist()\n",
" idx = {e:i for i,e in enumerate(labels)}\n",
" M = np.zeros((len(labels), len(labels)), dtype=float)\n",
" for a, b in zip(events, events[1:]):\n",
" M[idx[a], idx[b]] += 1\n",
" row_sums = M.sum(axis=1, keepdims=True)\n",
" with np.errstate(divide='ignore', invalid='ignore'):\n",
" P = np.divide(M, row_sums, where=row_sums>0) # row-normalized\n",
" return P, labels"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "e68f9004-82f5-4826-aece-e3dc6e15a18f",
"metadata": {},
"outputs": [],
"source": [
"# https://medium.com/data-science/time-series-data-markov-transition-matrices-7060771e362b\n",
"from graphviz import Digraph\n",
"import numpy as np\n",
"import pandas as pd\n",
"\n",
"def _as_prob_df(matrix, labels=None):\n",
" \"\"\"Return a square DataFrame with index=columns=labels.\"\"\"\n",
" if isinstance(matrix, pd.DataFrame):\n",
" # Ensure square and aligned\n",
" assert (matrix.index == matrix.columns).all(), \"Index/columns must match.\"\n",
" return matrix\n",
" matrix = np.asarray(matrix, dtype=float)\n",
" assert matrix.shape[0] == matrix.shape[1], \"Matrix must be square.\"\n",
" if labels is None:\n",
" raise ValueError(\"labels are required when matrix is not a DataFrame\")\n",
" assert len(labels) == matrix.shape[0], \"labels length must match matrix size.\"\n",
" return pd.DataFrame(matrix, index=list(labels), columns=list(labels))\n",
"\n",
"def _df_to_edgelist(P: pd.DataFrame, threshold=0.0, round_digits=2):\n",
" \"\"\"Build weighted edges > threshold.\"\"\"\n",
" edges = []\n",
" for src in P.index:\n",
" for dst in P.columns:\n",
" w = float(P.loc[src, dst])\n",
" if w > threshold:\n",
" edges.append((str(src), str(dst), f\"{w:.{round_digits}f}\"))\n",
" return edges\n",
"\n",
"def render_graph(fname, matrix, ls_index=None, threshold=0.0, fmt=\"svg\", view=False):\n",
" \"\"\"\n",
" fname: output file stem (no extension)\n",
" matrix: NumPy array or pandas DataFrame of transition PROBABILITIES\n",
" ls_index: ordered labels (required if matrix is not a DataFrame)\n",
" threshold: hide edges with weight <= threshold\n",
" fmt: 'svg'|'png'|'pdf' etc.\n",
" view: open after rendering\n",
" \"\"\"\n",
" P = _as_prob_df(matrix, labels=ls_index)\n",
" edges = _df_to_edgelist(P, threshold=threshold)\n",
"\n",
" g = Digraph(format=fmt)\n",
" g.attr(rankdir=\"LR\", size=\"30\")\n",
" g.attr(\"node\", shape=\"circle\")\n",
"\n",
" # ensure isolated nodes appear\n",
" for node in P.index:\n",
" g.node(str(node), width=\"1\", height=\"1\")\n",
"\n",
" for src, dst, label in edges:\n",
" g.edge(src, dst, label=label)\n",
"\n",
" g.render(fname, view=view, cleanup=True)\n",
" return g\n"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "e255a2c1-6454-4e5e-89f6-ef8ac51ab6cc",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"013fc334-4045-4d5a-8739-dd0a8766a63b\n"
]
},
{
"data": {
"image/svg+xml": [
"<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n",
"<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
" \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
"<!-- Generated by graphviz version 13.1.2 (0)\n",
" -->\n",
"<!-- Pages: 1 -->\n",
"<svg width=\"565pt\" height=\"354pt\"\n",
" viewBox=\"0.00 0.00 565.00 354.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
"<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 349.64)\">\n",
"<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-349.64 561.05,-349.64 561.05,4 -4,4\"/>\n",
"<!-- page_view -->\n",
"<g id=\"node1\" class=\"node\">\n",
"<title>page_view</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"48.19\" cy=\"-235.83\" rx=\"48.19\" ry=\"48.19\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"48.19\" y=\"-231.16\" font-family=\"Times,serif\" font-size=\"14.00\">page_view</text>\n",
"</g>\n",
"<!-- view_item_page -->\n",
"<g id=\"node2\" class=\"node\">\n",
"<title>view_item_page</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"232.88\" cy=\"-235.83\" rx=\"69.01\" ry=\"69.01\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"232.88\" y=\"-231.16\" font-family=\"Times,serif\" font-size=\"14.00\">view_item_page</text>\n",
"</g>\n",
"<!-- page_view&#45;&gt;view_item_page -->\n",
"<g id=\"edge1\" class=\"edge\">\n",
"<title>page_view&#45;&gt;view_item_page</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M96.71,-235.83C113.69,-235.83 133.31,-235.83 152.25,-235.83\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"152.1,-239.33 162.1,-235.83 152.1,-232.33 152.1,-239.33\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"130.12\" y=\"-239.78\" font-family=\"Times,serif\" font-size=\"14.00\">1.00</text>\n",
"</g>\n",
"<!-- view_item_page&#45;&gt;view_item_page -->\n",
"<g id=\"edge2\" class=\"edge\">\n",
"<title>view_item_page&#45;&gt;view_item_page</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M214.74,-302.59C217.1,-314.51 223.14,-322.84 232.88,-322.84 239.27,-322.84 244.07,-319.26 247.28,-313.42\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"250.57,-314.62 250.52,-304.02 243.95,-312.33 250.57,-314.62\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"232.88\" y=\"-326.79\" font-family=\"Times,serif\" font-size=\"14.00\">0.68</text>\n",
"</g>\n",
"<!-- hover_over_title -->\n",
"<g id=\"node3\" class=\"node\">\n",
"<title>hover_over_title</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"463.22\" cy=\"-275.83\" rx=\"69.81\" ry=\"69.81\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"463.22\" y=\"-271.16\" font-family=\"Times,serif\" font-size=\"14.00\">hover_over_title</text>\n",
"</g>\n",
"<!-- view_item_page&#45;&gt;hover_over_title -->\n",
"<g id=\"edge3\" class=\"edge\">\n",
"<title>view_item_page&#45;&gt;hover_over_title</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M300.48,-250.14C307.03,-251.43 313.58,-252.69 319.89,-253.83 340.12,-257.51 362.05,-261.1 382.5,-264.27\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"381.77,-267.7 392.19,-265.76 382.83,-260.78 381.77,-267.7\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"335.64\" y=\"-263.17\" font-family=\"Times,serif\" font-size=\"14.00\">0.29</text>\n",
"</g>\n",
"<!-- hover_over_paragraph -->\n",
"<g id=\"node4\" class=\"node\">\n",
"<title>hover_over_paragraph</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"463.22\" cy=\"-93.83\" rx=\"93.83\" ry=\"93.83\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"463.22\" y=\"-89.16\" font-family=\"Times,serif\" font-size=\"14.00\">hover_over_paragraph</text>\n",
"</g>\n",
"<!-- view_item_page&#45;&gt;hover_over_paragraph -->\n",
"<g id=\"edge4\" class=\"edge\">\n",
"<title>view_item_page&#45;&gt;hover_over_paragraph</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M292.09,-199.63C316.79,-184.27 346.14,-166.02 373.44,-149.04\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"375.08,-152.15 381.72,-143.89 371.38,-146.2 375.08,-152.15\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"335.64\" y=\"-185.68\" font-family=\"Times,serif\" font-size=\"14.00\">0.04</text>\n",
"</g>\n",
"<!-- hover_over_title&#45;&gt;view_item_page -->\n",
"<g id=\"edge5\" class=\"edge\">\n",
"<title>hover_over_title&#45;&gt;view_item_page</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M399.53,-246.73C384.12,-240.88 367.42,-235.6 351.39,-232.58 339.13,-230.28 326.03,-229.26 313.19,-229.04\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"313.51,-225.54 303.51,-229.04 313.51,-232.54 313.51,-225.54\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"335.64\" y=\"-236.53\" font-family=\"Times,serif\" font-size=\"14.00\">1.00</text>\n",
"</g>\n",
"</svg>\n"
],
"text/plain": [
"<graphviz.graphs.Digraph at 0x7f0779e818b0>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[]\n"
]
},
{
"data": {
"image/svg+xml": [
"<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n",
"<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
" \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
"<!-- Generated by graphviz version 13.1.2 (0)\n",
" -->\n",
"<!-- Pages: 1 -->\n",
"<svg width=\"8pt\" height=\"8pt\"\n",
" viewBox=\"0.00 0.00 8.00 8.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
"<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 4)\">\n",
"<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-4 4,-4 4,4 -4,4\"/>\n",
"</g>\n",
"</svg>\n"
],
"text/plain": [
"<graphviz.graphs.Digraph at 0x7f6800fac980>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[0.00000000e+000 1.00000000e+000 0.00000000e+000 0.00000000e+000]\n",
" [0.00000000e+000 6.78571429e-001 2.85714286e-001 3.57142857e-002]\n",
" [0.00000000e+000 1.00000000e+000 0.00000000e+000 0.00000000e+000]\n",
" [2.05833592e-312 2.29175545e-312 4.94065646e-324 6.92110218e-310]]\n",
"238dc588-a7ab-4c0e-bccd-6abca5076c66\n"
]
},
{
"data": {
"image/svg+xml": [
"<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n",
"<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
" \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
"<!-- Generated by graphviz version 13.1.2 (0)\n",
" -->\n",
"<!-- Pages: 1 -->\n",
"<svg width=\"565pt\" height=\"354pt\"\n",
" viewBox=\"0.00 0.00 565.00 354.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
"<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 349.64)\">\n",
"<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-349.64 561.05,-349.64 561.05,4 -4,4\"/>\n",
"<!-- page_view -->\n",
"<g id=\"node1\" class=\"node\">\n",
"<title>page_view</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"48.19\" cy=\"-109.83\" rx=\"48.19\" ry=\"48.19\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"48.19\" y=\"-105.16\" font-family=\"Times,serif\" font-size=\"14.00\">page_view</text>\n",
"</g>\n",
"<!-- view_item_page -->\n",
"<g id=\"node2\" class=\"node\">\n",
"<title>view_item_page</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"232.88\" cy=\"-197.83\" rx=\"69.01\" ry=\"69.01\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"232.88\" y=\"-193.16\" font-family=\"Times,serif\" font-size=\"14.00\">view_item_page</text>\n",
"</g>\n",
"<!-- page_view&#45;&gt;view_item_page -->\n",
"<g id=\"edge1\" class=\"edge\">\n",
"<title>page_view&#45;&gt;view_item_page</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M92.02,-130.47C112.32,-140.25 137.13,-152.2 160.18,-163.3\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"158.39,-166.32 168.92,-167.51 161.43,-160.02 158.39,-166.32\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"130.12\" y=\"-157.78\" font-family=\"Times,serif\" font-size=\"14.00\">1.00</text>\n",
"</g>\n",
"<!-- view_item_page&#45;&gt;view_item_page -->\n",
"<g id=\"edge2\" class=\"edge\">\n",
"<title>view_item_page&#45;&gt;view_item_page</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M214.74,-264.59C217.1,-276.51 223.14,-284.84 232.88,-284.84 239.27,-284.84 244.07,-281.26 247.28,-275.42\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"250.57,-276.62 250.52,-266.02 243.95,-274.33 250.57,-276.62\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"232.88\" y=\"-288.79\" font-family=\"Times,serif\" font-size=\"14.00\">0.19</text>\n",
"</g>\n",
"<!-- hover_over_title -->\n",
"<g id=\"node3\" class=\"node\">\n",
"<title>hover_over_title</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"463.22\" cy=\"-275.83\" rx=\"69.81\" ry=\"69.81\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"463.22\" y=\"-271.16\" font-family=\"Times,serif\" font-size=\"14.00\">hover_over_title</text>\n",
"</g>\n",
"<!-- view_item_page&#45;&gt;hover_over_title -->\n",
"<g id=\"edge3\" class=\"edge\">\n",
"<title>view_item_page&#45;&gt;hover_over_title</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M289.6,-237.16C299.36,-242.77 309.67,-247.94 319.89,-251.83 339.45,-259.28 361.4,-264.43 382.1,-267.98\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"381.52,-271.43 391.95,-269.55 382.62,-264.52 381.52,-271.43\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"335.64\" y=\"-265.16\" font-family=\"Times,serif\" font-size=\"14.00\">0.38</text>\n",
"</g>\n",
"<!-- hover_over_paragraph -->\n",
"<g id=\"node4\" class=\"node\">\n",
"<title>hover_over_paragraph</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"463.22\" cy=\"-93.83\" rx=\"93.83\" ry=\"93.83\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"463.22\" y=\"-89.16\" font-family=\"Times,serif\" font-size=\"14.00\">hover_over_paragraph</text>\n",
"</g>\n",
"<!-- view_item_page&#45;&gt;hover_over_paragraph -->\n",
"<g id=\"edge4\" class=\"edge\">\n",
"<title>view_item_page&#45;&gt;hover_over_paragraph</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M300.22,-180.71C317.22,-175.46 335.24,-169.12 351.39,-161.83 358.97,-158.41 366.67,-154.57 374.29,-150.49\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"375.84,-153.63 382.92,-145.75 372.47,-147.5 375.84,-153.63\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"335.64\" y=\"-178.15\" font-family=\"Times,serif\" font-size=\"14.00\">0.44</text>\n",
"</g>\n",
"<!-- hover_over_title&#45;&gt;view_item_page -->\n",
"<g id=\"edge5\" class=\"edge\">\n",
"<title>hover_over_title&#45;&gt;view_item_page</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M398.52,-248.36C383.21,-242.16 366.82,-235.87 351.39,-230.58 338.42,-226.15 324.5,-221.86 310.94,-217.93\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"312.2,-214.65 301.62,-215.28 310.28,-221.39 312.2,-214.65\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"335.64\" y=\"-234.53\" font-family=\"Times,serif\" font-size=\"14.00\">1.00</text>\n",
"</g>\n",
"<!-- hover_over_paragraph&#45;&gt;page_view -->\n",
"<g id=\"edge6\" class=\"edge\">\n",
"<title>hover_over_paragraph&#45;&gt;page_view</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M369.13,-95.76C310.26,-97.17 232.59,-99.41 163.87,-102.58 145.72,-103.42 125.98,-104.58 108.06,-105.73\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"107.86,-102.24 98.1,-106.38 108.31,-109.22 107.86,-102.24\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"232.88\" y=\"-106.53\" font-family=\"Times,serif\" font-size=\"14.00\">0.14</text>\n",
"</g>\n",
"<!-- hover_over_paragraph&#45;&gt;view_item_page -->\n",
"<g id=\"edge7\" class=\"edge\">\n",
"<title>hover_over_paragraph&#45;&gt;view_item_page</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M372.68,-119.15C354.84,-125.32 336.5,-132.51 319.89,-140.58 312.9,-143.98 305.81,-147.87 298.86,-151.98\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"297.49,-148.71 290.78,-156.91 301.14,-154.69 297.49,-148.71\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"335.64\" y=\"-144.53\" font-family=\"Times,serif\" font-size=\"14.00\">0.86</text>\n",
"</g>\n",
"</g>\n",
"</svg>\n"
],
"text/plain": [
"<graphviz.graphs.Digraph at 0x7f6800f97110>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[0. 1. 0. 0. ]\n",
" [0. 0.1875 0.375 0.4375 ]\n",
" [0. 1. 0. 0. ]\n",
" [0.14285714 0.85714286 0. 0. ]]\n",
"d176d7c9-4027-4702-9e31-2a71395cdda0\n"
]
},
{
"data": {
"image/svg+xml": [
"<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n",
"<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
" \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
"<!-- Generated by graphviz version 13.1.2 (0)\n",
" -->\n",
"<!-- Pages: 1 -->\n",
"<svg width=\"104pt\" height=\"104pt\"\n",
" viewBox=\"0.00 0.00 104.00 104.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
"<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 100.37)\">\n",
"<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-100.37 100.37,-100.37 100.37,4 -4,4\"/>\n",
"<!-- page_view -->\n",
"<g id=\"node1\" class=\"node\">\n",
"<title>page_view</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"48.19\" cy=\"-48.19\" rx=\"48.19\" ry=\"48.19\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"48.19\" y=\"-43.51\" font-family=\"Times,serif\" font-size=\"14.00\">page_view</text>\n",
"</g>\n",
"</g>\n",
"</svg>\n"
],
"text/plain": [
"<graphviz.graphs.Digraph at 0x7f6800f97110>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[0.]]\n",
"f0317a5d-e424-44e9-b784-c8f7291ffe31\n"
]
},
{
"data": {
"image/svg+xml": [
"<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n",
"<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
" \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
"<!-- Generated by graphviz version 13.1.2 (0)\n",
" -->\n",
"<!-- Pages: 1 -->\n",
"<svg width=\"310pt\" height=\"160pt\"\n",
" viewBox=\"0.00 0.00 310.00 160.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
"<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 156.44)\">\n",
"<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-156.44 305.89,-156.44 305.89,4 -4,4\"/>\n",
"<!-- page_view -->\n",
"<g id=\"node1\" class=\"node\">\n",
"<title>page_view</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"48.19\" cy=\"-69.01\" rx=\"48.19\" ry=\"48.19\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"48.19\" y=\"-64.33\" font-family=\"Times,serif\" font-size=\"14.00\">page_view</text>\n",
"</g>\n",
"<!-- page_view&#45;&gt;page_view -->\n",
"<g id=\"edge1\" class=\"edge\">\n",
"<title>page_view&#45;&gt;page_view</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M33.03,-115.09C34.09,-126.6 39.14,-135.19 48.19,-135.19 53.98,-135.19 58.13,-131.66 60.65,-126.1\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"64.01,-127.11 62.98,-116.56 57.21,-125.45 64.01,-127.11\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"48.19\" y=\"-139.14\" font-family=\"Times,serif\" font-size=\"14.00\">0.50</text>\n",
"</g>\n",
"<!-- view_item_page -->\n",
"<g id=\"node2\" class=\"node\">\n",
"<title>view_item_page</title>\n",
"<ellipse fill=\"none\" stroke=\"black\" cx=\"232.88\" cy=\"-69.01\" rx=\"69.01\" ry=\"69.01\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"232.88\" y=\"-64.33\" font-family=\"Times,serif\" font-size=\"14.00\">view_item_page</text>\n",
"</g>\n",
"<!-- page_view&#45;&gt;view_item_page -->\n",
"<g id=\"edge2\" class=\"edge\">\n",
"<title>page_view&#45;&gt;view_item_page</title>\n",
"<path fill=\"none\" stroke=\"black\" d=\"M96.71,-69.01C113.69,-69.01 133.31,-69.01 152.25,-69.01\"/>\n",
"<polygon fill=\"black\" stroke=\"black\" points=\"152.1,-72.51 162.1,-69.01 152.1,-65.51 152.1,-72.51\"/>\n",
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"130.12\" y=\"-72.96\" font-family=\"Times,serif\" font-size=\"14.00\">0.50</text>\n",
"</g>\n",
"</g>\n",
"</svg>\n"
],
"text/plain": [
"<graphviz.graphs.Digraph at 0x7f6800bf50f0>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[5.0e-001 5.0e-001]\n",
" [9.9e-324 1.5e-323]]\n"
]
}
],
"source": [
"def explore_session(session_id: str):\n",
" subset = df[df['sessionId'] == session_id]\n",
" print(session_id)\n",
" P, labels = build_transition_prob_matrix(subset)\n",
" g = render_graph(f\"session_{session_id}\", P, ls_index=labels, threshold=0.01, fmt=\"svg\", view=False)\n",
" display(g)\n",
" return P\n",
"for session in sessions:\n",
" print(explore_session(session))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python (PHANTOM)",
"language": "python",
"name": "phantom"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.13.7"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,165 @@
import pytest
import pandas as pd
import numpy as np
from procesing.steps.session import (
TemporalFeatureStep,
BehavioralFeatureStep,
ProductFeatureStep,
UserAgentFeatureStep,
ExtractSessionFeaturesStep,
JoinLabelsStep,
ValidateDataStep,
)
# TemporalFeatureStep tests
def test_temporal_empty(pipeline_context):
result = TemporalFeatureStep(pipeline_context).transform(pd.DataFrame())
assert 'sessionId' in result.columns
assert result.empty
def test_temporal_basic(pipeline_context, session_interactions):
result = TemporalFeatureStep(pipeline_context).transform(session_interactions)
assert 'session_duration_sec' in result.columns
assert 'interaction_velocity' in result.columns
assert 'max_velocity_5min' in result.columns
assert result['total_interactions'].sum() == len(session_interactions)
def test_temporal_timeout(pipeline_context):
df = pd.DataFrame({
'sessionId': ['s1', 's1'],
'ts': ['2025-01-01T10:00:00Z', '2025-01-01T11:00:00Z'], # 1 hour gap
})
result = TemporalFeatureStep(pipeline_context, timeout_sec=900).transform(df)
assert result.iloc[0]['session_duration_sec'] == 0 # gap exceeds timeout
# BehavioralFeatureStep tests
def test_behavioral_empty(pipeline_context):
result = BehavioralFeatureStep(pipeline_context).transform(pd.DataFrame())
assert 'sessionId' in result.columns
def test_behavioral_counts(pipeline_context, session_interactions):
result = BehavioralFeatureStep(pipeline_context).transform(session_interactions)
assert 'page_views' in result.columns
assert 'item_views' in result.columns
assert 'hover_events' in result.columns
assert result['total_events'].sum() == len(session_interactions)
def test_behavioral_hover_prefix(pipeline_context):
df = pd.DataFrame({
'sessionId': ['s1', 's1'],
'eventName': ['hover_over_custom', 'hover_over_button'],
'page': ['/products', '/products'],
})
result = BehavioralFeatureStep(pipeline_context).transform(df)
assert result.iloc[0]['hover_events'] == 2
# ProductFeatureStep tests
def test_product_empty(pipeline_context):
result = ProductFeatureStep(pipeline_context).transform(pd.DataFrame())
assert 'sessionId' in result.columns
def test_product_features(pipeline_context, session_interactions):
result = ProductFeatureStep(pipeline_context).transform(session_interactions)
assert 'unique_products_viewed' in result.columns
assert 'price_range' in result.columns
assert result['unique_products_viewed'].sum() > 0
# UserAgentFeatureStep tests
def test_ua_empty(pipeline_context):
result = UserAgentFeatureStep(pipeline_context).transform(pd.DataFrame())
assert 'sessionId' in result.columns
def test_ua_headless_detection(pipeline_context):
df = pd.DataFrame({
'sessionId': ['s1', 's2'],
'userAgent': ['Mozilla/5.0 Chrome/120', 'HeadlessChrome/120'],
})
result = UserAgentFeatureStep(pipeline_context).transform(df)
assert 'is_headless' in result.columns
headless = dict(zip(result['sessionId'], result['is_headless']))
assert headless['s1'] == False
assert headless['s2'] == True
def test_ua_browser_family(pipeline_context):
df = pd.DataFrame({
'sessionId': ['s1', 's2', 's3'],
'userAgent': ['Mozilla/5.0 Firefox/120', 'Safari/605.1.15', 'Unknown'],
})
result = UserAgentFeatureStep(pipeline_context).transform(df)
browsers = dict(zip(result['sessionId'], result['browser_family']))
assert browsers['s1'] == 'Firefox'
assert browsers['s2'] == 'Safari'
assert browsers['s3'] == 'Other'
def test_ua_automation_detection(pipeline_context):
df = pd.DataFrame({
'sessionId': ['s1', 's2'],
'userAgent': ['Selenium WebDriver', 'Normal Chrome/120'],
})
result = UserAgentFeatureStep(pipeline_context).transform(df)
auto = dict(zip(result['sessionId'], result['is_automation']))
assert auto['s1'] == True
assert auto['s2'] == False
# ExtractSessionFeaturesStep tests
def test_extract_empty(pipeline_context):
result = ExtractSessionFeaturesStep(pipeline_context).transform(pd.DataFrame())
assert result.empty
def test_extract_merges_all(pipeline_context, session_interactions):
result = ExtractSessionFeaturesStep(pipeline_context).transform(session_interactions)
expected = ['session_duration_sec', 'total_events', 'unique_products_viewed', 'is_headless']
for col in expected:
assert col in result.columns
assert 'experimentId' in result.columns
# JoinLabelsStep tests
def test_join_labels_tuple_input(pipeline_context):
features = pd.DataFrame({'sessionId': ['s1'], 'experimentId': ['exp1'], 'total_events': [5]})
experiments = pd.DataFrame({'id': ['exp1'], 'xp_human_only': [True]})
result = JoinLabelsStep(pipeline_context).transform((features, experiments))
assert 'is_agent' in result.columns
assert result.iloc[0]['is_agent'] == False
def test_join_labels_empty_experiments(pipeline_context):
features = pd.DataFrame({'sessionId': ['s1'], 'experimentId': ['exp1']})
result = JoinLabelsStep(pipeline_context).transform((features, pd.DataFrame()))
assert pd.isna(result.iloc[0]['is_agent'])
# ValidateDataStep tests
def test_validate_empty(pipeline_context):
ValidateDataStep(pipeline_context).transform(pd.DataFrame())
report = pipeline_context.get_cached('validation_report')
assert report['status'] == 'empty'
def test_validate_missing_cols(pipeline_context):
df = pd.DataFrame({'sessionId': ['s1'], 'ts': ['2025-01-01']})
ValidateDataStep(pipeline_context).transform(df)
report = pipeline_context.get_cached('validation_report')
assert report['status'] == 'invalid'
assert 'eventName' in report['missing_cols']
def test_validate_valid(pipeline_context, session_interactions):
ValidateDataStep(pipeline_context).transform(session_interactions)
report = pipeline_context.get_cached('validation_report')
assert report['status'] == 'valid'
assert report['sessions'] > 0

128
lib/separability.py Normal file
View File

@@ -0,0 +1,128 @@
"""Utilities for loading separability artifacts and scoring interaction sessions."""
from __future__ import annotations
import json
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, Iterable, List, Sequence
import joblib
import numpy as np
from experiments.ml.arch import featurize_trajectory
DEFAULT_ARTIFACT_DIR = Path("data/separability")
@dataclass
class SeparabilityArtifacts:
scaler: object
classifier: object
states: List[str]
event_transitions: Dict[str, Dict[str, float]]
feature_dim: int
def _normalize_events(raw_events: Sequence[object]) -> List[object]:
events: List[object] = []
for evt in raw_events:
if hasattr(evt, "value") and hasattr(evt.value, "payload"):
events.append(evt.value.payload)
else:
events.append(evt)
events.sort(key=lambda e: getattr(e, "ts", ""))
return events
def _event_transition_distribution(events: Sequence[object]) -> Dict[str, Dict[str, float]]:
counts: Dict[str, Dict[str, int]] = {}
for src_evt, dst_evt in zip(events, events[1:]):
src_name = getattr(src_evt, "eventName", "unknown")
dst_name = getattr(dst_evt, "eventName", "unknown")
counts.setdefault(src_name, {})
counts[src_name][dst_name] = counts[src_name].get(dst_name, 0) + 1
distribution: Dict[str, Dict[str, float]] = {}
for src, dsts in counts.items():
total = float(sum(dsts.values()))
distribution[src] = {dst: val / total for dst, val in dsts.items()} if total else {}
return distribution
def _kl_divergence(p: Dict[str, Dict[str, float]], q: Dict[str, Dict[str, float]]) -> float:
eps = 1e-10
total = 0.0
for src, dsts in p.items():
for dst, prob in dsts.items():
ref = q.get(src, {}).get(dst, 0.0)
total += (prob + eps) * np.log((prob + eps) / (ref + eps))
return float(total)
def load_artifacts(artifact_dir: Path | str = DEFAULT_ARTIFACT_DIR) -> SeparabilityArtifacts:
artifact_dir = Path(artifact_dir)
scaler_path = artifact_dir / "scaler.joblib"
model_path = artifact_dir / "classifier.joblib"
metadata_path = artifact_dir / "metadata.json"
if not (scaler_path.exists() and model_path.exists() and metadata_path.exists()):
raise FileNotFoundError(
f"Separability artifacts not found in {artifact_dir}. Run sim.strong_learner.train first."
)
scaler = joblib.load(scaler_path)
classifier = joblib.load(model_path)
with open(metadata_path, "r", encoding="utf-8") as fin:
metadata = json.load(fin)
return SeparabilityArtifacts(
scaler=scaler,
classifier=classifier,
states=list(metadata["reference_states"]),
event_transitions=metadata["event_transitions"],
feature_dim=int(metadata["feature_dim"]),
)
def score_session(
raw_events: Sequence[object],
artifacts: SeparabilityArtifacts,
) -> dict:
events = _normalize_events(raw_events)
if not events:
return {"prob_agent": 0.0, "delta_h": 0.0, "delta_a": 0.0}
reference_mdp = {"states": artifacts.states}
features = featurize_trajectory(events, mdp=reference_mdp, input_dim=artifacts.feature_dim)
scaled = artifacts.scaler.transform(features.reshape(1, -1))
prob_agent = float(artifacts.classifier.predict_proba(scaled)[0, 1])
session_dist = _event_transition_distribution(events)
delta_h = _kl_divergence(session_dist, artifacts.event_transitions.get("human", {}))
delta_a = _kl_divergence(session_dist, artifacts.event_transitions.get("agent", {}))
return {
"prob_agent": prob_agent,
"delta_h": delta_h,
"delta_a": delta_a,
}
def estimate_alpha(prob_agent: float, delta_h: float, delta_a: float, temperature: float = 1.0) -> float:
divergence_mass = delta_h + delta_a
if divergence_mass <= 1e-8:
return float(prob_agent)
ratio = delta_a / divergence_mass
blended = 0.5 * prob_agent + 0.5 * ratio
if temperature <= 0:
return float(np.clip(blended, 0.0, 1.0))
scaled = 1.0 / (1.0 + np.exp(-temperature * (blended - 0.5)))
return float(np.clip(scaled, 0.0, 1.0))
def score_sessions(raw_sessions: Iterable[Sequence[object]], artifacts: SeparabilityArtifacts) -> List[dict]:
return [score_session(events, artifacts) for events in raw_sessions]

View File

@@ -0,0 +1,69 @@
\section{Problem Formulation: A Stackelberg Game Approach}
\label{sec:math_formulation}
We formalize the interaction between the dynamic pricing system and non-human actors as a \textit{Stackelberg Game} (Leader-Follower) with incomplete information. This framework captures the hierarchical nature of the problem: the Platform (Leader) sets a pricing policy, and the Actors (Followers)---both Humans and Agents---observe these prices and react strategically.
\subsection{The Players and Objectives}
Let $t \in \{1, \dots, T\}$ denote discrete time steps. At each step, the system interactions are defined by the following entities:
\paragraph{1. The Leader (The Platform)}
The e-commerce platform acts as the leader, choosing a pricing policy $\pi$ to maximize total expected revenue. At time $t$, given a state $s_t \in \mathcal{S}$ (representing inventory, time of day, and historical interactions), the platform sets a price $p_t \in [p_{\min}, p_{\max}]$.
The platform's goal is to maximize the cumulative revenue from genuine human transactions while mitigating the distortion caused by agent interactions.
\paragraph{2. The Followers (The Demand Mixture)}
The observed demand is not a monolithic signal but a mixture of two distinct populations with divergent objective functions. Let $u$ denote an incoming actor. The type of the actor $\theta \in \{H, A\}$ is a latent variable, where $H$ denotes a Human and $A$ denotes an Agent.
\begin{itemize}
\item \textbf{The Human ($H$):} Acts as a \textit{myopic utility maximizer}. A human $i$ has a private valuation $v_i$ for the product. They execute a purchase decision $d_i \in \{0, 1\}$ based on the consumer surplus:
\begin{equation}
d_i(p_t) = \mathbb{I}(v_i - p_t \geq 0)
\end{equation}
where $\mathbb{I}(\cdot)$ is the indicator function. The aggregate human demand $q_H(p_t)$ follows a standard downward-sloping demand curve $D(p_t)$.
\item \textbf{The Agent ($A$):} Acts as an \textit{information maximizer} (reconnaissance). The agent does not intend to purchase at the displayed price $p_t$ unless an arbitrage condition is met. Instead, the agent generates interaction events (queries) to estimate the platform's pricing function $f(p)$. The agent's reward function $R_A$ is defined by Information Gain:
\begin{equation}
R_A(p_t) = H(\mathcal{P}) - H(\mathcal{P} \mid p_t) - c_{query}
\end{equation}
where $H(\mathcal{P})$ is the entropy of the agent's belief regarding the price distribution, and $c_{query}$ is the marginal cost of interaction (assumed $\approx 0$ for LLMs).
\end{itemize}
\subsection{The Demand Contamination Model}
% MAYBE alpha has to be \lambda which we also need to formally define still
The core difficulty in this setting is that the platform observes only the aggregate interaction volume $\hat{q}_t$, which is a contaminated signal. Let $\alpha_t \in [0, 1]$ represent the proportion of traffic generated by agents at time $t$. The observed signal is:
\begin{equation}
\hat{q}_t(p_t) = (1 - \alpha_t) \cdot q_H(p_t) + \alpha_t \cdot q_A(p_t) + \epsilon_t
\end{equation}
where:
\begin{itemize}
\item $q_H(p_t)$ is the \textit{true signal} (conversion intent).
\item $q_A(p_t)$ is the \textit{adversarial noise} (reconnaissance queries).
\item $\epsilon_t$ is random market noise.
\end{itemize}
Crucially, $q_A(p_t)$ is often inversely correlated with $q_H(p_t)$ in terms of utility; agents may flood the system with queries during high-volatility periods to map price boundaries, artificially inflating $\hat{q}_t$ without converting.
\subsection{The Optimization Objective: Robust Revenue}
Standard dynamic pricing algorithms (e.g., Thompson Sampling or UCB) assume $\alpha_t = 0$, estimating demand $\hat{D}(p) \approx \mathbb{E}[\hat{q} | p]$. In the presence of agents ($\alpha_t > 0$), this estimator becomes biased, leading to the \textit{Cost of Information} (COI) defined in Section 3.2.
We propose a robust optimization objective. The platform seeks a pricing policy $\pi^*$ that maximizes worst-case revenue over a statistically plausible set of contamination rates $\alpha$:
\begin{equation}
\pi^* = \argmax_{\pi} \sum_{t=1}^T \mathbb{E}_{s_t} \left[ \min_{\alpha} \left( p_t \cdot \hat{q}_t(p_t | \theta=H) \right) - \lambda \cdot \mathcal{L}_{detect}(\hat{q}_t) \right]
\end{equation}
Here:
\begin{itemize}
\item The first term, $p_t \cdot \hat{q}_t(p_t | \theta=H)$, represents the revenue generated strictly from the estimated human segment.
\item $\mathcal{L}_{detect}$ is a penalty term for failing to separate distributions (the cost of confusion).
\item $\lambda$ is a hyperparameter balancing revenue exploitation vs. robust detection.
\end{itemize}
This formulation effectively transforms the pricing problem into a \textit{Distributionally Robust Optimization (DRO)} problem, where the learner must guard against adversarial perturbations (Agent traffic) in the observed demand distribution.

BIN
paper/src/graphics/gcp.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 19 KiB

BIN
paper/src/graphics/gcp.webp Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 8.3 KiB

12
pyproject.toml Normal file
View File

@@ -0,0 +1,12 @@
[build-system]
requires = ["setuptools>=45", "wheel"]
build-backend = "setuptools.build_meta"
[project]
name = "phantom"
version = "0.1.0"
description = "Pricing Heuristics Against Non-human Transaction Orchestration Mechanisms"
requires-python = ">=3.8"
[tool.setuptools.packages.find]
include = ["experiments*", "lib*"]

32
scripts/tpu_pod_run.sh Executable file
View File

@@ -0,0 +1,32 @@
#!/usr/bin/env sh
# Executed on each TPU pod worker via `gcloud tpu-vm scp` + `gcloud tpu-vm ssh --worker=all`.
# Authenticates with Artifact Registry using the VM's service account metadata token,
# pulls the TPU trainer image, then runs the W&B sweep agent inside Docker.
# TPU chip devices (/dev/accel*) are exposed via --privileged + /dev volume mount.
# Required env vars: WANDB_API_KEY, SWEEP_ID
# Optional: AGENT_COUNT (default 1, 0 = run until sweep ends)
set -eu
IMAGE="us-central1-docker.pkg.dev/phantom-trc/phantom/phantom-trainer:tpu-latest"
AGENT_COUNT="${AGENT_COUNT:-1}"
# use VM service account — no manual key needed on the pod
TOKEN=$(curl -sf -H "Metadata-Flavor: Google" \
"http://metadata.google.internal/computeMetadata/v1/instance/service-accounts/default/token" \
| python3 -c 'import sys, json; print(json.load(sys.stdin)["access_token"])')
echo "$TOKEN" | sudo docker login -u oauth2accesstoken \
--password-stdin https://us-central1-docker.pkg.dev
sudo docker pull "$IMAGE"
# --privileged + /dev mount gives the container access to /dev/accel* (TPU chips)
# --network host lets JAX reach the other pod workers for distributed init
sudo docker run --rm \
--privileged \
--network host \
--volume /dev:/dev \
-e WANDB_API_KEY="$WANDB_API_KEY" \
-e SWEEP_ID="$SWEEP_ID" \
-e AGENT_COUNT="$AGENT_COUNT" \
"$IMAGE"

83
scripts/tpu_sync_repo.sh Normal file
View File

@@ -0,0 +1,83 @@
#!/usr/bin/env sh
set -eu
TPU_NAME="${TPU_NAME:?TPU_NAME is required}"
TPU_ZONE="${TPU_ZONE:-us-central2-b}"
TPU_PROJECT="${TPU_PROJECT:-phantom-trc}"
LOCAL_REPO_DIR="${LOCAL_REPO_DIR:-$(pwd)}"
REMOTE_REPO_DIR="${REMOTE_REPO_DIR:-/tmp/PHANTOM}"
ARCHIVE_PATH="${ARCHIVE_PATH:-/tmp/phantom-sync.tgz}"
FILE_LIST="$(mktemp /tmp/phantom-sync-files.XXXXXX)"
CLEANUP_LIST=true
cleanup() {
if [ "$CLEANUP_LIST" = "true" ]; then
rm -f "$FILE_LIST"
fi
}
trap cleanup EXIT
if [ ! -d "$LOCAL_REPO_DIR" ]; then
echo "local repo directory not found: $LOCAL_REPO_DIR"
exit 1
fi
if git -C "$LOCAL_REPO_DIR" rev-parse --is-inside-work-tree >/dev/null 2>&1; then
git -C "$LOCAL_REPO_DIR" ls-files -co --exclude-standard > "$FILE_LIST"
python3 - "$FILE_LIST" <<'PY'
import sys
from pathlib import Path
file_list = Path(sys.argv[1])
skip_prefixes = (
"wandb/",
".venv/",
"venv/",
"node_modules/",
".next/",
".turbo/",
"__pycache__/",
".mypy_cache/",
".pytest_cache/",
".ruff_cache/",
"paper/build/",
"tests/e2e/test-results/",
)
rows = file_list.read_text().splitlines()
kept = [
row
for row in rows
if row and not any(row == p.rstrip("/") or row.startswith(p) for p in skip_prefixes)
]
file_list.write_text("\n".join(kept) + ("\n" if kept else ""))
PY
tar -czf "$ARCHIVE_PATH" -C "$LOCAL_REPO_DIR" -T "$FILE_LIST"
else
tar \
--exclude-vcs \
--exclude=".venv" --exclude="*/.venv" \
--exclude="venv" --exclude="*/venv" \
--exclude="node_modules" --exclude="*/node_modules" \
--exclude=".next" --exclude="*/.next" \
--exclude=".turbo" --exclude="*/.turbo" \
--exclude="__pycache__" --exclude="*/__pycache__" \
--exclude=".mypy_cache" --exclude="*/.mypy_cache" \
--exclude=".pytest_cache" --exclude="*/.pytest_cache" \
--exclude=".ruff_cache" --exclude="*/.ruff_cache" \
--exclude="wandb" --exclude="*/wandb" \
--exclude="paper/build" \
--exclude="tests/e2e/test-results" \
-czf "$ARCHIVE_PATH" \
-C "$LOCAL_REPO_DIR" .
fi
gcloud compute tpus tpu-vm scp "$ARCHIVE_PATH" "$TPU_NAME:/tmp/phantom-sync.tgz" \
--zone="$TPU_ZONE" --project="$TPU_PROJECT" --worker=all
gcloud compute tpus tpu-vm ssh "$TPU_NAME" \
--zone="$TPU_ZONE" --project="$TPU_PROJECT" --worker=all \
--command="rm -rf '$REMOTE_REPO_DIR' && mkdir -p '$REMOTE_REPO_DIR' && tar -xzf /tmp/phantom-sync.tgz -C '$REMOTE_REPO_DIR' && rm -f /tmp/phantom-sync.tgz"
rm -f "$ARCHIVE_PATH"

View File

@@ -0,0 +1,183 @@
#!/usr/bin/env python3
from __future__ import annotations
import argparse
import json
import os
import re
import shlex
import subprocess
import time
from pathlib import Path
import wandb
CLI_MAP: dict[str, str] = {
"algo": "--algo",
"total_timesteps": "--total-timesteps",
"alpha": "--alpha",
"N": "--N",
"n_products": "--n-products",
"lambda_coi": "--lambda-coi",
"info_value": "--info-value",
"robust_radius": "--robust-radius",
"robust_points": "--robust-points",
"learning_rate": "--learning-rate",
"gamma": "--gamma",
"gae_lambda": "--gae-lambda",
"clip_range": "--clip-range",
"ent_coef": "--ent-coef",
"revenue_weight": "--revenue-weight",
"max_steps": "--max-steps",
"margin_floor": "--margin-floor",
"margin_floor_patience": "--margin-floor-patience",
"arch": "--arch",
"activation": "--activation",
"jax_num_envs": "--jax-num-envs",
"jax_num_steps": "--jax-num-steps",
"jax_num_minibatches": "--jax-num-minibatches",
"jax_update_epochs": "--jax-update-epochs",
"jax_anneal_lr": "--jax-anneal-lr",
"checkpoint_interval": "--checkpoint-interval",
"action_levels": "--action-levels",
"action_scale_low": "--action-scale-low",
"action_scale_high": "--action-scale-high",
}
def _to_cli_args(cfg: dict) -> str:
parts: list[str] = ["--jax", "--no-wandb"]
for key, flag in CLI_MAP.items():
if key not in cfg:
continue
value = cfg[key]
if value is None:
continue
if isinstance(value, bool):
if key == "jax_anneal_lr":
parts.extend([flag, "true" if value else "false"])
elif value:
parts.append(flag)
continue
parts.extend([flag, str(value)])
return " ".join(shlex.quote(p) for p in parts)
_SENTINEL = "PHANTOM_METRICS:"
def _extract_metrics(output: str) -> dict:
# fast path: look for the dedicated sentinel line emitted by run_local
for line in output.splitlines():
if line.startswith(_SENTINEL):
try:
return json.loads(line[len(_SENTINEL) :])
except Exception:
break
# fallback: scan for any JSON block containing eval/sweep keys;
# use greedy match to capture the largest possible block first
for block in re.findall(r"\{[^{}]*\}", output):
try:
obj = json.loads(block)
except Exception:
continue
if isinstance(obj, dict) and ("sweep/score" in obj or "eval/reward" in obj):
return obj
return {}
def main() -> None:
p = argparse.ArgumentParser(
description="Run W&B sweep where each trial uses full TPU pod"
)
p.add_argument("--sweep-id", required=True)
p.add_argument("--tpu-name", required=True)
p.add_argument("--tpu-zone", default="us-central2-b")
p.add_argument("--tpu-project", default="phantom-trc")
p.add_argument("--tpu-repo-dir", default="/tmp/PHANTOM")
p.add_argument("--count", type=int, default=0)
p.add_argument("--workdir", default=str(Path(__file__).resolve().parents[1]))
args = p.parse_args()
workdir = Path(args.workdir).resolve()
env = os.environ.copy()
prepare_cmd = [
"make",
"train.tpu.vm.prepare",
f"TPU_NAME={args.tpu_name}",
f"TPU_ZONE={args.tpu_zone}",
f"TPU_PROJECT={args.tpu_project}",
f"TPU_REPO_DIR={args.tpu_repo_dir}",
]
prepare = subprocess.run(
prepare_cmd,
cwd=workdir,
env=env,
text=True,
capture_output=False,
check=False,
)
if prepare.returncode != 0:
raise RuntimeError("Failed to prepare TPU workers for sweep")
def run_trial() -> None:
run = None
try:
run = wandb.init()
cfg = dict(wandb.config)
cli_args = _to_cli_args(cfg)
env_trial = dict(env)
env_trial["LOCAL_TRAIN_ARGS"] = cli_args
cmd = [
"make",
"train.tpu.vm.run",
f"TPU_NAME={args.tpu_name}",
f"TPU_ZONE={args.tpu_zone}",
f"TPU_PROJECT={args.tpu_project}",
f"TPU_REPO_DIR={args.tpu_repo_dir}",
]
proc = subprocess.run(
cmd,
cwd=workdir,
env=env_trial,
text=True,
capture_output=True,
check=False,
)
if proc.stdout:
print(proc.stdout)
if proc.stderr:
print(proc.stderr)
if proc.returncode != 0:
if run is not None:
run.summary["runner/exit_code"] = proc.returncode
raise RuntimeError(f"TPU trial failed with exit code {proc.returncode}")
metrics = _extract_metrics(proc.stdout)
if metrics:
wandb.log(metrics)
for k, v in metrics.items():
run.summary[k] = v
run.summary["runner/exit_code"] = 0
except Exception:
time.sleep(2)
raise
finally:
if run is not None and wandb.run is not None:
wandb.finish()
wandb.agent(
args.sweep_id,
function=run_trial,
count=args.count if args.count > 0 else None,
)
if __name__ == "__main__":
main()

43
scripts/tpu_vm_train.sh Normal file
View File

@@ -0,0 +1,43 @@
#!/usr/bin/env sh
set -eu
REPO_DIR="${REPO_DIR:-$HOME/PHANTOM}"
PYTHON_BIN="${PYTHON_BIN:-python3}"
TRAIN_ARGS="${TRAIN_ARGS:---algo ppo --jax --total-timesteps 200000 --jax-num-envs 32 --jax-num-steps 128 --jax-num-minibatches 4 --jax-update-epochs 4}"
EXTRA_PIP="${EXTRA_PIP:-flax optax distrax}"
INSTALL_FULL_REQUIREMENTS="${INSTALL_FULL_REQUIREMENTS:-0}"
if [ ! -d "$REPO_DIR" ]; then
echo "repo directory not found: $REPO_DIR"
exit 1
fi
cd "$REPO_DIR"
if [ -d "wandb" ]; then
rm -rf wandb
fi
# keep install idempotent and avoid re-installing jax/libtpu each run
if [ "$INSTALL_FULL_REQUIREMENTS" = "1" ] && [ -f "requirements.txt" ]; then
$PYTHON_BIN -m pip install -r requirements.txt
fi
if ! $PYTHON_BIN -c 'import flax, optax, distrax' >/dev/null 2>&1; then
if [ -f "engine/jax/requirements.txt" ]; then
$PYTHON_BIN -m pip install -r engine/jax/requirements.txt
fi
$PYTHON_BIN -m pip install -U $EXTRA_PIP
fi
if [ -n "${WANDB_API_KEY:-}" ]; then
if ! $PYTHON_BIN -c 'import wandb; import inspect; assert hasattr(wandb, "init") and callable(wandb.init)' >/dev/null 2>&1; then
$PYTHON_BIN -m pip install -U wandb
fi
fi
if [ -n "${WANDB_API_KEY:-}" ]; then
export WANDB_API_KEY
exec $PYTHON_BIN -m engine.train $TRAIN_ARGS
fi
exec $PYTHON_BIN -m engine.train $TRAIN_ARGS --no-wandb

108
scripts/wandb_agent_bootstrap.sh Executable file
View File

@@ -0,0 +1,108 @@
#!/usr/bin/env bash
set -euo pipefail
need_env() {
local name="$1"
if [ -z "${!name:-}" ]; then
echo "$name is required"
exit 1
fi
}
need_cmd() {
local c="$1"
command -v "$c" >/dev/null 2>&1 || {
echo "Missing command: $c"
exit 1
}
}
need_cmd git
need_cmd python3
need_env WANDB_API_KEY
need_env GITHUB_TOKEN
need_env REPO_URL
need_env SWEEP_ID
BRANCH="${BRANCH:-main}"
WORKDIR="${WORKDIR:-$HOME/PHANTOM-agent}"
AGENT_COUNT="${AGENT_COUNT:-0}"
AGENT_LOOP="${AGENT_LOOP:-1}"
RETRY_SECONDS="${RETRY_SECONDS:-20}"
PYTHON_BIN="${PYTHON_BIN:-python3}"
mkdir -p "$(dirname "$WORKDIR")"
ASKPASS_FILE="$(mktemp)"
cat >"$ASKPASS_FILE" <<'EOF'
#!/usr/bin/env sh
case "$1" in
*Username*) echo "x-access-token" ;;
*Password*) echo "$GITHUB_TOKEN" ;;
*) echo "" ;;
esac
EOF
chmod 700 "$ASKPASS_FILE"
cleanup() {
rm -f "$ASKPASS_FILE"
}
trap cleanup EXIT
git_auth() {
GIT_TERMINAL_PROMPT=0 GIT_ASKPASS="$ASKPASS_FILE" git "$@"
}
sync_repo() {
if [ ! -d "$WORKDIR/.git" ]; then
rm -rf "$WORKDIR"
git_auth clone --single-branch --branch "$BRANCH" "$REPO_URL" "$WORKDIR"
return
fi
git -C "$WORKDIR" remote set-url origin "$REPO_URL"
git_auth -C "$WORKDIR" fetch origin "$BRANCH" --prune
git -C "$WORKDIR" checkout -B "$BRANCH" "origin/$BRANCH"
git -C "$WORKDIR" reset --hard "origin/$BRANCH"
}
install_deps() {
"$PYTHON_BIN" -m venv "$WORKDIR/.venv"
"$WORKDIR/.venv/bin/pip" install --upgrade pip
"$WORKDIR/.venv/bin/pip" install -r "$WORKDIR/requirements.txt"
}
run_agent() {
local cmd=("$WORKDIR/.venv/bin/python" -m engine.train --sweep-agent --sweep-id "$SWEEP_ID")
if [ "$AGENT_COUNT" != "0" ]; then
cmd+=(--count "$AGENT_COUNT")
fi
(
cd "$WORKDIR"
WANDB_API_KEY="$WANDB_API_KEY" \
WANDB_ENTITY="${WANDB_ENTITY:-}" \
WANDB_PROJECT="${WANDB_PROJECT:-}" \
"${cmd[@]}"
)
}
while true; do
sync_repo
install_deps
if run_agent; then
if [ "$AGENT_LOOP" = "1" ] && [ "$AGENT_COUNT" = "0" ]; then
sleep "$RETRY_SECONDS"
continue
fi
exit 0
fi
if [ "$AGENT_LOOP" != "1" ]; then
exit 1
fi
sleep "$RETRY_SECONDS"
done

7
sim/requirements.txt Normal file
View File

@@ -0,0 +1,7 @@
gymnasium>=0.29.0
numpy>=1.24.0
pandas>=2.0.0
stable-baselines3>=2.2.0
tensorboard>=2.15.0
jax>=0.4.20
jaxlib>=0.4.20

View File

@@ -0,0 +1,117 @@
import numpy as np
import matplotlib.pyplot as plt
from collections import defaultdict
from models import BehaviorModel, AgentBehaviorModel, aggregate_event_transitions, kl_divergence
def event_frequency_distribution(mdp):
evt_cnt, total = defaultdict(int), 0
for s, trans in mdp['transitions'].items():
evt = s.split('|')[2]
for cnt in mdp['trans_counts'][s].values():
evt_cnt[evt] += cnt
total += cnt
return {evt: cnt/total for evt, cnt in evt_cnt.items()} if total > 0 else {}
def transition_distribution(mdp):
trans_cnt, total = defaultdict(int), 0
for s, trans in mdp['trans_counts'].items():
src = s.split('|')[2]
for s_next, cnt in trans.items():
dst = s_next.split('|')[2]
trans_cnt[f"{src}->{dst}"] += cnt
total += cnt
return {t: cnt/total for t, cnt in trans_cnt.items()} if total > 0 else {}
def kl_color(kl):
return '#d62828' if kl > 2.0 else '#f77f00' if kl > 0.5 else '#2a9d8f'
def plot_comparison(ax, human_vals, agent_vals, labels, title, ylabel, kl_val=None):
x, w = np.arange(len(labels)), 0.35
ax.bar(x - w/2, human_vals, w, label='Human', alpha=0.8, color='#2E86AB')
ax.bar(x + w/2, agent_vals, w, label='Agent', alpha=0.8, color='#A23B72')
ax.set_ylabel(ylabel, fontsize=9 if len(labels) > 10 else 11, fontweight='bold')
ax.set_title(title if not kl_val else f"{title}\nKL={kl_val:.4f}",
fontsize=10 if len(labels) > 10 else 12, fontweight='bold')
ax.set_xticks(x)
ax.set_xticklabels(labels, rotation=45, ha='right', fontsize=8)
ax.legend(fontsize=8)
ax.grid(axis='y', alpha=0.3, linestyle='--')
return ax
if __name__ == "__main__":
base_dir = "/home/velocitatem/Documents/Projects/PHANTOM/experiments"
human_dir, agent_dir = f"{base_dir}/collected_data/", f"{base_dir}/agents/collected_data/"
human_model, agent_model = BehaviorModel(human_dir), AgentBehaviorModel(agent_dir)
human_mdp, agent_mdp = human_model.build_MDP(), agent_model.build_MDP()
human_evt, agent_evt = aggregate_event_transitions(human_mdp), aggregate_event_transitions(agent_mdp)
common = set(human_evt.keys()) & set(agent_evt.keys())
kl_results = sorted([(e, kl_divergence(human_evt[e], agent_evt[e])) for e in common],
key=lambda x: x[1], reverse=True)
fig = plt.figure(figsize=(16, 10))
n_rows, n_cols = (len(kl_results) + 1) // 2, 2
for idx, (evt, kl) in enumerate(kl_results):
ax = plt.subplot(n_rows, n_cols, idx + 1)
h_dist, a_dist = human_evt.get(evt, {}), agent_evt.get(evt, {})
dests = sorted(set(h_dist.keys()) | set(a_dist.keys()))
if not dests: continue
h_probs, a_probs = [h_dist.get(d, 0) for d in dests], [a_dist.get(d, 0) for d in dests]
plot_comparison(ax, h_probs, a_probs, dests, f'From: {evt}', 'Probability')
ax.set_ylim([0, max(max(h_probs + a_probs, default=0) * 1.1, 0.1)])
ax.text(0.95, 0.95, f'KL={kl:.2f}', transform=ax.transAxes, fontsize=11,
fontweight='bold', va='top', ha='right',
bbox=dict(boxstyle='round', facecolor=kl_color(kl), alpha=0.3))
plt.tight_layout()
plt.savefig('kl_divergence_comparison.png', dpi=300, bbox_inches='tight')
print("Saved visualization to kl_divergence_comparison.png")
fig2, ax2 = plt.subplots(figsize=(10, 6))
evts, kls = zip(*kl_results) if kl_results else ([], [])
colors = [kl_color(kl) for kl in kls]
bars = ax2.barh(evts, kls, color=colors, alpha=0.8)
ax2.set_xlabel('KL Divergence D(Human || Agent)', fontsize=12, fontweight='bold')
ax2.set_ylabel('Event Type', fontsize=12, fontweight='bold')
ax2.set_title('Behavioral Divergence Between Human and Agent Traffic', fontsize=14, fontweight='bold')
if kls:
ax2.axvline(x=np.mean(kls), color='black', linestyle='--', linewidth=2,
alpha=0.5, label=f'Mean={np.mean(kls):.2f}')
for bar, kl in zip(bars, kls):
ax2.text(bar.get_width() + 0.1, bar.get_y() + bar.get_height()/2,
f'{kl:.2f}', ha='left', va='center', fontsize=10, fontweight='bold')
ax2.legend()
ax2.grid(axis='x', alpha=0.3, linestyle='--')
plt.tight_layout()
plt.savefig('kl_summary.png', dpi=300, bbox_inches='tight')
print("Saved KL summary to kl_summary.png")
h_freq, a_freq = event_frequency_distribution(human_mdp), event_frequency_distribution(agent_mdp)
h_trans, a_trans = transition_distribution(human_mdp), transition_distribution(agent_mdp)
freq_kl, trans_kl = kl_divergence(h_freq, a_freq), kl_divergence(h_trans, a_trans)
print(f"\n=== Global Distribution KL Divergence ===")
print(f"Event frequency KL: {freq_kl:.4f}")
print(f"Transition pair KL: {trans_kl:.4f}")
fig3, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))
all_evts = sorted(set(h_freq.keys()) | set(a_freq.keys()))
h_freqs, a_freqs = [h_freq.get(e, 0) for e in all_evts], [a_freq.get(e, 0) for e in all_evts]
plot_comparison(ax1, h_freqs, a_freqs, all_evts, 'Event Frequency Distribution',
'Frequency', freq_kl)
all_trans = sorted(set(h_trans.keys()) | set(a_trans.keys()))
top_trans = [t for t, _ in sorted([(t, h_trans.get(t, 0) + a_trans.get(t, 0))
for t in all_trans], key=lambda x: x[1], reverse=True)[:15]]
h_tprobs, a_tprobs = [h_trans.get(t, 0) for t in top_trans], [a_trans.get(t, 0) for t in top_trans]
plot_comparison(ax2, h_tprobs, a_tprobs, top_trans, 'Top Transition Pairs Distribution',
'Probability', trans_kl)
plt.tight_layout()
plt.savefig('global_distributions.png', dpi=300, bbox_inches='tight')
print("Saved global distributions to global_distributions.png")

86
sim/rl/thesis_core.py Normal file
View File

@@ -0,0 +1,86 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Dict, Optional
import numpy as np
from sim.case.thesis_simplified.simplified import Session
@dataclass(frozen=True)
class PricingStep:
sessions: list[Session]
demand_by_session: Dict[str, float]
demand_by_product: np.ndarray
purchases_by_product: np.ndarray
revenue: float
cost: float
n_agents: int
def clip_prices(prices: np.ndarray, min_price: float, max_price: float) -> np.ndarray:
return np.clip(prices, min_price, max_price).astype(np.float32)
def constrain_prices(
prev_prices: Optional[np.ndarray],
proposed: np.ndarray,
*,
costs: np.ndarray,
min_price: float,
max_price: float,
max_adjustment: float,
min_margin_pct: float,
) -> np.ndarray:
prices = clip_prices(proposed, min_price, max_price)
floor = (costs * (1.0 + float(min_margin_pct))).astype(np.float32)
prices = np.maximum(prices, floor)
if prev_prices is None:
return prices
prev_prices = prev_prices.astype(np.float32)
ratio = np.clip(prices / (prev_prices + 1e-6), 1.0 - max_adjustment, 1.0 + max_adjustment)
return (prev_prices * ratio).astype(np.float32)
def aggregate_demand_by_product(
sessions: list[Session],
demand_by_session: Dict[str, float],
n_products: int,
) -> np.ndarray:
demand = np.zeros(n_products, dtype=np.float32)
sessions_by_id = {s.sid: s for s in sessions}
for sid, q in demand_by_session.items():
sess = sessions_by_id.get(sid)
if not sess or not sess.events:
continue
pidx = int(sess.events[0].product_idx)
if 0 <= pidx < n_products:
demand[pidx] += float(q)
return demand
def aggregate_purchases(
sessions: list[Session],
costs: np.ndarray,
n_products: int,
) -> tuple[np.ndarray, float, float, int]:
purchases = np.zeros(n_products, dtype=np.float32)
revenue = 0.0
cost = 0.0
n_agents = 0
for sess in sessions:
if sess.actor == "A":
n_agents += 1
for e in sess.events:
if e.action != "purchase":
continue
pidx = int(e.product_idx)
if 0 <= pidx < n_products:
purchases[pidx] += 1.0
revenue += float(e.price_seen)
cost += float(costs[pidx])
return purchases, revenue, cost, n_agents