mirror of
https://github.com/velocitatem/PHANTOM.git
synced 2026-05-31 16:43:36 +00:00
17
.dockerignore
Normal file
17
.dockerignore
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
.git
|
||||||
|
.venv
|
||||||
|
.venv-tpu
|
||||||
|
**/__pycache__
|
||||||
|
**/*.pyc
|
||||||
|
**/*.pyo
|
||||||
|
**/.pytest_cache
|
||||||
|
**/.mypy_cache
|
||||||
|
**/.ruff_cache
|
||||||
|
**/.ipynb_checkpoints
|
||||||
|
wandb
|
||||||
|
build
|
||||||
|
paper/build
|
||||||
|
paper/build-cais
|
||||||
|
node_modules
|
||||||
|
**/node_modules
|
||||||
|
*.egg-info
|
||||||
18
.env.sweep.example
Normal file
18
.env.sweep.example
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
# Copy this file to .env.sweep and fill in values.
|
||||||
|
|
||||||
|
# Required for wandb runs and sweep agent workers.
|
||||||
|
WANDB_API_KEY=
|
||||||
|
WANDB_ENTITY=
|
||||||
|
WANDB_PROJECT=phantom-pricing
|
||||||
|
|
||||||
|
# Required for private repo bootstrap workers.
|
||||||
|
GITHUB_TOKEN=
|
||||||
|
|
||||||
|
# Optional defaults for bootstrap mode.
|
||||||
|
# REPO_URL=https://github.com/org/repo.git
|
||||||
|
# BRANCH=main
|
||||||
|
# WORKDIR=$HOME/PHANTOM-agent
|
||||||
|
# SWEEP_ID=entity/project/id
|
||||||
|
# AGENT_COUNT=0
|
||||||
|
# AGENT_LOOP=1
|
||||||
|
# RETRY_SECONDS=20
|
||||||
57
.gitignore
vendored
57
.gitignore
vendored
@@ -1,21 +1,50 @@
|
|||||||
|
# environment and secrets
|
||||||
**/.env
|
**/.env
|
||||||
|
.env.*
|
||||||
|
!.env.*.example
|
||||||
**/.venv
|
**/.venv
|
||||||
|
|
||||||
|
# python build/cache artifacts
|
||||||
**/__pycache__
|
**/__pycache__
|
||||||
|
phantom.egg-info/
|
||||||
|
*.egg-info/
|
||||||
|
|
||||||
|
# notebook artifacts
|
||||||
**/.ipynb_checkpoints/
|
**/.ipynb_checkpoints/
|
||||||
**/.virtual_documents/
|
**/.virtual_documents/
|
||||||
|
|
||||||
|
# editor/tool state
|
||||||
|
**/.pdf-view-restore
|
||||||
|
.nextstep
|
||||||
|
.ignore-gitlogue
|
||||||
|
.cloudflare
|
||||||
|
|
||||||
|
# generated svg/graphics
|
||||||
**/session_*.svg
|
**/session_*.svg
|
||||||
**/*graph.svg
|
**/*graph.svg
|
||||||
**/auto/*.el
|
**/auto/*.el
|
||||||
|
|
||||||
|
# misc generated
|
||||||
*.old
|
*.old
|
||||||
**/package-lock.json
|
**/package-lock.json
|
||||||
**/*.parquet
|
**/*.parquet
|
||||||
**/_build/
|
**/_build/
|
||||||
|
|
||||||
|
# paper build artifacts
|
||||||
paper/src/bib/auto
|
paper/src/bib/auto
|
||||||
**/_build/
|
|
||||||
paper/src/auto/*
|
paper/src/auto/*
|
||||||
paper/src/bib/auto
|
paper/src/bib/auto
|
||||||
paper/template/*
|
paper/template/*
|
||||||
|
paper/build-cais/
|
||||||
|
paper/src/main.pdf
|
||||||
|
paper/src/main-blx.bib
|
||||||
|
paper/src/svg-inkscape/
|
||||||
|
paper/src/mirrors/
|
||||||
|
paper/variations/
|
||||||
|
paper/src/graphics/test_*.png
|
||||||
|
thesis-latest.pdf
|
||||||
|
|
||||||
|
# experiment run artifacts and logs
|
||||||
docs/goals/*.md
|
docs/goals/*.md
|
||||||
PHANTOM.wiki/
|
PHANTOM.wiki/
|
||||||
experiments/airflow/logs/*
|
experiments/airflow/logs/*
|
||||||
@@ -23,11 +52,35 @@ experiments/airflow/logs/scheduler/
|
|||||||
experiments/airflow/logs/dag_processor_manager/
|
experiments/airflow/logs/dag_processor_manager/
|
||||||
experiments/collected_data/
|
experiments/collected_data/
|
||||||
experiments/agents/collected_data/
|
experiments/agents/collected_data/
|
||||||
|
tests/e2e/test-results/
|
||||||
|
tests/e2e/node_modules/**
|
||||||
|
|
||||||
|
# rl/sim run outputs
|
||||||
sim/rl/behavior_loader/*.dot
|
sim/rl/behavior_loader/*.dot
|
||||||
sim/rl/behavior_loader/*.png
|
sim/rl/behavior_loader/*.png
|
||||||
sim/rl/behavior_loader/*.svg
|
sim/rl/behavior_loader/*.svg
|
||||||
sim/rl/behavior_loader/*.pdf
|
sim/rl/behavior_loader/*.pdf
|
||||||
tests/e2e/node_modules/**
|
sim/rl/runs/
|
||||||
lab/case/thesis/runs*/
|
lab/case/thesis/runs*/
|
||||||
sim/case/thesis_simplified/runs*/
|
sim/case/thesis_simplified/runs*/
|
||||||
|
|
||||||
|
# model binaries
|
||||||
|
engine/models/*.zip
|
||||||
|
*.zip
|
||||||
|
|
||||||
|
# wandb local state
|
||||||
|
wandb/
|
||||||
|
|
||||||
|
# data directory (large datasets)
|
||||||
|
data/
|
||||||
|
|
||||||
|
# ktem local app data
|
||||||
|
ktem_app_data/
|
||||||
|
|
||||||
|
# generated visualization pdfs
|
||||||
|
*_mdp_viz.pdf
|
||||||
|
phantom_env_comparison.png
|
||||||
|
sim/phantom_env_comparison.png
|
||||||
|
|
||||||
|
# web clone
|
||||||
PHANTOM_web/*
|
PHANTOM_web/*
|
||||||
|
|||||||
93
engine/sweeps/tpu_jax.yaml
Normal file
93
engine/sweeps/tpu_jax.yaml
Normal file
@@ -0,0 +1,93 @@
|
|||||||
|
method: bayes
|
||||||
|
metric:
|
||||||
|
name: sweep/score
|
||||||
|
goal: maximize
|
||||||
|
command:
|
||||||
|
- ${env}
|
||||||
|
- python
|
||||||
|
- -m
|
||||||
|
- engine.train
|
||||||
|
parameters:
|
||||||
|
# fixed: always use JAX backend so TPU chips are actually exercised
|
||||||
|
use_jax:
|
||||||
|
value: true
|
||||||
|
# all four algos have JAX implementations
|
||||||
|
algo:
|
||||||
|
values: [ppo, a2c, dqn, qtable]
|
||||||
|
total_timesteps:
|
||||||
|
values: [50000, 80000, 120000]
|
||||||
|
checkpoint_interval:
|
||||||
|
value: 200000
|
||||||
|
seed:
|
||||||
|
values: [13, 42, 77]
|
||||||
|
n_products:
|
||||||
|
values: [8, 10, 12]
|
||||||
|
# COI framework parameters -- primary research variables
|
||||||
|
alpha:
|
||||||
|
distribution: uniform
|
||||||
|
min: 0.1
|
||||||
|
max: 0.6
|
||||||
|
lambda_coi:
|
||||||
|
distribution: uniform
|
||||||
|
min: 0.05
|
||||||
|
max: 0.6
|
||||||
|
robust_radius:
|
||||||
|
distribution: uniform
|
||||||
|
min: 0.0
|
||||||
|
max: 0.3
|
||||||
|
robust_points:
|
||||||
|
values: [3, 5, 7]
|
||||||
|
info_value:
|
||||||
|
distribution: uniform
|
||||||
|
min: 0.5
|
||||||
|
max: 2.0
|
||||||
|
revenue_weight:
|
||||||
|
values: [0.005, 0.01, 0.02]
|
||||||
|
# shared hyperparameters
|
||||||
|
learning_rate:
|
||||||
|
distribution: log_uniform_values
|
||||||
|
min: 1.0e-5
|
||||||
|
max: 1.0e-3
|
||||||
|
gamma:
|
||||||
|
values: [0.97, 0.99, 0.995]
|
||||||
|
# JAX parallelism -- key lever for TPU throughput
|
||||||
|
jax_num_envs:
|
||||||
|
values: [8, 16, 32]
|
||||||
|
jax_num_steps:
|
||||||
|
values: [64, 128, 256]
|
||||||
|
jax_num_minibatches:
|
||||||
|
values: [2, 4, 8]
|
||||||
|
jax_update_epochs:
|
||||||
|
values: [2, 4, 8]
|
||||||
|
# PPO/A2C specific
|
||||||
|
gae_lambda:
|
||||||
|
values: [0.9, 0.95, 0.98]
|
||||||
|
clip_range:
|
||||||
|
values: [0.1, 0.2, 0.3]
|
||||||
|
ent_coef:
|
||||||
|
values: [0.0, 0.005, 0.01]
|
||||||
|
# DQN specific
|
||||||
|
buffer_size:
|
||||||
|
values: [20000, 50000, 100000]
|
||||||
|
batch_size:
|
||||||
|
values: [128, 256, 512]
|
||||||
|
learning_starts:
|
||||||
|
values: [500, 1000, 3000]
|
||||||
|
exploration_fraction:
|
||||||
|
values: [0.1, 0.2, 0.3]
|
||||||
|
exploration_final_eps:
|
||||||
|
values: [0.01, 0.03, 0.05]
|
||||||
|
# QTable specific
|
||||||
|
q_lr:
|
||||||
|
values: [0.03, 0.05, 0.1, 0.2]
|
||||||
|
eps_end:
|
||||||
|
values: [0.02, 0.05, 0.1]
|
||||||
|
eps_decay:
|
||||||
|
values: [0.999, 0.9995, 0.9999]
|
||||||
|
# action space
|
||||||
|
action_levels:
|
||||||
|
values: [7, 9, 11]
|
||||||
|
action_scale_low:
|
||||||
|
values: [0.75, 0.8, 0.85]
|
||||||
|
action_scale_high:
|
||||||
|
values: [1.15, 1.2, 1.25]
|
||||||
64
engine/sweeps/tpu_pod.yaml
Normal file
64
engine/sweeps/tpu_pod.yaml
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
method: bayes
|
||||||
|
metric:
|
||||||
|
name: sweep/score
|
||||||
|
goal: maximize
|
||||||
|
command:
|
||||||
|
- ${env}
|
||||||
|
- python
|
||||||
|
- -m
|
||||||
|
- engine.train
|
||||||
|
parameters:
|
||||||
|
use_jax:
|
||||||
|
value: true
|
||||||
|
# pmap requires all workers to compile the same computation graph shape,
|
||||||
|
# so structural params are fixed -- only research/scalar params are swept
|
||||||
|
algo:
|
||||||
|
values: [ppo, a2c]
|
||||||
|
jax_num_envs:
|
||||||
|
value: 32
|
||||||
|
jax_num_steps:
|
||||||
|
value: 128
|
||||||
|
jax_num_minibatches:
|
||||||
|
value: 4
|
||||||
|
jax_update_epochs:
|
||||||
|
value: 4
|
||||||
|
total_timesteps:
|
||||||
|
value: 100000
|
||||||
|
checkpoint_interval:
|
||||||
|
value: 200000
|
||||||
|
n_products:
|
||||||
|
value: 10
|
||||||
|
action_levels:
|
||||||
|
value: 9
|
||||||
|
# research parameters -- primary sweep targets
|
||||||
|
alpha:
|
||||||
|
distribution: uniform
|
||||||
|
min: 0.1
|
||||||
|
max: 0.6
|
||||||
|
lambda_coi:
|
||||||
|
distribution: uniform
|
||||||
|
min: 0.05
|
||||||
|
max: 0.6
|
||||||
|
robust_radius:
|
||||||
|
distribution: uniform
|
||||||
|
min: 0.0
|
||||||
|
max: 0.3
|
||||||
|
info_value:
|
||||||
|
distribution: uniform
|
||||||
|
min: 0.5
|
||||||
|
max: 2.0
|
||||||
|
revenue_weight:
|
||||||
|
values: [0.005, 0.01, 0.02]
|
||||||
|
# training hyperparameters
|
||||||
|
learning_rate:
|
||||||
|
distribution: log_uniform_values
|
||||||
|
min: 1.0e-5
|
||||||
|
max: 1.0e-3
|
||||||
|
gamma:
|
||||||
|
values: [0.97, 0.99, 0.995]
|
||||||
|
gae_lambda:
|
||||||
|
values: [0.9, 0.95, 0.98]
|
||||||
|
clip_range:
|
||||||
|
values: [0.1, 0.2, 0.3]
|
||||||
|
ent_coef:
|
||||||
|
values: [0.0, 0.005, 0.01]
|
||||||
@@ -91,10 +91,8 @@ DEFAULT_CFG = {
|
|||||||
|
|
||||||
|
|
||||||
def _truthy(value: str | bool | None) -> bool:
|
def _truthy(value: str | bool | None) -> bool:
|
||||||
if isinstance(value, bool):
|
if isinstance(value, bool): return value
|
||||||
return value
|
if value is None: return False
|
||||||
if value is None:
|
|
||||||
return False
|
|
||||||
return str(value).strip().lower() in {"1", "true", "yes", "on"}
|
return str(value).strip().lower() in {"1", "true", "yes", "on"}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
130
engine/wandb_checkpoint.py
Normal file
130
engine/wandb_checkpoint.py
Normal file
@@ -0,0 +1,130 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import json
|
||||||
|
import re
|
||||||
|
from pathlib import Path
|
||||||
|
from tempfile import TemporaryDirectory
|
||||||
|
from typing import Any, Mapping
|
||||||
|
|
||||||
|
try:
|
||||||
|
import wandb
|
||||||
|
from wandb.errors import CommError
|
||||||
|
|
||||||
|
HAS_WANDB = True
|
||||||
|
except ImportError:
|
||||||
|
HAS_WANDB = False
|
||||||
|
wandb = None # type: ignore[assignment]
|
||||||
|
CommError = RuntimeError # type: ignore[assignment]
|
||||||
|
|
||||||
|
|
||||||
|
def _safe_value(value: Any) -> Any:
|
||||||
|
if isinstance(value, (str, int, float, bool)) or value is None:
|
||||||
|
return value
|
||||||
|
if isinstance(value, (list, tuple)):
|
||||||
|
return [_safe_value(v) for v in value]
|
||||||
|
if isinstance(value, dict):
|
||||||
|
return {str(k): _safe_value(value[k]) for k in sorted(value)}
|
||||||
|
return str(value)
|
||||||
|
|
||||||
|
|
||||||
|
def _safe_scope(scope: str | None) -> str:
|
||||||
|
raw = "manual" if scope in (None, "") else str(scope)
|
||||||
|
cleaned = re.sub(r"[^A-Za-z0-9_.-]+", "-", raw).strip("-")
|
||||||
|
return cleaned or "manual"
|
||||||
|
|
||||||
|
|
||||||
|
def checkpoint_artifact_name(
|
||||||
|
cfg: Mapping[str, Any], *, backend: str, sweep_id: str | None = None
|
||||||
|
) -> str:
|
||||||
|
payload = {k: _safe_value(cfg[k]) for k in sorted(cfg)}
|
||||||
|
scope = _safe_scope(sweep_id)
|
||||||
|
canonical = json.dumps(
|
||||||
|
{"backend": backend, "scope": scope, "cfg": payload},
|
||||||
|
sort_keys=True,
|
||||||
|
separators=(",", ":"),
|
||||||
|
)
|
||||||
|
digest = hashlib.sha1(canonical.encode("utf-8")).hexdigest()[:14]
|
||||||
|
return f"phantom-{backend}-ckpt-{scope}-{digest}"[:128]
|
||||||
|
|
||||||
|
|
||||||
|
def _is_missing_artifact_error(exc: Exception) -> bool:
|
||||||
|
if isinstance(exc, CommError):
|
||||||
|
msg = str(exc).lower()
|
||||||
|
return "not found" in msg or "does not exist" in msg
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def download_latest_checkpoint(
|
||||||
|
artifact_name: str, *, file_name: str
|
||||||
|
) -> tuple[Path, dict[str, Any]] | None:
|
||||||
|
if not HAS_WANDB or wandb.run is None:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
artifact = wandb.run.use_artifact(f"{artifact_name}:latest")
|
||||||
|
except Exception as exc:
|
||||||
|
if _is_missing_artifact_error(exc):
|
||||||
|
return None
|
||||||
|
raise
|
||||||
|
directory = Path(artifact.download())
|
||||||
|
checkpoint_path = directory / file_name
|
||||||
|
if not checkpoint_path.exists():
|
||||||
|
return None
|
||||||
|
metadata = dict(getattr(artifact, "metadata", {}) or {})
|
||||||
|
return checkpoint_path, metadata
|
||||||
|
|
||||||
|
|
||||||
|
def _aliases_from_metadata(metadata: dict[str, Any] | None) -> list[str]:
|
||||||
|
aliases = ["latest"]
|
||||||
|
if metadata is None:
|
||||||
|
return aliases
|
||||||
|
if "step" in metadata:
|
||||||
|
try:
|
||||||
|
aliases.append(f"step-{int(metadata['step'])}")
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
pass
|
||||||
|
return aliases
|
||||||
|
|
||||||
|
|
||||||
|
def log_checkpoint_bytes(
|
||||||
|
artifact_name: str,
|
||||||
|
*,
|
||||||
|
file_name: str,
|
||||||
|
payload: bytes,
|
||||||
|
metadata: dict[str, Any] | None = None,
|
||||||
|
) -> bool:
|
||||||
|
if not HAS_WANDB or wandb.run is None:
|
||||||
|
return False
|
||||||
|
with TemporaryDirectory(prefix="phantom-ckpt-") as tmpdir:
|
||||||
|
path = Path(tmpdir) / file_name
|
||||||
|
path.write_bytes(payload)
|
||||||
|
artifact = wandb.Artifact(
|
||||||
|
name=artifact_name,
|
||||||
|
type="checkpoint",
|
||||||
|
metadata=metadata or {},
|
||||||
|
)
|
||||||
|
artifact.add_file(path.as_posix(), name=file_name)
|
||||||
|
wandb.log_artifact(artifact, aliases=_aliases_from_metadata(metadata))
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def log_checkpoint_file(
|
||||||
|
artifact_name: str,
|
||||||
|
*,
|
||||||
|
file_path: str | Path,
|
||||||
|
artifact_file_name: str,
|
||||||
|
metadata: dict[str, Any] | None = None,
|
||||||
|
) -> bool:
|
||||||
|
if not HAS_WANDB or wandb.run is None:
|
||||||
|
return False
|
||||||
|
src = Path(file_path)
|
||||||
|
if not src.exists():
|
||||||
|
return False
|
||||||
|
artifact = wandb.Artifact(
|
||||||
|
name=artifact_name,
|
||||||
|
type="checkpoint",
|
||||||
|
metadata=metadata or {},
|
||||||
|
)
|
||||||
|
artifact.add_file(src.as_posix(), name=artifact_file_name)
|
||||||
|
wandb.log_artifact(artifact, aliases=_aliases_from_metadata(metadata))
|
||||||
|
return True
|
||||||
269
experiments/airflow/dags/session_pricing_pipeline.py
Normal file
269
experiments/airflow/dags/session_pricing_pipeline.py
Normal file
@@ -0,0 +1,269 @@
|
|||||||
|
"""
|
||||||
|
Session-Aware Pricing DAG
|
||||||
|
THIS implements the core pricing computation (policy layer).
|
||||||
|
|
||||||
|
Flow: τ → θ̂ → D → p*
|
||||||
|
1. Fetch recent sessions from Kafka (last 10 active)
|
||||||
|
2. Extract features per session (τ → θ̂)
|
||||||
|
3. Map features to demand proxy (θ̂ → D)
|
||||||
|
4. Compute optimal prices (D → p*)
|
||||||
|
5. Write to Redis session:{sessionId}:prices
|
||||||
|
|
||||||
|
Scheduled: every 1 minute when enabled
|
||||||
|
"""
|
||||||
|
from airflow import DAG
|
||||||
|
from airflow.operators.python import PythonOperator
|
||||||
|
from airflow.utils.dates import days_ago
|
||||||
|
from datetime import timedelta
|
||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
import pickle
|
||||||
|
|
||||||
|
sys.path.insert(0, '/opt/airflow')
|
||||||
|
|
||||||
|
from procesing.context import PipelineContext
|
||||||
|
from procesing.providers import SupabaseProvider, BackendAPIProvider
|
||||||
|
from procesing.steps.session import ExtractSessionFeaturesStep
|
||||||
|
from procesing.pricers.simple import SimpleSurgePricer, session_features_to_demand
|
||||||
|
from procesing.pricing import StateSpace
|
||||||
|
from lib.model_registry import ModelRegistry
|
||||||
|
|
||||||
|
DEFAULT_ARGS = {
|
||||||
|
'owner': 'phantom-research',
|
||||||
|
'depends_on_past': False,
|
||||||
|
'email_on_failure': False,
|
||||||
|
'email_on_retry': False,
|
||||||
|
'retries': 1,
|
||||||
|
'retry_delay': timedelta(seconds=30),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class CompositeProvider(SupabaseProvider, BackendAPIProvider):
|
||||||
|
def __init__(self):
|
||||||
|
SupabaseProvider.__init__(self)
|
||||||
|
BackendAPIProvider.__init__(self)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_context(store_mode: str = 'hotel') -> PipelineContext:
|
||||||
|
return PipelineContext(provider=CompositeProvider(), store_mode=store_mode)
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_recent_sessions(**kwargs):
|
||||||
|
"""
|
||||||
|
Task: Fetch last N active sessions from Kafka.
|
||||||
|
Returns: DataFrame of interaction events for recent sessions.
|
||||||
|
"""
|
||||||
|
dag_conf = kwargs.get('dag_run').conf if kwargs.get('dag_run') else {}
|
||||||
|
store_mode = dag_conf.get('store_mode', 'hotel')
|
||||||
|
session_limit = dag_conf.get('session_limit', 10)
|
||||||
|
|
||||||
|
ctx = _get_context(store_mode)
|
||||||
|
provider = ctx.provider
|
||||||
|
|
||||||
|
# fetch all recent interactions from Kafka
|
||||||
|
try:
|
||||||
|
interactions_df = provider.fetch_kafka_topic("user-interactions")
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Failed to fetch interactions: {e}")
|
||||||
|
kwargs['ti'].xcom_push(key='sessions_data', value=pickle.dumps(pd.DataFrame()))
|
||||||
|
return 0
|
||||||
|
|
||||||
|
if interactions_df.empty or 'sessionId' not in interactions_df.columns:
|
||||||
|
kwargs['ti'].xcom_push(key='sessions_data', value=pickle.dumps(pd.DataFrame()))
|
||||||
|
return 0
|
||||||
|
|
||||||
|
# identify last N active sessions (most recent by event count)
|
||||||
|
recent_sessions = interactions_df['sessionId'].value_counts().head(session_limit).index.tolist()
|
||||||
|
|
||||||
|
# filter to only those sessions
|
||||||
|
filtered_df = interactions_df[interactions_df['sessionId'].isin(recent_sessions)].copy()
|
||||||
|
|
||||||
|
kwargs['ti'].xcom_push(key='sessions_data', value=pickle.dumps(filtered_df))
|
||||||
|
kwargs['ti'].xcom_push(key='session_ids', value=recent_sessions)
|
||||||
|
|
||||||
|
logging.info(f"Fetched {len(filtered_df)} events for {len(recent_sessions)} sessions")
|
||||||
|
return len(recent_sessions)
|
||||||
|
|
||||||
|
|
||||||
|
def extract_session_features(**kwargs):
|
||||||
|
"""
|
||||||
|
Task: Extract behavioral features from session trajectories.
|
||||||
|
THIS implements τ → θ̂ transformation.
|
||||||
|
"""
|
||||||
|
ti = kwargs['ti']
|
||||||
|
sessions_df = pickle.loads(ti.xcom_pull(key='sessions_data'))
|
||||||
|
|
||||||
|
if sessions_df.empty:
|
||||||
|
ti.xcom_push(key='session_features', value=pickle.dumps(pd.DataFrame()))
|
||||||
|
return 0
|
||||||
|
|
||||||
|
dag_conf = kwargs.get('dag_run').conf if kwargs.get('dag_run') else {}
|
||||||
|
ctx = _get_context(dag_conf.get('store_mode', 'hotel'))
|
||||||
|
|
||||||
|
# extract features using vectorized pipeline
|
||||||
|
feature_extractor = ExtractSessionFeaturesStep(ctx)
|
||||||
|
features_df = feature_extractor.transform(sessions_df)
|
||||||
|
|
||||||
|
ti.xcom_push(key='session_features', value=pickle.dumps(features_df))
|
||||||
|
|
||||||
|
logging.info(f"Extracted {len(features_df.columns)} features for {len(features_df)} sessions")
|
||||||
|
logging.info(f"Feature columns: {list(features_df.columns)}")
|
||||||
|
logging.info(f"Sample features (first session):\n{features_df.iloc[0].to_dict()}")
|
||||||
|
|
||||||
|
return len(features_df)
|
||||||
|
|
||||||
|
|
||||||
|
def compute_session_prices(**kwargs):
|
||||||
|
"""
|
||||||
|
Task: Compute optimal prices for each session.
|
||||||
|
THIS implements θ̂ → D → p* transformation.
|
||||||
|
"""
|
||||||
|
ti = kwargs['ti']
|
||||||
|
features_df = pickle.loads(ti.xcom_pull(key='session_features'))
|
||||||
|
|
||||||
|
if features_df.empty:
|
||||||
|
ti.xcom_push(key='price_results', value=pickle.dumps({}))
|
||||||
|
return 0
|
||||||
|
|
||||||
|
dag_conf = kwargs.get('dag_run').conf if kwargs.get('dag_run') else {}
|
||||||
|
store_mode = dag_conf.get('store_mode', 'hotel')
|
||||||
|
ctx = _get_context(store_mode)
|
||||||
|
|
||||||
|
# fetch product catalog for base prices
|
||||||
|
products_df = ctx.provider.fetch_products(store_mode)
|
||||||
|
if products_df.empty:
|
||||||
|
logging.error("No products found in catalog")
|
||||||
|
ti.xcom_push(key='price_results', value=pickle.dumps({}))
|
||||||
|
return 0
|
||||||
|
|
||||||
|
products_df['base_price'] = products_df['metadata'].apply(
|
||||||
|
lambda m: m.get('base_price', 100.0) if isinstance(m, dict) else 100.0
|
||||||
|
)
|
||||||
|
|
||||||
|
# initialize pricing model
|
||||||
|
pricer = SimpleSurgePricer(
|
||||||
|
high_threshold=dag_conf.get('high_threshold', 10),
|
||||||
|
low_threshold=dag_conf.get('low_threshold', 2),
|
||||||
|
surge_multiplier=dag_conf.get('surge_multiplier', 1.15),
|
||||||
|
discount_multiplier=dag_conf.get('discount_multiplier', 0.95)
|
||||||
|
)
|
||||||
|
pricer.fit(products_df)
|
||||||
|
|
||||||
|
# compute prices per session
|
||||||
|
price_results = {}
|
||||||
|
n_products = len(products_df)
|
||||||
|
|
||||||
|
logging.info(f"Starting price computation for {len(features_df)} sessions, {n_products} products")
|
||||||
|
logging.info(f"Pricer config: high_thresh={pricer.high_threshold}, low_thresh={pricer.low_threshold}, surge_mult={pricer.surge_multiplier}")
|
||||||
|
|
||||||
|
for idx, session_row in features_df.iterrows():
|
||||||
|
session_id = session_row.get('sessionId')
|
||||||
|
if not session_id:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# map features to demand proxy (θ̂ → D)
|
||||||
|
session_features_single = pd.DataFrame([session_row])
|
||||||
|
demand_proxy = session_features_to_demand(session_features_single)
|
||||||
|
|
||||||
|
logging.info(f"[Session {session_id}] Features → Demand: {demand_proxy:.2f}")
|
||||||
|
logging.info(f"[Session {session_id}] Key features: velocity={session_row.get('interaction_velocity', 0):.2f}, cart_ratio={session_row.get('cart_to_view_ratio', 0):.2f}, item_views={session_row.get('item_views', 0)}")
|
||||||
|
|
||||||
|
# build state space
|
||||||
|
state_space = StateSpace(
|
||||||
|
demand=np.full(n_products, demand_proxy), # broadcast session demand to all products
|
||||||
|
prices=products_df['base_price'].values,
|
||||||
|
session_features=session_features_single
|
||||||
|
)
|
||||||
|
|
||||||
|
# compute optimal prices (D → p*)
|
||||||
|
optimal_prices = pricer.predict(state_space)
|
||||||
|
|
||||||
|
base_avg = products_df['base_price'].mean()
|
||||||
|
optimal_avg = optimal_prices.mean()
|
||||||
|
price_change_pct = ((optimal_avg - base_avg) / base_avg) * 100
|
||||||
|
|
||||||
|
logging.info(f"[Session {session_id}] Price adjustment: base_avg={base_avg:.2f}, optimal_avg={optimal_avg:.2f}, change={price_change_pct:+.1f}%")
|
||||||
|
|
||||||
|
# store as dict {productId: price}
|
||||||
|
price_map = {
|
||||||
|
str(products_df.iloc[i]['id']): float(optimal_prices[i])
|
||||||
|
for i in range(n_products)
|
||||||
|
}
|
||||||
|
|
||||||
|
price_results[session_id] = price_map
|
||||||
|
|
||||||
|
ti.xcom_push(key='price_results', value=pickle.dumps(price_results))
|
||||||
|
|
||||||
|
logging.info(f"Computed prices for {len(price_results)} sessions, {n_products} products each")
|
||||||
|
return len(price_results)
|
||||||
|
|
||||||
|
|
||||||
|
def publish_to_registry(**kwargs):
|
||||||
|
"""
|
||||||
|
Task: Write session prices to Redis registry.
|
||||||
|
THIS is the write path: prices → session:{sessionId}:prices
|
||||||
|
"""
|
||||||
|
ti = kwargs['ti']
|
||||||
|
price_results = pickle.loads(ti.xcom_pull(key='price_results'))
|
||||||
|
|
||||||
|
if not price_results:
|
||||||
|
logging.warning("No prices to publish")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
registry = ModelRegistry()
|
||||||
|
ttl = kwargs.get('dag_run').conf.get('ttl', 1800) if kwargs.get('dag_run') and kwargs.get('dag_run').conf else 1800
|
||||||
|
|
||||||
|
published_count = 0
|
||||||
|
for session_id, price_map in price_results.items():
|
||||||
|
registry.set_session_prices(session_id, price_map, ttl=ttl)
|
||||||
|
published_count += 1
|
||||||
|
|
||||||
|
logging.info(f"Published prices for {published_count} sessions to registry (TTL={ttl}s)")
|
||||||
|
|
||||||
|
return {
|
||||||
|
'sessions_published': published_count,
|
||||||
|
'products_per_session': len(next(iter(price_results.values()))) if price_results else 0,
|
||||||
|
'status': 'success'
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# DAG definition
|
||||||
|
with DAG(
|
||||||
|
'session_pricing_pipeline',
|
||||||
|
default_args=DEFAULT_ARGS,
|
||||||
|
description='Session-aware pricing: extract features → compute prices → publish to registry',
|
||||||
|
schedule_interval='*/1 * * * *', # every 1 minute
|
||||||
|
start_date=days_ago(1),
|
||||||
|
catchup=False,
|
||||||
|
max_active_runs=1,
|
||||||
|
tags=['pricing', 'session-aware', 'research', 'real-time'],
|
||||||
|
) as dag:
|
||||||
|
|
||||||
|
t_fetch_sessions = PythonOperator(
|
||||||
|
task_id='fetch_recent_sessions',
|
||||||
|
python_callable=fetch_recent_sessions,
|
||||||
|
provide_context=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
t_extract_features = PythonOperator(
|
||||||
|
task_id='extract_session_features',
|
||||||
|
python_callable=extract_session_features,
|
||||||
|
provide_context=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
t_compute_prices = PythonOperator(
|
||||||
|
task_id='compute_session_prices',
|
||||||
|
python_callable=compute_session_prices,
|
||||||
|
provide_context=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
t_publish = PythonOperator(
|
||||||
|
task_id='publish_to_registry',
|
||||||
|
python_callable=publish_to_registry,
|
||||||
|
provide_context=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# linear dependency: fetch → extract → compute → publish
|
||||||
|
t_fetch_sessions >> t_extract_features >> t_compute_prices >> t_publish
|
||||||
1
experiments/ml/encoder/__init__.py
Normal file
1
experiments/ml/encoder/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
from .encoder import Window, extract_windows, build_windows, WindowDataset, PrototypeClassifier, train, loocv
|
||||||
210
experiments/ml/encoder/encoder.py
Normal file
210
experiments/ml/encoder/encoder.py
Normal file
@@ -0,0 +1,210 @@
|
|||||||
|
"""Contrastive encoder via trajectory windowing. Classification by prototype distance."""
|
||||||
|
import sys
|
||||||
|
sys.path.insert(0, "/home/velocitatem/Documents/Projects/PHANTOM/sim/rl/behavior_loader")
|
||||||
|
sys.path.insert(0, "/home/velocitatem/Documents/Projects/PHANTOM/experiments/ml")
|
||||||
|
|
||||||
|
from sim.rl.behavior_loader.loader import JointLoader, PayloadModel
|
||||||
|
from arch import TrajectoryEncoder, featurize_trajectory, nt_xent_loss
|
||||||
|
from typing import List, Dict, Tuple
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from datetime import datetime
|
||||||
|
import numpy as np, torch, torch.nn.functional as F, random, optuna
|
||||||
|
from torch.utils.data import Dataset, DataLoader
|
||||||
|
from torch.optim import Adam
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
|
RUNS = "/home/velocitatem/Documents/Projects/PHANTOM/experiments/ml/runs"
|
||||||
|
AGENT_DIR = "/home/velocitatem/Documents/Projects/PHANTOM/experiments/agents/collected_data/"
|
||||||
|
HUMAN_DIR = "/home/velocitatem/Documents/Projects/PHANTOM/experiments/collected_data/"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Window:
|
||||||
|
events: List[PayloadModel]
|
||||||
|
traj_id: str
|
||||||
|
label: int # 0=human, 1=agent
|
||||||
|
|
||||||
|
|
||||||
|
def extract_windows(events: List[PayloadModel], traj_id: str, label: int,
|
||||||
|
sizes: List[int] = [5, 10, 15], stride: int = 2) -> List[Window]:
|
||||||
|
"""Multi-scale overlapping windows from trajectory"""
|
||||||
|
n = len(events)
|
||||||
|
wins = [Window(events[i:i+s], traj_id, label) for s in sizes if n >= s for i in range(0, n-s+1, stride)]
|
||||||
|
if n >= 3: wins.append(Window(events, traj_id, label)) # full traj
|
||||||
|
return wins
|
||||||
|
|
||||||
|
|
||||||
|
def build_windows(data: Dict[str, List], sizes=[5,10,15], stride=2) -> List[Window]:
|
||||||
|
return [w for tid, evts in data.items()
|
||||||
|
for w in extract_windows(evts, tid, 0 if tid.startswith('human_') else 1, sizes, stride)]
|
||||||
|
|
||||||
|
|
||||||
|
class WindowDataset(Dataset):
|
||||||
|
"""Yields (anchor, positive) pairs from same class"""
|
||||||
|
def __init__(self, windows: List[Window], dim: int = 64):
|
||||||
|
self.wins, self.dim = windows, dim
|
||||||
|
self.by_label = {0: [i for i,w in enumerate(windows) if w.label==0],
|
||||||
|
1: [i for i,w in enumerate(windows) if w.label==1]}
|
||||||
|
self.by_traj = {}
|
||||||
|
for i, w in enumerate(windows): self.by_traj.setdefault(w.traj_id, []).append(i)
|
||||||
|
|
||||||
|
def __len__(self): return len(self.wins)
|
||||||
|
|
||||||
|
def _feat(self, evts): return featurize_trajectory(evts, None, self.dim)
|
||||||
|
|
||||||
|
def _aug(self, evts): # subsample 70-100%
|
||||||
|
if len(evts) < 4: return evts
|
||||||
|
k = max(3, int(len(evts) * random.uniform(0.7, 1.0)))
|
||||||
|
start = random.randint(0, len(evts) - k)
|
||||||
|
return evts[start:start+k]
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
w = self.wins[idx]
|
||||||
|
pool = [i for i in self.by_label[w.label] if self.wins[i].traj_id != w.traj_id]
|
||||||
|
pos_idx = random.choice(pool) if pool else idx
|
||||||
|
a = torch.tensor(self._feat(self._aug(w.events)), dtype=torch.float32)
|
||||||
|
p = torch.tensor(self._feat(self._aug(self.wins[pos_idx].events)), dtype=torch.float32)
|
||||||
|
return a, p, w.label
|
||||||
|
|
||||||
|
|
||||||
|
class PrototypeClassifier:
|
||||||
|
"""Classify by distance to class centroids"""
|
||||||
|
def __init__(self, encoder: TrajectoryEncoder, device = 'cuda', dim=64):
|
||||||
|
self.enc, self.dev, self.dim = encoder, device, dim
|
||||||
|
self.centroids = {0: None, 1: None}
|
||||||
|
|
||||||
|
def fit(self, windows: List[Window]):
|
||||||
|
self.enc.eval()
|
||||||
|
embs = {0: [], 1: []}
|
||||||
|
with torch.no_grad():
|
||||||
|
for w in windows:
|
||||||
|
x = torch.tensor(featurize_trajectory(w.events, None, self.dim), dtype=torch.float32)
|
||||||
|
z = self.enc(x.unsqueeze(0).unsqueeze(1).to(self.dev))
|
||||||
|
embs[w.label].append(z)
|
||||||
|
self.centroids = {k: torch.cat(v).mean(0, keepdim=True) if v else None for k, v in embs.items()}
|
||||||
|
return self
|
||||||
|
|
||||||
|
def predict(self, events: List[PayloadModel]) -> Tuple[int, float, Dict]:
|
||||||
|
"""Returns (pred, confidence, debug). Confidence via softmax over -distances."""
|
||||||
|
self.enc.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
x = torch.tensor(featurize_trajectory(events, None, self.dim), dtype=torch.float32)
|
||||||
|
z = self.enc(x.unsqueeze(0).unsqueeze(1).to(self.dev))
|
||||||
|
dists = {k: torch.norm(z - c, dim=1).item() for k, c in self.centroids.items() if c is not None}
|
||||||
|
if not dists: return 0, 0.0, {'d': {}, 'p': [0.5, 0.5]}
|
||||||
|
pred = min(dists, key=dists.get)
|
||||||
|
d0, d1 = dists.get(0, 1e6), dists.get(1, 1e6) # softmax(-d) gives higher prob to closer centroid
|
||||||
|
probs = F.softmax(torch.tensor([[-d0, -d1]]), dim=1).squeeze()
|
||||||
|
return pred, probs[pred].item(), {'d': dists, 'p': probs.tolist()}
|
||||||
|
|
||||||
|
|
||||||
|
def train(epochs=200, lr=5e-4, batch=16, dim=64, emb=32, temp=0.5,
|
||||||
|
sizes=[5,10,15], stride=2, name=None, verbose=True):
|
||||||
|
data = JointLoader(HUMAN_DIR, AGENT_DIR).get_data()
|
||||||
|
wins = build_windows(data, sizes, stride)
|
||||||
|
if verbose: print(f"Windows: {len(wins)} ({sum(w.label==0 for w in wins)}h/{sum(w.label==1 for w in wins)}a)")
|
||||||
|
|
||||||
|
dev = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
|
enc = TrajectoryEncoder(dim, emb).to(dev)
|
||||||
|
opt = Adam(enc.parameters(), lr=lr)
|
||||||
|
loader = DataLoader(WindowDataset(wins, dim), batch_size=batch, shuffle=True, drop_last=True)
|
||||||
|
|
||||||
|
name = name or f"enc_{dim}_{emb}_{datetime.now():%Y%m%d_%H%M%S}"
|
||||||
|
writer = SummaryWriter(f"{RUNS}/encoder/{name}")
|
||||||
|
|
||||||
|
for ep in range(epochs):
|
||||||
|
enc.train()
|
||||||
|
total, n = 0.0, 0
|
||||||
|
for a, p, _ in loader:
|
||||||
|
loss = nt_xent_loss(enc(a.unsqueeze(1).to(dev)), enc(p.unsqueeze(1).to(dev)), temp)
|
||||||
|
opt.zero_grad(); loss.backward(); opt.step()
|
||||||
|
total += loss.item(); n += 1
|
||||||
|
avg = total / max(n, 1)
|
||||||
|
writer.add_scalar('loss-ntxent', avg, ep)
|
||||||
|
if verbose and (ep+1) % 20 == 0: print(f"Epoch {ep+1}: {avg:.4f}")
|
||||||
|
|
||||||
|
writer.close()
|
||||||
|
return enc, wins, dev
|
||||||
|
|
||||||
|
|
||||||
|
def loocv(epochs=100, lr=5e-4, dim=64, emb=32, temp=0.5, sizes=[5,10,15], stride=2, verbose=True):
|
||||||
|
"""Leave-one-trajectory-out CV"""
|
||||||
|
data = JointLoader(HUMAN_DIR, AGENT_DIR).get_data()
|
||||||
|
dev = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
|
results = []
|
||||||
|
|
||||||
|
for test_id in data:
|
||||||
|
train_data = {k: v for k, v in data.items() if k != test_id}
|
||||||
|
if not any(k.startswith('human_') for k in train_data) or not any(k.startswith('agent_') for k in train_data):
|
||||||
|
continue
|
||||||
|
|
||||||
|
wins = build_windows(train_data, sizes, stride)
|
||||||
|
enc = TrajectoryEncoder(dim, emb).to(dev)
|
||||||
|
opt = Adam(enc.parameters(), lr=lr)
|
||||||
|
loader = DataLoader(WindowDataset(wins, dim), batch_size=min(16, len(wins)//2 or 1),
|
||||||
|
shuffle=True, drop_last=len(wins)>2)
|
||||||
|
|
||||||
|
for _ in range(epochs):
|
||||||
|
enc.train()
|
||||||
|
for a, p, _ in loader:
|
||||||
|
loss = nt_xent_loss(enc(a.unsqueeze(1).to(dev)), enc(p.unsqueeze(1).to(dev)), temp)
|
||||||
|
opt.zero_grad(); loss.backward(); opt.step()
|
||||||
|
|
||||||
|
clf = PrototypeClassifier(enc, dev, dim).fit(wins)
|
||||||
|
pred, conf, dbg = clf.predict(data[test_id])
|
||||||
|
actual = 0 if test_id.startswith('human_') else 1
|
||||||
|
results.append((pred, actual, conf))
|
||||||
|
if verbose: print(f"{test_id[:18]}: pred={pred} conf={conf:.2f} actual={actual} {'OK' if pred==actual else 'MISS'}")
|
||||||
|
|
||||||
|
if results:
|
||||||
|
acc = sum(p==a for p,a,_ in results) / len(results)
|
||||||
|
if verbose: print(f"\nAccuracy: {acc:.1%} ({sum(p==a for p,a,_ in results)}/{len(results)})")
|
||||||
|
return acc, results
|
||||||
|
return 0.0, []
|
||||||
|
|
||||||
|
|
||||||
|
def hparam_tune(n_trials=50, epochs=60, n_jobs=2, verbose=True):
|
||||||
|
"""Optuna hyperparameter search maximizing LOOCV accuracy"""
|
||||||
|
def objective(trial):
|
||||||
|
lr = trial.suggest_float('lr', 1e-5, 1e-2, log=True)
|
||||||
|
dim = trial.suggest_categorical('dim', [32, 64, 128, 256])
|
||||||
|
emb = trial.suggest_categorical('emb', [16, 32, 64, 128])
|
||||||
|
temp = trial.suggest_float('temp', 0.05, 1.0)
|
||||||
|
stride = trial.suggest_int('stride', 1, 4)
|
||||||
|
sizes = [trial.suggest_int(f's{i}', 3, 20) for i in range(3)]
|
||||||
|
sizes = sorted(set(sizes)) # unique sorted
|
||||||
|
acc, _ = loocv(epochs, lr, dim, emb, temp, sizes, stride, verbose=False)
|
||||||
|
return acc
|
||||||
|
|
||||||
|
study = optuna.create_study(direction='maximize', study_name='encoder_hparam',
|
||||||
|
sampler=optuna.samplers.TPESampler(seed=42))
|
||||||
|
study.optimize(objective, n_trials=n_trials, n_jobs=n_jobs, show_progress_bar=verbose)
|
||||||
|
|
||||||
|
best = study.best_params
|
||||||
|
if verbose:
|
||||||
|
print(f"\nBest accuracy: {study.best_value:.1%}")
|
||||||
|
print(f"Best params: {best}")
|
||||||
|
return best, study
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import argparse
|
||||||
|
p = argparse.ArgumentParser()
|
||||||
|
p.add_argument('--mode', choices=['train', 'eval', 'hparam'], default='train')
|
||||||
|
p.add_argument('--epochs', type=int, default=200)
|
||||||
|
p.add_argument('--lr', type=float, default=5e-4)
|
||||||
|
p.add_argument('--dim', type=int, default=128)
|
||||||
|
p.add_argument('--emb', type=int, default=64)
|
||||||
|
p.add_argument('--temp', type=float, default=0.1)
|
||||||
|
p.add_argument('--sizes', type=str, default='5,10,15')
|
||||||
|
p.add_argument('--stride', type=int, default=2)
|
||||||
|
p.add_argument('--n_trials', type=int, default=50)
|
||||||
|
args = p.parse_args()
|
||||||
|
sizes = [int(x) for x in args.sizes.split(',')]
|
||||||
|
|
||||||
|
if args.mode == 'train':
|
||||||
|
enc, wins, dev = train(args.epochs, args.lr, 16, args.dim, args.emb, args.temp, sizes, args.stride)
|
||||||
|
elif args.mode == 'hparam':
|
||||||
|
best, study = hparam_tune(args.n_trials, min(args.epochs, 60))
|
||||||
|
else:
|
||||||
|
loocv(args.epochs, args.lr, args.dim, args.emb, args.temp, sizes, args.stride)
|
||||||
957
experiments/notebooks/data_export.ipynb
Normal file
957
experiments/notebooks/data_export.ipynb
Normal file
@@ -0,0 +1,957 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 10,
|
||||||
|
"id": "62eafcd9-5462-4063-8873-0e7fb9add907",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"True"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 10,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"from kafka import KafkaConsumer\n",
|
||||||
|
"import pandas as pd\n",
|
||||||
|
"import json\n",
|
||||||
|
"import numpy as np\n",
|
||||||
|
"import os\n",
|
||||||
|
"from dotenv import load_dotenv\n",
|
||||||
|
"import matplotlib.pyplot as plt\n",
|
||||||
|
"from IPython.display import display, SVG, Image\n",
|
||||||
|
"load_dotenv()"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 11,
|
||||||
|
"id": "4af65cb4-e8cf-4877-b2db-13ac19f3838f",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"<class 'pandas.core.frame.DataFrame'>\n",
|
||||||
|
"RangeIndex: 73 entries, 0 to 72\n",
|
||||||
|
"Data columns (total 13 columns):\n",
|
||||||
|
" # Column Non-Null Count Dtype \n",
|
||||||
|
"--- ------ -------------- ----- \n",
|
||||||
|
" 0 sessionId 73 non-null object \n",
|
||||||
|
" 1 eventName 73 non-null object \n",
|
||||||
|
" 2 page 73 non-null object \n",
|
||||||
|
" 3 productId 67 non-null object \n",
|
||||||
|
" 4 storeMode 73 non-null object \n",
|
||||||
|
" 5 userAgent 73 non-null object \n",
|
||||||
|
" 6 ts 73 non-null object \n",
|
||||||
|
" 7 metadata_referrer 6 non-null object \n",
|
||||||
|
" 8 metadata_roomType 45 non-null object \n",
|
||||||
|
" 9 metadata_price 45 non-null float64\n",
|
||||||
|
" 10 metadata_nights 45 non-null float64\n",
|
||||||
|
" 11 metadata_elementText 22 non-null object \n",
|
||||||
|
" 12 metadata_dwellTime 22 non-null float64\n",
|
||||||
|
"dtypes: float64(3), object(10)\n",
|
||||||
|
"memory usage: 7.5+ KB\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"KAFKA_PORT=os.getenv(\"KAFKA_PORT\", 9092)\n",
|
||||||
|
"topic = \"user-interactions\"\n",
|
||||||
|
"consumer = KafkaConsumer(\n",
|
||||||
|
" topic, \n",
|
||||||
|
" enable_auto_commit=True,\n",
|
||||||
|
" value_deserializer=lambda x: json.loads(x.decode('utf-8')),\n",
|
||||||
|
" auto_offset_reset='earliest', \n",
|
||||||
|
" bootstrap_servers=['localhost:9092'])\n",
|
||||||
|
"messages=consumer.poll(timeout_ms=1000,max_records=10000)\n",
|
||||||
|
"df = []\n",
|
||||||
|
"for m in messages.values():\n",
|
||||||
|
" for i in m:\n",
|
||||||
|
" df.append(i.value)\n",
|
||||||
|
"df = pd.DataFrame(df)\n",
|
||||||
|
"# explode metadata col json\n",
|
||||||
|
"df = df.join(pd.json_normalize(df.pop(\"metadata\"), sep=\".\").add_prefix(\"metadata_\"))\n",
|
||||||
|
"df.info()"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 12,
|
||||||
|
"id": "f6819a1c-32ab-49c7-845b-5df7bf60f561",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/html": [
|
||||||
|
"<div>\n",
|
||||||
|
"<style scoped>\n",
|
||||||
|
" .dataframe tbody tr th:only-of-type {\n",
|
||||||
|
" vertical-align: middle;\n",
|
||||||
|
" }\n",
|
||||||
|
"\n",
|
||||||
|
" .dataframe tbody tr th {\n",
|
||||||
|
" vertical-align: top;\n",
|
||||||
|
" }\n",
|
||||||
|
"\n",
|
||||||
|
" .dataframe thead th {\n",
|
||||||
|
" text-align: right;\n",
|
||||||
|
" }\n",
|
||||||
|
"</style>\n",
|
||||||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
||||||
|
" <thead>\n",
|
||||||
|
" <tr style=\"text-align: right;\">\n",
|
||||||
|
" <th></th>\n",
|
||||||
|
" <th>sessionId</th>\n",
|
||||||
|
" <th>eventName</th>\n",
|
||||||
|
" <th>page</th>\n",
|
||||||
|
" <th>productId</th>\n",
|
||||||
|
" <th>storeMode</th>\n",
|
||||||
|
" <th>userAgent</th>\n",
|
||||||
|
" <th>ts</th>\n",
|
||||||
|
" <th>metadata_referrer</th>\n",
|
||||||
|
" <th>metadata_roomType</th>\n",
|
||||||
|
" <th>metadata_price</th>\n",
|
||||||
|
" <th>metadata_nights</th>\n",
|
||||||
|
" <th>metadata_elementText</th>\n",
|
||||||
|
" <th>metadata_dwellTime</th>\n",
|
||||||
|
" </tr>\n",
|
||||||
|
" </thead>\n",
|
||||||
|
" <tbody>\n",
|
||||||
|
" <tr>\n",
|
||||||
|
" <th>0</th>\n",
|
||||||
|
" <td>d176d7c9-4027-4702-9e31-2a71395cdda0</td>\n",
|
||||||
|
" <td>page_view</td>\n",
|
||||||
|
" <td>/products</td>\n",
|
||||||
|
" <td>None</td>\n",
|
||||||
|
" <td>hotel</td>\n",
|
||||||
|
" <td>Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/53...</td>\n",
|
||||||
|
" <td>2025-11-14T13:23:46.270Z</td>\n",
|
||||||
|
" <td></td>\n",
|
||||||
|
" <td>NaN</td>\n",
|
||||||
|
" <td>NaN</td>\n",
|
||||||
|
" <td>NaN</td>\n",
|
||||||
|
" <td>NaN</td>\n",
|
||||||
|
" <td>NaN</td>\n",
|
||||||
|
" <td>NaN</td>\n",
|
||||||
|
" <td>NaN</td>\n",
|
||||||
|
" <td>NaN</td>\n",
|
||||||
|
" <td>NaN</td>\n",
|
||||||
|
" <td>NaN</td>\n",
|
||||||
|
" </tr>\n",
|
||||||
|
" <tr>\n",
|
||||||
|
" <th>1</th>\n",
|
||||||
|
" <td>f0317a5d-e424-44e9-b784-c8f7291ffe31</td>\n",
|
||||||
|
" <td>page_view</td>\n",
|
||||||
|
" <td>/</td>\n",
|
||||||
|
" <td>None</td>\n",
|
||||||
|
" <td>hotel</td>\n",
|
||||||
|
" <td>Mozilla/5.0 (X11; Linux x86_64; rv:143.0) Geck...</td>\n",
|
||||||
|
" <td>2025-11-14T13:26:00.291Z</td>\n",
|
||||||
|
" <td></td>\n",
|
||||||
|
" <td>NaN</td>\n",
|
||||||
|
" <td>NaN</td>\n",
|
||||||
|
" <td>NaN</td>\n",
|
||||||
|
" <td>NaN</td>\n",
|
||||||
|
" <td>NaN</td>\n",
|
||||||
|
" </tr>\n",
|
||||||
|
" <tr>\n",
|
||||||
|
" <th>2</th>\n",
|
||||||
|
" <td>f0317a5d-e424-44e9-b784-c8f7291ffe31</td>\n",
|
||||||
|
" <td>page_view</td>\n",
|
||||||
|
" <td>/products</td>\n",
|
||||||
|
" <td>None</td>\n",
|
||||||
|
" <td>hotel</td>\n",
|
||||||
|
" <td>Mozilla/5.0 (X11; Linux x86_64; rv:143.0) Geck...</td>\n",
|
||||||
|
" <td>2025-11-14T13:26:07.769Z</td>\n",
|
||||||
|
" <td></td>\n",
|
||||||
|
" <td>NaN</td>\n",
|
||||||
|
" <td>NaN</td>\n",
|
||||||
|
" <td>NaN</td>\n",
|
||||||
|
" <td>NaN</td>\n",
|
||||||
|
" <td>NaN</td>\n",
|
||||||
|
" </tr>\n",
|
||||||
|
" <tr>\n",
|
||||||
|
" <th>3</th>\n",
|
||||||
|
" <td>f0317a5d-e424-44e9-b784-c8f7291ffe31</td>\n",
|
||||||
|
" <td>view_item_page</td>\n",
|
||||||
|
" <td>/products</td>\n",
|
||||||
|
" <td>htl-0</td>\n",
|
||||||
|
" <td>hotel</td>\n",
|
||||||
|
" <td>Mozilla/5.0 (X11; Linux x86_64; rv:143.0) Geck...</td>\n",
|
||||||
|
" <td>2025-11-14T13:26:15.010Z</td>\n",
|
||||||
|
" <td>NaN</td>\n",
|
||||||
|
" <td>Premium Room</td>\n",
|
||||||
|
" <td>269.0</td>\n",
|
||||||
|
" <td>1.0</td>\n",
|
||||||
|
" <td>NaN</td>\n",
|
||||||
|
" <td>NaN</td>\n",
|
||||||
|
" </tr>\n",
|
||||||
|
" <tr>\n",
|
||||||
|
" <th>4</th>\n",
|
||||||
|
" <td>238dc588-a7ab-4c0e-bccd-6abca5076c66</td>\n",
|
||||||
|
" <td>page_view</td>\n",
|
||||||
|
" <td>/products</td>\n",
|
||||||
|
" <td>None</td>\n",
|
||||||
|
" <td>hotel</td>\n",
|
||||||
|
" <td>Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7...</td>\n",
|
||||||
|
" <td>2025-11-14T13:27:15.457Z</td>\n",
|
||||||
|
" <td></td>\n",
|
||||||
|
" <td>NaN</td>\n",
|
||||||
|
" <td>NaN</td>\n",
|
||||||
|
" <td>NaN</td>\n",
|
||||||
|
" <td>NaN</td>\n",
|
||||||
|
" <td>NaN</td>\n",
|
||||||
|
" </tr>\n",
|
||||||
|
" <tr>\n",
|
||||||
|
" <th>5</th>\n",
|
||||||
|
" <td>238dc588-a7ab-4c0e-bccd-6abca5076c66</td>\n",
|
||||||
|
" <td>view_item_page</td>\n",
|
||||||
|
" <td>/products</td>\n",
|
||||||
|
" <td>htl-0</td>\n",
|
||||||
|
" <td>hotel</td>\n",
|
||||||
|
" <td>Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7...</td>\n",
|
||||||
|
" <td>2025-11-14T13:27:15.591Z</td>\n",
|
||||||
|
" <td>NaN</td>\n",
|
||||||
|
" <td>Premium Room</td>\n",
|
||||||
|
" <td>264.0</td>\n",
|
||||||
|
" <td>2.0</td>\n",
|
||||||
|
" <td>NaN</td>\n",
|
||||||
|
" <td>NaN</td>\n",
|
||||||
|
" <td>NaN</td>\n",
|
||||||
|
" <td>NaN</td>\n",
|
||||||
|
" <td>NaN</td>\n",
|
||||||
|
" <td>NaN</td>\n",
|
||||||
|
" <td>NaN</td>\n",
|
||||||
|
" </tr>\n",
|
||||||
|
" <tr>\n",
|
||||||
|
" <th>432</th>\n",
|
||||||
|
" <td>214d9fad-9b00-40c3-bd0e-7739b6acd654</td>\n",
|
||||||
|
" <td>click</td>\n",
|
||||||
|
" <td>1762448192425</td>\n",
|
||||||
|
" <td>DIV</td>\n",
|
||||||
|
" <td>NaN</td>\n",
|
||||||
|
" <td>NaN</td>\n",
|
||||||
|
" <td>NaN</td>\n",
|
||||||
|
" <td>NaN</td>\n",
|
||||||
|
" <td>NaN</td>\n",
|
||||||
|
" <td>/</td>\n",
|
||||||
|
" <td>NaN</td>\n",
|
||||||
|
" <td>1623.0</td>\n",
|
||||||
|
" <td>493.0</td>\n",
|
||||||
|
" <td>NaN</td>\n",
|
||||||
|
" <td>NaN</td>\n",
|
||||||
|
" <td>NaN</td>\n",
|
||||||
|
" <td>NaN</td>\n",
|
||||||
|
" <td>NaN</td>\n",
|
||||||
|
" <td>NaN</td>\n",
|
||||||
|
" </tr>\n",
|
||||||
|
" <tr>\n",
|
||||||
|
" <th>6</th>\n",
|
||||||
|
" <td>238dc588-a7ab-4c0e-bccd-6abca5076c66</td>\n",
|
||||||
|
" <td>view_item_page</td>\n",
|
||||||
|
" <td>/products</td>\n",
|
||||||
|
" <td>htl-0</td>\n",
|
||||||
|
" <td>hotel</td>\n",
|
||||||
|
" <td>Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7...</td>\n",
|
||||||
|
" <td>2025-11-14T13:27:21.483Z</td>\n",
|
||||||
|
" <td>NaN</td>\n",
|
||||||
|
" <td>Premium Room</td>\n",
|
||||||
|
" <td>264.0</td>\n",
|
||||||
|
" <td>2.0</td>\n",
|
||||||
|
" <td>NaN</td>\n",
|
||||||
|
" <td>NaN</td>\n",
|
||||||
|
" </tr>\n",
|
||||||
|
" <tr>\n",
|
||||||
|
" <th>7</th>\n",
|
||||||
|
" <td>238dc588-a7ab-4c0e-bccd-6abca5076c66</td>\n",
|
||||||
|
" <td>hover_over_title</td>\n",
|
||||||
|
" <td>/products</td>\n",
|
||||||
|
" <td>htl-0</td>\n",
|
||||||
|
" <td>hotel</td>\n",
|
||||||
|
" <td>Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7...</td>\n",
|
||||||
|
" <td>2025-11-14T13:27:22.646Z</td>\n",
|
||||||
|
" <td>NaN</td>\n",
|
||||||
|
" <td>NaN</td>\n",
|
||||||
|
" <td>NaN</td>\n",
|
||||||
|
" <td>NaN</td>\n",
|
||||||
|
" <td>Grand Plaza Hotel</td>\n",
|
||||||
|
" <td>1200.0</td>\n",
|
||||||
|
" </tr>\n",
|
||||||
|
" <tr>\n",
|
||||||
|
" <th>8</th>\n",
|
||||||
|
" <td>238dc588-a7ab-4c0e-bccd-6abca5076c66</td>\n",
|
||||||
|
" <td>view_item_page</td>\n",
|
||||||
|
" <td>/products</td>\n",
|
||||||
|
" <td>htl-0</td>\n",
|
||||||
|
" <td>hotel</td>\n",
|
||||||
|
" <td>Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7...</td>\n",
|
||||||
|
" <td>2025-11-14T13:27:25.889Z</td>\n",
|
||||||
|
" <td>NaN</td>\n",
|
||||||
|
" <td>Premium Room</td>\n",
|
||||||
|
" <td>264.0</td>\n",
|
||||||
|
" <td>2.0</td>\n",
|
||||||
|
" <td>NaN</td>\n",
|
||||||
|
" <td>NaN</td>\n",
|
||||||
|
" </tr>\n",
|
||||||
|
" <tr>\n",
|
||||||
|
" <th>35</th>\n",
|
||||||
|
" <td>013fc334-4045-4d5a-8739-dd0a8766a63b</td>\n",
|
||||||
|
" <td>page_view</td>\n",
|
||||||
|
" <td>/products</td>\n",
|
||||||
|
" <td>None</td>\n",
|
||||||
|
" <td>hotel</td>\n",
|
||||||
|
" <td>Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/53...</td>\n",
|
||||||
|
" <td>2025-11-14T13:53:59.993Z</td>\n",
|
||||||
|
" <td></td>\n",
|
||||||
|
" <td>NaN</td>\n",
|
||||||
|
" <td>NaN</td>\n",
|
||||||
|
" <td>NaN</td>\n",
|
||||||
|
" <td>NaN</td>\n",
|
||||||
|
" <td>NaN</td>\n",
|
||||||
|
" </tr>\n",
|
||||||
|
" <tr>\n",
|
||||||
|
" <th>36</th>\n",
|
||||||
|
" <td>013fc334-4045-4d5a-8739-dd0a8766a63b</td>\n",
|
||||||
|
" <td>view_item_page</td>\n",
|
||||||
|
" <td>/products</td>\n",
|
||||||
|
" <td>htl-0</td>\n",
|
||||||
|
" <td>hotel</td>\n",
|
||||||
|
" <td>Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/53...</td>\n",
|
||||||
|
" <td>2025-11-14T13:54:10.705Z</td>\n",
|
||||||
|
" <td>NaN</td>\n",
|
||||||
|
" <td>Premium Room</td>\n",
|
||||||
|
" <td>223.0</td>\n",
|
||||||
|
" <td>3.0</td>\n",
|
||||||
|
" <td>NaN</td>\n",
|
||||||
|
" <td>NaN</td>\n",
|
||||||
|
" </tr>\n",
|
||||||
|
" <tr>\n",
|
||||||
|
" <th>37</th>\n",
|
||||||
|
" <td>013fc334-4045-4d5a-8739-dd0a8766a63b</td>\n",
|
||||||
|
" <td>hover_over_title</td>\n",
|
||||||
|
" <td>/products</td>\n",
|
||||||
|
" <td>htl-0</td>\n",
|
||||||
|
" <td>hotel</td>\n",
|
||||||
|
" <td>Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/53...</td>\n",
|
||||||
|
" <td>2025-11-14T13:54:11.771Z</td>\n",
|
||||||
|
" <td>NaN</td>\n",
|
||||||
|
" <td>NaN</td>\n",
|
||||||
|
" <td>416.0</td>\n",
|
||||||
|
" <td>397.0</td>\n",
|
||||||
|
" <td>NaN</td>\n",
|
||||||
|
" <td>NaN</td>\n",
|
||||||
|
" <td>NaN</td>\n",
|
||||||
|
" <td>NaN</td>\n",
|
||||||
|
" <td>NaN</td>\n",
|
||||||
|
" <td>NaN</td>\n",
|
||||||
|
" <td>Grand Plaza Hotel</td>\n",
|
||||||
|
" <td>1200.0</td>\n",
|
||||||
|
" </tr>\n",
|
||||||
|
" <tr>\n",
|
||||||
|
" <th>38</th>\n",
|
||||||
|
" <td>013fc334-4045-4d5a-8739-dd0a8766a63b</td>\n",
|
||||||
|
" <td>view_item_page</td>\n",
|
||||||
|
" <td>/products</td>\n",
|
||||||
|
" <td>htl-1</td>\n",
|
||||||
|
" <td>hotel</td>\n",
|
||||||
|
" <td>Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/53...</td>\n",
|
||||||
|
" <td>2025-11-14T13:54:29.772Z</td>\n",
|
||||||
|
" <td>NaN</td>\n",
|
||||||
|
" <td>Standard Room</td>\n",
|
||||||
|
" <td>267.0</td>\n",
|
||||||
|
" <td>5.0</td>\n",
|
||||||
|
" <td>NaN</td>\n",
|
||||||
|
" <td>NaN</td>\n",
|
||||||
|
" </tr>\n",
|
||||||
|
" <tr>\n",
|
||||||
|
" <th>39</th>\n",
|
||||||
|
" <td>013fc334-4045-4d5a-8739-dd0a8766a63b</td>\n",
|
||||||
|
" <td>hover_over_title</td>\n",
|
||||||
|
" <td>/products</td>\n",
|
||||||
|
" <td>htl-1</td>\n",
|
||||||
|
" <td>hotel</td>\n",
|
||||||
|
" <td>Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/53...</td>\n",
|
||||||
|
" <td>2025-11-14T13:54:30.833Z</td>\n",
|
||||||
|
" <td>NaN</td>\n",
|
||||||
|
" <td>NaN</td>\n",
|
||||||
|
" <td>NaN</td>\n",
|
||||||
|
" <td>NaN</td>\n",
|
||||||
|
" <td>Seaside Resort</td>\n",
|
||||||
|
" <td>1200.0</td>\n",
|
||||||
|
" </tr>\n",
|
||||||
|
" </tbody>\n",
|
||||||
|
"</table>\n",
|
||||||
|
"</div>"
|
||||||
|
],
|
||||||
|
"text/plain": [
|
||||||
|
" sessionId eventName page \\\n",
|
||||||
|
"0 d176d7c9-4027-4702-9e31-2a71395cdda0 page_view /products \n",
|
||||||
|
"1 f0317a5d-e424-44e9-b784-c8f7291ffe31 page_view / \n",
|
||||||
|
"2 f0317a5d-e424-44e9-b784-c8f7291ffe31 page_view /products \n",
|
||||||
|
"3 f0317a5d-e424-44e9-b784-c8f7291ffe31 view_item_page /products \n",
|
||||||
|
"4 238dc588-a7ab-4c0e-bccd-6abca5076c66 page_view /products \n",
|
||||||
|
"5 238dc588-a7ab-4c0e-bccd-6abca5076c66 view_item_page /products \n",
|
||||||
|
"6 238dc588-a7ab-4c0e-bccd-6abca5076c66 view_item_page /products \n",
|
||||||
|
"7 238dc588-a7ab-4c0e-bccd-6abca5076c66 hover_over_title /products \n",
|
||||||
|
"8 238dc588-a7ab-4c0e-bccd-6abca5076c66 view_item_page /products \n",
|
||||||
|
"35 013fc334-4045-4d5a-8739-dd0a8766a63b page_view /products \n",
|
||||||
|
"36 013fc334-4045-4d5a-8739-dd0a8766a63b view_item_page /products \n",
|
||||||
|
"37 013fc334-4045-4d5a-8739-dd0a8766a63b hover_over_title /products \n",
|
||||||
|
"38 013fc334-4045-4d5a-8739-dd0a8766a63b view_item_page /products \n",
|
||||||
|
"39 013fc334-4045-4d5a-8739-dd0a8766a63b hover_over_title /products \n",
|
||||||
|
"\n",
|
||||||
|
" productId storeMode userAgent \\\n",
|
||||||
|
"0 None hotel Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/53... \n",
|
||||||
|
"1 None hotel Mozilla/5.0 (X11; Linux x86_64; rv:143.0) Geck... \n",
|
||||||
|
"2 None hotel Mozilla/5.0 (X11; Linux x86_64; rv:143.0) Geck... \n",
|
||||||
|
"3 htl-0 hotel Mozilla/5.0 (X11; Linux x86_64; rv:143.0) Geck... \n",
|
||||||
|
"4 None hotel Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7... \n",
|
||||||
|
"5 htl-0 hotel Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7... \n",
|
||||||
|
"6 htl-0 hotel Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7... \n",
|
||||||
|
"7 htl-0 hotel Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7... \n",
|
||||||
|
"8 htl-0 hotel Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7... \n",
|
||||||
|
"35 None hotel Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/53... \n",
|
||||||
|
"36 htl-0 hotel Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/53... \n",
|
||||||
|
"37 htl-0 hotel Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/53... \n",
|
||||||
|
"38 htl-1 hotel Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/53... \n",
|
||||||
|
"39 htl-1 hotel Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/53... \n",
|
||||||
|
"\n",
|
||||||
|
" ts metadata_referrer metadata_roomType \\\n",
|
||||||
|
"0 2025-11-14T13:23:46.270Z NaN \n",
|
||||||
|
"1 2025-11-14T13:26:00.291Z NaN \n",
|
||||||
|
"2 2025-11-14T13:26:07.769Z NaN \n",
|
||||||
|
"3 2025-11-14T13:26:15.010Z NaN Premium Room \n",
|
||||||
|
"4 2025-11-14T13:27:15.457Z NaN \n",
|
||||||
|
"5 2025-11-14T13:27:15.591Z NaN Premium Room \n",
|
||||||
|
"6 2025-11-14T13:27:21.483Z NaN Premium Room \n",
|
||||||
|
"7 2025-11-14T13:27:22.646Z NaN NaN \n",
|
||||||
|
"8 2025-11-14T13:27:25.889Z NaN Premium Room \n",
|
||||||
|
"35 2025-11-14T13:53:59.993Z NaN \n",
|
||||||
|
"36 2025-11-14T13:54:10.705Z NaN Premium Room \n",
|
||||||
|
"37 2025-11-14T13:54:11.771Z NaN NaN \n",
|
||||||
|
"38 2025-11-14T13:54:29.772Z NaN Standard Room \n",
|
||||||
|
"39 2025-11-14T13:54:30.833Z NaN NaN \n",
|
||||||
|
"\n",
|
||||||
|
" metadata_price metadata_nights metadata_elementText metadata_dwellTime \n",
|
||||||
|
"0 NaN NaN NaN NaN \n",
|
||||||
|
"1 NaN NaN NaN NaN \n",
|
||||||
|
"2 NaN NaN NaN NaN \n",
|
||||||
|
"3 269.0 1.0 NaN NaN \n",
|
||||||
|
"4 NaN NaN NaN NaN \n",
|
||||||
|
"5 264.0 2.0 NaN NaN \n",
|
||||||
|
"6 264.0 2.0 NaN NaN \n",
|
||||||
|
"7 NaN NaN Grand Plaza Hotel 1200.0 \n",
|
||||||
|
"8 264.0 2.0 NaN NaN \n",
|
||||||
|
"35 NaN NaN NaN NaN \n",
|
||||||
|
"36 223.0 3.0 NaN NaN \n",
|
||||||
|
"37 NaN NaN Grand Plaza Hotel 1200.0 \n",
|
||||||
|
"38 267.0 5.0 NaN NaN \n",
|
||||||
|
"39 NaN NaN Seaside Resort 1200.0 "
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 12,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"df.groupby('sessionId').head()"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 13,
|
||||||
|
"id": "380eca5f-8304-4fb2-be32-e8bcfd312085",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"['013fc334-4045-4d5a-8739-dd0a8766a63b',\n",
|
||||||
|
" '238dc588-a7ab-4c0e-bccd-6abca5076c66',\n",
|
||||||
|
" 'd176d7c9-4027-4702-9e31-2a71395cdda0',\n",
|
||||||
|
" 'f0317a5d-e424-44e9-b784-c8f7291ffe31']"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 13,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"sessions = list(set(df['sessionId'])); sessions # 238dc588-a7ab-4c0e-bccd-6abca5076c66"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 14,
|
||||||
|
"id": "f4ae6f81-dcb8-44be-aee7-30dbc3a6bae1",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# map sessions to experiments"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 15,
|
||||||
|
"id": "050d90a4-20a9-47f5-b998-c31178a54cb3",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"def build_transition_prob_matrix(df: pd.DataFrame):\n",
|
||||||
|
" df = df.dropna(subset=['eventName'])\n",
|
||||||
|
" events = df['eventName'].tolist()\n",
|
||||||
|
" labels = pd.Index(events).unique().tolist()\n",
|
||||||
|
" idx = {e:i for i,e in enumerate(labels)}\n",
|
||||||
|
" M = np.zeros((len(labels), len(labels)), dtype=float)\n",
|
||||||
|
" for a, b in zip(events, events[1:]):\n",
|
||||||
|
" M[idx[a], idx[b]] += 1\n",
|
||||||
|
" row_sums = M.sum(axis=1, keepdims=True)\n",
|
||||||
|
" with np.errstate(divide='ignore', invalid='ignore'):\n",
|
||||||
|
" P = np.divide(M, row_sums, where=row_sums>0) # row-normalized\n",
|
||||||
|
" return P, labels"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 16,
|
||||||
|
"id": "e68f9004-82f5-4826-aece-e3dc6e15a18f",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# https://medium.com/data-science/time-series-data-markov-transition-matrices-7060771e362b\n",
|
||||||
|
"from graphviz import Digraph\n",
|
||||||
|
"import numpy as np\n",
|
||||||
|
"import pandas as pd\n",
|
||||||
|
"\n",
|
||||||
|
"def _as_prob_df(matrix, labels=None):\n",
|
||||||
|
" \"\"\"Return a square DataFrame with index=columns=labels.\"\"\"\n",
|
||||||
|
" if isinstance(matrix, pd.DataFrame):\n",
|
||||||
|
" # Ensure square and aligned\n",
|
||||||
|
" assert (matrix.index == matrix.columns).all(), \"Index/columns must match.\"\n",
|
||||||
|
" return matrix\n",
|
||||||
|
" matrix = np.asarray(matrix, dtype=float)\n",
|
||||||
|
" assert matrix.shape[0] == matrix.shape[1], \"Matrix must be square.\"\n",
|
||||||
|
" if labels is None:\n",
|
||||||
|
" raise ValueError(\"labels are required when matrix is not a DataFrame\")\n",
|
||||||
|
" assert len(labels) == matrix.shape[0], \"labels length must match matrix size.\"\n",
|
||||||
|
" return pd.DataFrame(matrix, index=list(labels), columns=list(labels))\n",
|
||||||
|
"\n",
|
||||||
|
"def _df_to_edgelist(P: pd.DataFrame, threshold=0.0, round_digits=2):\n",
|
||||||
|
" \"\"\"Build weighted edges > threshold.\"\"\"\n",
|
||||||
|
" edges = []\n",
|
||||||
|
" for src in P.index:\n",
|
||||||
|
" for dst in P.columns:\n",
|
||||||
|
" w = float(P.loc[src, dst])\n",
|
||||||
|
" if w > threshold:\n",
|
||||||
|
" edges.append((str(src), str(dst), f\"{w:.{round_digits}f}\"))\n",
|
||||||
|
" return edges\n",
|
||||||
|
"\n",
|
||||||
|
"def render_graph(fname, matrix, ls_index=None, threshold=0.0, fmt=\"svg\", view=False):\n",
|
||||||
|
" \"\"\"\n",
|
||||||
|
" fname: output file stem (no extension)\n",
|
||||||
|
" matrix: NumPy array or pandas DataFrame of transition PROBABILITIES\n",
|
||||||
|
" ls_index: ordered labels (required if matrix is not a DataFrame)\n",
|
||||||
|
" threshold: hide edges with weight <= threshold\n",
|
||||||
|
" fmt: 'svg'|'png'|'pdf' etc.\n",
|
||||||
|
" view: open after rendering\n",
|
||||||
|
" \"\"\"\n",
|
||||||
|
" P = _as_prob_df(matrix, labels=ls_index)\n",
|
||||||
|
" edges = _df_to_edgelist(P, threshold=threshold)\n",
|
||||||
|
"\n",
|
||||||
|
" g = Digraph(format=fmt)\n",
|
||||||
|
" g.attr(rankdir=\"LR\", size=\"30\")\n",
|
||||||
|
" g.attr(\"node\", shape=\"circle\")\n",
|
||||||
|
"\n",
|
||||||
|
" # ensure isolated nodes appear\n",
|
||||||
|
" for node in P.index:\n",
|
||||||
|
" g.node(str(node), width=\"1\", height=\"1\")\n",
|
||||||
|
"\n",
|
||||||
|
" for src, dst, label in edges:\n",
|
||||||
|
" g.edge(src, dst, label=label)\n",
|
||||||
|
"\n",
|
||||||
|
" g.render(fname, view=view, cleanup=True)\n",
|
||||||
|
" return g\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 17,
|
||||||
|
"id": "e255a2c1-6454-4e5e-89f6-ef8ac51ab6cc",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"013fc334-4045-4d5a-8739-dd0a8766a63b\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"image/svg+xml": [
|
||||||
|
"<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n",
|
||||||
|
"<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
|
||||||
|
" \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
|
||||||
|
"<!-- Generated by graphviz version 13.1.2 (0)\n",
|
||||||
|
" -->\n",
|
||||||
|
"<!-- Pages: 1 -->\n",
|
||||||
|
"<svg width=\"565pt\" height=\"354pt\"\n",
|
||||||
|
" viewBox=\"0.00 0.00 565.00 354.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
|
||||||
|
"<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 349.64)\">\n",
|
||||||
|
"<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-349.64 561.05,-349.64 561.05,4 -4,4\"/>\n",
|
||||||
|
"<!-- page_view -->\n",
|
||||||
|
"<g id=\"node1\" class=\"node\">\n",
|
||||||
|
"<title>page_view</title>\n",
|
||||||
|
"<ellipse fill=\"none\" stroke=\"black\" cx=\"48.19\" cy=\"-235.83\" rx=\"48.19\" ry=\"48.19\"/>\n",
|
||||||
|
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"48.19\" y=\"-231.16\" font-family=\"Times,serif\" font-size=\"14.00\">page_view</text>\n",
|
||||||
|
"</g>\n",
|
||||||
|
"<!-- view_item_page -->\n",
|
||||||
|
"<g id=\"node2\" class=\"node\">\n",
|
||||||
|
"<title>view_item_page</title>\n",
|
||||||
|
"<ellipse fill=\"none\" stroke=\"black\" cx=\"232.88\" cy=\"-235.83\" rx=\"69.01\" ry=\"69.01\"/>\n",
|
||||||
|
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"232.88\" y=\"-231.16\" font-family=\"Times,serif\" font-size=\"14.00\">view_item_page</text>\n",
|
||||||
|
"</g>\n",
|
||||||
|
"<!-- page_view->view_item_page -->\n",
|
||||||
|
"<g id=\"edge1\" class=\"edge\">\n",
|
||||||
|
"<title>page_view->view_item_page</title>\n",
|
||||||
|
"<path fill=\"none\" stroke=\"black\" d=\"M96.71,-235.83C113.69,-235.83 133.31,-235.83 152.25,-235.83\"/>\n",
|
||||||
|
"<polygon fill=\"black\" stroke=\"black\" points=\"152.1,-239.33 162.1,-235.83 152.1,-232.33 152.1,-239.33\"/>\n",
|
||||||
|
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"130.12\" y=\"-239.78\" font-family=\"Times,serif\" font-size=\"14.00\">1.00</text>\n",
|
||||||
|
"</g>\n",
|
||||||
|
"<!-- view_item_page->view_item_page -->\n",
|
||||||
|
"<g id=\"edge2\" class=\"edge\">\n",
|
||||||
|
"<title>view_item_page->view_item_page</title>\n",
|
||||||
|
"<path fill=\"none\" stroke=\"black\" d=\"M214.74,-302.59C217.1,-314.51 223.14,-322.84 232.88,-322.84 239.27,-322.84 244.07,-319.26 247.28,-313.42\"/>\n",
|
||||||
|
"<polygon fill=\"black\" stroke=\"black\" points=\"250.57,-314.62 250.52,-304.02 243.95,-312.33 250.57,-314.62\"/>\n",
|
||||||
|
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"232.88\" y=\"-326.79\" font-family=\"Times,serif\" font-size=\"14.00\">0.68</text>\n",
|
||||||
|
"</g>\n",
|
||||||
|
"<!-- hover_over_title -->\n",
|
||||||
|
"<g id=\"node3\" class=\"node\">\n",
|
||||||
|
"<title>hover_over_title</title>\n",
|
||||||
|
"<ellipse fill=\"none\" stroke=\"black\" cx=\"463.22\" cy=\"-275.83\" rx=\"69.81\" ry=\"69.81\"/>\n",
|
||||||
|
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"463.22\" y=\"-271.16\" font-family=\"Times,serif\" font-size=\"14.00\">hover_over_title</text>\n",
|
||||||
|
"</g>\n",
|
||||||
|
"<!-- view_item_page->hover_over_title -->\n",
|
||||||
|
"<g id=\"edge3\" class=\"edge\">\n",
|
||||||
|
"<title>view_item_page->hover_over_title</title>\n",
|
||||||
|
"<path fill=\"none\" stroke=\"black\" d=\"M300.48,-250.14C307.03,-251.43 313.58,-252.69 319.89,-253.83 340.12,-257.51 362.05,-261.1 382.5,-264.27\"/>\n",
|
||||||
|
"<polygon fill=\"black\" stroke=\"black\" points=\"381.77,-267.7 392.19,-265.76 382.83,-260.78 381.77,-267.7\"/>\n",
|
||||||
|
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"335.64\" y=\"-263.17\" font-family=\"Times,serif\" font-size=\"14.00\">0.29</text>\n",
|
||||||
|
"</g>\n",
|
||||||
|
"<!-- hover_over_paragraph -->\n",
|
||||||
|
"<g id=\"node4\" class=\"node\">\n",
|
||||||
|
"<title>hover_over_paragraph</title>\n",
|
||||||
|
"<ellipse fill=\"none\" stroke=\"black\" cx=\"463.22\" cy=\"-93.83\" rx=\"93.83\" ry=\"93.83\"/>\n",
|
||||||
|
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"463.22\" y=\"-89.16\" font-family=\"Times,serif\" font-size=\"14.00\">hover_over_paragraph</text>\n",
|
||||||
|
"</g>\n",
|
||||||
|
"<!-- view_item_page->hover_over_paragraph -->\n",
|
||||||
|
"<g id=\"edge4\" class=\"edge\">\n",
|
||||||
|
"<title>view_item_page->hover_over_paragraph</title>\n",
|
||||||
|
"<path fill=\"none\" stroke=\"black\" d=\"M292.09,-199.63C316.79,-184.27 346.14,-166.02 373.44,-149.04\"/>\n",
|
||||||
|
"<polygon fill=\"black\" stroke=\"black\" points=\"375.08,-152.15 381.72,-143.89 371.38,-146.2 375.08,-152.15\"/>\n",
|
||||||
|
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"335.64\" y=\"-185.68\" font-family=\"Times,serif\" font-size=\"14.00\">0.04</text>\n",
|
||||||
|
"</g>\n",
|
||||||
|
"<!-- hover_over_title->view_item_page -->\n",
|
||||||
|
"<g id=\"edge5\" class=\"edge\">\n",
|
||||||
|
"<title>hover_over_title->view_item_page</title>\n",
|
||||||
|
"<path fill=\"none\" stroke=\"black\" d=\"M399.53,-246.73C384.12,-240.88 367.42,-235.6 351.39,-232.58 339.13,-230.28 326.03,-229.26 313.19,-229.04\"/>\n",
|
||||||
|
"<polygon fill=\"black\" stroke=\"black\" points=\"313.51,-225.54 303.51,-229.04 313.51,-232.54 313.51,-225.54\"/>\n",
|
||||||
|
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"335.64\" y=\"-236.53\" font-family=\"Times,serif\" font-size=\"14.00\">1.00</text>\n",
|
||||||
|
"</g>\n",
|
||||||
|
"</svg>\n"
|
||||||
|
],
|
||||||
|
"text/plain": [
|
||||||
|
"<graphviz.graphs.Digraph at 0x7f0779e818b0>"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "display_data"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"[]\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"image/svg+xml": [
|
||||||
|
"<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n",
|
||||||
|
"<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
|
||||||
|
" \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
|
||||||
|
"<!-- Generated by graphviz version 13.1.2 (0)\n",
|
||||||
|
" -->\n",
|
||||||
|
"<!-- Pages: 1 -->\n",
|
||||||
|
"<svg width=\"8pt\" height=\"8pt\"\n",
|
||||||
|
" viewBox=\"0.00 0.00 8.00 8.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
|
||||||
|
"<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 4)\">\n",
|
||||||
|
"<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-4 4,-4 4,4 -4,4\"/>\n",
|
||||||
|
"</g>\n",
|
||||||
|
"</svg>\n"
|
||||||
|
],
|
||||||
|
"text/plain": [
|
||||||
|
"<graphviz.graphs.Digraph at 0x7f6800fac980>"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "display_data"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"[[0.00000000e+000 1.00000000e+000 0.00000000e+000 0.00000000e+000]\n",
|
||||||
|
" [0.00000000e+000 6.78571429e-001 2.85714286e-001 3.57142857e-002]\n",
|
||||||
|
" [0.00000000e+000 1.00000000e+000 0.00000000e+000 0.00000000e+000]\n",
|
||||||
|
" [2.05833592e-312 2.29175545e-312 4.94065646e-324 6.92110218e-310]]\n",
|
||||||
|
"238dc588-a7ab-4c0e-bccd-6abca5076c66\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"image/svg+xml": [
|
||||||
|
"<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n",
|
||||||
|
"<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
|
||||||
|
" \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
|
||||||
|
"<!-- Generated by graphviz version 13.1.2 (0)\n",
|
||||||
|
" -->\n",
|
||||||
|
"<!-- Pages: 1 -->\n",
|
||||||
|
"<svg width=\"565pt\" height=\"354pt\"\n",
|
||||||
|
" viewBox=\"0.00 0.00 565.00 354.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
|
||||||
|
"<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 349.64)\">\n",
|
||||||
|
"<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-349.64 561.05,-349.64 561.05,4 -4,4\"/>\n",
|
||||||
|
"<!-- page_view -->\n",
|
||||||
|
"<g id=\"node1\" class=\"node\">\n",
|
||||||
|
"<title>page_view</title>\n",
|
||||||
|
"<ellipse fill=\"none\" stroke=\"black\" cx=\"48.19\" cy=\"-109.83\" rx=\"48.19\" ry=\"48.19\"/>\n",
|
||||||
|
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"48.19\" y=\"-105.16\" font-family=\"Times,serif\" font-size=\"14.00\">page_view</text>\n",
|
||||||
|
"</g>\n",
|
||||||
|
"<!-- view_item_page -->\n",
|
||||||
|
"<g id=\"node2\" class=\"node\">\n",
|
||||||
|
"<title>view_item_page</title>\n",
|
||||||
|
"<ellipse fill=\"none\" stroke=\"black\" cx=\"232.88\" cy=\"-197.83\" rx=\"69.01\" ry=\"69.01\"/>\n",
|
||||||
|
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"232.88\" y=\"-193.16\" font-family=\"Times,serif\" font-size=\"14.00\">view_item_page</text>\n",
|
||||||
|
"</g>\n",
|
||||||
|
"<!-- page_view->view_item_page -->\n",
|
||||||
|
"<g id=\"edge1\" class=\"edge\">\n",
|
||||||
|
"<title>page_view->view_item_page</title>\n",
|
||||||
|
"<path fill=\"none\" stroke=\"black\" d=\"M92.02,-130.47C112.32,-140.25 137.13,-152.2 160.18,-163.3\"/>\n",
|
||||||
|
"<polygon fill=\"black\" stroke=\"black\" points=\"158.39,-166.32 168.92,-167.51 161.43,-160.02 158.39,-166.32\"/>\n",
|
||||||
|
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"130.12\" y=\"-157.78\" font-family=\"Times,serif\" font-size=\"14.00\">1.00</text>\n",
|
||||||
|
"</g>\n",
|
||||||
|
"<!-- view_item_page->view_item_page -->\n",
|
||||||
|
"<g id=\"edge2\" class=\"edge\">\n",
|
||||||
|
"<title>view_item_page->view_item_page</title>\n",
|
||||||
|
"<path fill=\"none\" stroke=\"black\" d=\"M214.74,-264.59C217.1,-276.51 223.14,-284.84 232.88,-284.84 239.27,-284.84 244.07,-281.26 247.28,-275.42\"/>\n",
|
||||||
|
"<polygon fill=\"black\" stroke=\"black\" points=\"250.57,-276.62 250.52,-266.02 243.95,-274.33 250.57,-276.62\"/>\n",
|
||||||
|
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"232.88\" y=\"-288.79\" font-family=\"Times,serif\" font-size=\"14.00\">0.19</text>\n",
|
||||||
|
"</g>\n",
|
||||||
|
"<!-- hover_over_title -->\n",
|
||||||
|
"<g id=\"node3\" class=\"node\">\n",
|
||||||
|
"<title>hover_over_title</title>\n",
|
||||||
|
"<ellipse fill=\"none\" stroke=\"black\" cx=\"463.22\" cy=\"-275.83\" rx=\"69.81\" ry=\"69.81\"/>\n",
|
||||||
|
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"463.22\" y=\"-271.16\" font-family=\"Times,serif\" font-size=\"14.00\">hover_over_title</text>\n",
|
||||||
|
"</g>\n",
|
||||||
|
"<!-- view_item_page->hover_over_title -->\n",
|
||||||
|
"<g id=\"edge3\" class=\"edge\">\n",
|
||||||
|
"<title>view_item_page->hover_over_title</title>\n",
|
||||||
|
"<path fill=\"none\" stroke=\"black\" d=\"M289.6,-237.16C299.36,-242.77 309.67,-247.94 319.89,-251.83 339.45,-259.28 361.4,-264.43 382.1,-267.98\"/>\n",
|
||||||
|
"<polygon fill=\"black\" stroke=\"black\" points=\"381.52,-271.43 391.95,-269.55 382.62,-264.52 381.52,-271.43\"/>\n",
|
||||||
|
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"335.64\" y=\"-265.16\" font-family=\"Times,serif\" font-size=\"14.00\">0.38</text>\n",
|
||||||
|
"</g>\n",
|
||||||
|
"<!-- hover_over_paragraph -->\n",
|
||||||
|
"<g id=\"node4\" class=\"node\">\n",
|
||||||
|
"<title>hover_over_paragraph</title>\n",
|
||||||
|
"<ellipse fill=\"none\" stroke=\"black\" cx=\"463.22\" cy=\"-93.83\" rx=\"93.83\" ry=\"93.83\"/>\n",
|
||||||
|
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"463.22\" y=\"-89.16\" font-family=\"Times,serif\" font-size=\"14.00\">hover_over_paragraph</text>\n",
|
||||||
|
"</g>\n",
|
||||||
|
"<!-- view_item_page->hover_over_paragraph -->\n",
|
||||||
|
"<g id=\"edge4\" class=\"edge\">\n",
|
||||||
|
"<title>view_item_page->hover_over_paragraph</title>\n",
|
||||||
|
"<path fill=\"none\" stroke=\"black\" d=\"M300.22,-180.71C317.22,-175.46 335.24,-169.12 351.39,-161.83 358.97,-158.41 366.67,-154.57 374.29,-150.49\"/>\n",
|
||||||
|
"<polygon fill=\"black\" stroke=\"black\" points=\"375.84,-153.63 382.92,-145.75 372.47,-147.5 375.84,-153.63\"/>\n",
|
||||||
|
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"335.64\" y=\"-178.15\" font-family=\"Times,serif\" font-size=\"14.00\">0.44</text>\n",
|
||||||
|
"</g>\n",
|
||||||
|
"<!-- hover_over_title->view_item_page -->\n",
|
||||||
|
"<g id=\"edge5\" class=\"edge\">\n",
|
||||||
|
"<title>hover_over_title->view_item_page</title>\n",
|
||||||
|
"<path fill=\"none\" stroke=\"black\" d=\"M398.52,-248.36C383.21,-242.16 366.82,-235.87 351.39,-230.58 338.42,-226.15 324.5,-221.86 310.94,-217.93\"/>\n",
|
||||||
|
"<polygon fill=\"black\" stroke=\"black\" points=\"312.2,-214.65 301.62,-215.28 310.28,-221.39 312.2,-214.65\"/>\n",
|
||||||
|
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"335.64\" y=\"-234.53\" font-family=\"Times,serif\" font-size=\"14.00\">1.00</text>\n",
|
||||||
|
"</g>\n",
|
||||||
|
"<!-- hover_over_paragraph->page_view -->\n",
|
||||||
|
"<g id=\"edge6\" class=\"edge\">\n",
|
||||||
|
"<title>hover_over_paragraph->page_view</title>\n",
|
||||||
|
"<path fill=\"none\" stroke=\"black\" d=\"M369.13,-95.76C310.26,-97.17 232.59,-99.41 163.87,-102.58 145.72,-103.42 125.98,-104.58 108.06,-105.73\"/>\n",
|
||||||
|
"<polygon fill=\"black\" stroke=\"black\" points=\"107.86,-102.24 98.1,-106.38 108.31,-109.22 107.86,-102.24\"/>\n",
|
||||||
|
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"232.88\" y=\"-106.53\" font-family=\"Times,serif\" font-size=\"14.00\">0.14</text>\n",
|
||||||
|
"</g>\n",
|
||||||
|
"<!-- hover_over_paragraph->view_item_page -->\n",
|
||||||
|
"<g id=\"edge7\" class=\"edge\">\n",
|
||||||
|
"<title>hover_over_paragraph->view_item_page</title>\n",
|
||||||
|
"<path fill=\"none\" stroke=\"black\" d=\"M372.68,-119.15C354.84,-125.32 336.5,-132.51 319.89,-140.58 312.9,-143.98 305.81,-147.87 298.86,-151.98\"/>\n",
|
||||||
|
"<polygon fill=\"black\" stroke=\"black\" points=\"297.49,-148.71 290.78,-156.91 301.14,-154.69 297.49,-148.71\"/>\n",
|
||||||
|
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"335.64\" y=\"-144.53\" font-family=\"Times,serif\" font-size=\"14.00\">0.86</text>\n",
|
||||||
|
"</g>\n",
|
||||||
|
"</g>\n",
|
||||||
|
"</svg>\n"
|
||||||
|
],
|
||||||
|
"text/plain": [
|
||||||
|
"<graphviz.graphs.Digraph at 0x7f6800f97110>"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "display_data"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"[[0. 1. 0. 0. ]\n",
|
||||||
|
" [0. 0.1875 0.375 0.4375 ]\n",
|
||||||
|
" [0. 1. 0. 0. ]\n",
|
||||||
|
" [0.14285714 0.85714286 0. 0. ]]\n",
|
||||||
|
"d176d7c9-4027-4702-9e31-2a71395cdda0\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"image/svg+xml": [
|
||||||
|
"<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n",
|
||||||
|
"<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
|
||||||
|
" \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
|
||||||
|
"<!-- Generated by graphviz version 13.1.2 (0)\n",
|
||||||
|
" -->\n",
|
||||||
|
"<!-- Pages: 1 -->\n",
|
||||||
|
"<svg width=\"104pt\" height=\"104pt\"\n",
|
||||||
|
" viewBox=\"0.00 0.00 104.00 104.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
|
||||||
|
"<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 100.37)\">\n",
|
||||||
|
"<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-100.37 100.37,-100.37 100.37,4 -4,4\"/>\n",
|
||||||
|
"<!-- page_view -->\n",
|
||||||
|
"<g id=\"node1\" class=\"node\">\n",
|
||||||
|
"<title>page_view</title>\n",
|
||||||
|
"<ellipse fill=\"none\" stroke=\"black\" cx=\"48.19\" cy=\"-48.19\" rx=\"48.19\" ry=\"48.19\"/>\n",
|
||||||
|
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"48.19\" y=\"-43.51\" font-family=\"Times,serif\" font-size=\"14.00\">page_view</text>\n",
|
||||||
|
"</g>\n",
|
||||||
|
"</g>\n",
|
||||||
|
"</svg>\n"
|
||||||
|
],
|
||||||
|
"text/plain": [
|
||||||
|
"<graphviz.graphs.Digraph at 0x7f6800f97110>"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "display_data"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"[[0.]]\n",
|
||||||
|
"f0317a5d-e424-44e9-b784-c8f7291ffe31\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"image/svg+xml": [
|
||||||
|
"<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n",
|
||||||
|
"<!DOCTYPE svg PUBLIC \"-//W3C//DTD SVG 1.1//EN\"\n",
|
||||||
|
" \"http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd\">\n",
|
||||||
|
"<!-- Generated by graphviz version 13.1.2 (0)\n",
|
||||||
|
" -->\n",
|
||||||
|
"<!-- Pages: 1 -->\n",
|
||||||
|
"<svg width=\"310pt\" height=\"160pt\"\n",
|
||||||
|
" viewBox=\"0.00 0.00 310.00 160.00\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n",
|
||||||
|
"<g id=\"graph0\" class=\"graph\" transform=\"scale(1 1) rotate(0) translate(4 156.44)\">\n",
|
||||||
|
"<polygon fill=\"white\" stroke=\"none\" points=\"-4,4 -4,-156.44 305.89,-156.44 305.89,4 -4,4\"/>\n",
|
||||||
|
"<!-- page_view -->\n",
|
||||||
|
"<g id=\"node1\" class=\"node\">\n",
|
||||||
|
"<title>page_view</title>\n",
|
||||||
|
"<ellipse fill=\"none\" stroke=\"black\" cx=\"48.19\" cy=\"-69.01\" rx=\"48.19\" ry=\"48.19\"/>\n",
|
||||||
|
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"48.19\" y=\"-64.33\" font-family=\"Times,serif\" font-size=\"14.00\">page_view</text>\n",
|
||||||
|
"</g>\n",
|
||||||
|
"<!-- page_view->page_view -->\n",
|
||||||
|
"<g id=\"edge1\" class=\"edge\">\n",
|
||||||
|
"<title>page_view->page_view</title>\n",
|
||||||
|
"<path fill=\"none\" stroke=\"black\" d=\"M33.03,-115.09C34.09,-126.6 39.14,-135.19 48.19,-135.19 53.98,-135.19 58.13,-131.66 60.65,-126.1\"/>\n",
|
||||||
|
"<polygon fill=\"black\" stroke=\"black\" points=\"64.01,-127.11 62.98,-116.56 57.21,-125.45 64.01,-127.11\"/>\n",
|
||||||
|
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"48.19\" y=\"-139.14\" font-family=\"Times,serif\" font-size=\"14.00\">0.50</text>\n",
|
||||||
|
"</g>\n",
|
||||||
|
"<!-- view_item_page -->\n",
|
||||||
|
"<g id=\"node2\" class=\"node\">\n",
|
||||||
|
"<title>view_item_page</title>\n",
|
||||||
|
"<ellipse fill=\"none\" stroke=\"black\" cx=\"232.88\" cy=\"-69.01\" rx=\"69.01\" ry=\"69.01\"/>\n",
|
||||||
|
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"232.88\" y=\"-64.33\" font-family=\"Times,serif\" font-size=\"14.00\">view_item_page</text>\n",
|
||||||
|
"</g>\n",
|
||||||
|
"<!-- page_view->view_item_page -->\n",
|
||||||
|
"<g id=\"edge2\" class=\"edge\">\n",
|
||||||
|
"<title>page_view->view_item_page</title>\n",
|
||||||
|
"<path fill=\"none\" stroke=\"black\" d=\"M96.71,-69.01C113.69,-69.01 133.31,-69.01 152.25,-69.01\"/>\n",
|
||||||
|
"<polygon fill=\"black\" stroke=\"black\" points=\"152.1,-72.51 162.1,-69.01 152.1,-65.51 152.1,-72.51\"/>\n",
|
||||||
|
"<text xml:space=\"preserve\" text-anchor=\"middle\" x=\"130.12\" y=\"-72.96\" font-family=\"Times,serif\" font-size=\"14.00\">0.50</text>\n",
|
||||||
|
"</g>\n",
|
||||||
|
"</g>\n",
|
||||||
|
"</svg>\n"
|
||||||
|
],
|
||||||
|
"text/plain": [
|
||||||
|
"<graphviz.graphs.Digraph at 0x7f6800bf50f0>"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "display_data"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"[[5.0e-001 5.0e-001]\n",
|
||||||
|
" [9.9e-324 1.5e-323]]\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"def explore_session(session_id: str):\n",
|
||||||
|
" subset = df[df['sessionId'] == session_id]\n",
|
||||||
|
" print(session_id)\n",
|
||||||
|
" P, labels = build_transition_prob_matrix(subset)\n",
|
||||||
|
" g = render_graph(f\"session_{session_id}\", P, ls_index=labels, threshold=0.01, fmt=\"svg\", view=False)\n",
|
||||||
|
" display(g)\n",
|
||||||
|
" return P\n",
|
||||||
|
"for session in sessions:\n",
|
||||||
|
" print(explore_session(session))"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python (PHANTOM)",
|
||||||
|
"language": "python",
|
||||||
|
"name": "phantom"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.13.7"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
|
}
|
||||||
1740
experiments/notebooks/states.ipynb
Normal file
1740
experiments/notebooks/states.ipynb
Normal file
File diff suppressed because it is too large
Load Diff
2320
experiments/notebooks/step_breakdown.ipynb
Normal file
2320
experiments/notebooks/step_breakdown.ipynb
Normal file
File diff suppressed because it is too large
Load Diff
165
experiments/procesing/tests/test_session.py
Normal file
165
experiments/procesing/tests/test_session.py
Normal file
@@ -0,0 +1,165 @@
|
|||||||
|
import pytest
|
||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
from procesing.steps.session import (
|
||||||
|
TemporalFeatureStep,
|
||||||
|
BehavioralFeatureStep,
|
||||||
|
ProductFeatureStep,
|
||||||
|
UserAgentFeatureStep,
|
||||||
|
ExtractSessionFeaturesStep,
|
||||||
|
JoinLabelsStep,
|
||||||
|
ValidateDataStep,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# TemporalFeatureStep tests
|
||||||
|
def test_temporal_empty(pipeline_context):
|
||||||
|
result = TemporalFeatureStep(pipeline_context).transform(pd.DataFrame())
|
||||||
|
assert 'sessionId' in result.columns
|
||||||
|
assert result.empty
|
||||||
|
|
||||||
|
|
||||||
|
def test_temporal_basic(pipeline_context, session_interactions):
|
||||||
|
result = TemporalFeatureStep(pipeline_context).transform(session_interactions)
|
||||||
|
assert 'session_duration_sec' in result.columns
|
||||||
|
assert 'interaction_velocity' in result.columns
|
||||||
|
assert 'max_velocity_5min' in result.columns
|
||||||
|
assert result['total_interactions'].sum() == len(session_interactions)
|
||||||
|
|
||||||
|
|
||||||
|
def test_temporal_timeout(pipeline_context):
|
||||||
|
df = pd.DataFrame({
|
||||||
|
'sessionId': ['s1', 's1'],
|
||||||
|
'ts': ['2025-01-01T10:00:00Z', '2025-01-01T11:00:00Z'], # 1 hour gap
|
||||||
|
})
|
||||||
|
result = TemporalFeatureStep(pipeline_context, timeout_sec=900).transform(df)
|
||||||
|
assert result.iloc[0]['session_duration_sec'] == 0 # gap exceeds timeout
|
||||||
|
|
||||||
|
|
||||||
|
# BehavioralFeatureStep tests
|
||||||
|
def test_behavioral_empty(pipeline_context):
|
||||||
|
result = BehavioralFeatureStep(pipeline_context).transform(pd.DataFrame())
|
||||||
|
assert 'sessionId' in result.columns
|
||||||
|
|
||||||
|
|
||||||
|
def test_behavioral_counts(pipeline_context, session_interactions):
|
||||||
|
result = BehavioralFeatureStep(pipeline_context).transform(session_interactions)
|
||||||
|
assert 'page_views' in result.columns
|
||||||
|
assert 'item_views' in result.columns
|
||||||
|
assert 'hover_events' in result.columns
|
||||||
|
assert result['total_events'].sum() == len(session_interactions)
|
||||||
|
|
||||||
|
|
||||||
|
def test_behavioral_hover_prefix(pipeline_context):
|
||||||
|
df = pd.DataFrame({
|
||||||
|
'sessionId': ['s1', 's1'],
|
||||||
|
'eventName': ['hover_over_custom', 'hover_over_button'],
|
||||||
|
'page': ['/products', '/products'],
|
||||||
|
})
|
||||||
|
result = BehavioralFeatureStep(pipeline_context).transform(df)
|
||||||
|
assert result.iloc[0]['hover_events'] == 2
|
||||||
|
|
||||||
|
|
||||||
|
# ProductFeatureStep tests
|
||||||
|
def test_product_empty(pipeline_context):
|
||||||
|
result = ProductFeatureStep(pipeline_context).transform(pd.DataFrame())
|
||||||
|
assert 'sessionId' in result.columns
|
||||||
|
|
||||||
|
|
||||||
|
def test_product_features(pipeline_context, session_interactions):
|
||||||
|
result = ProductFeatureStep(pipeline_context).transform(session_interactions)
|
||||||
|
assert 'unique_products_viewed' in result.columns
|
||||||
|
assert 'price_range' in result.columns
|
||||||
|
assert result['unique_products_viewed'].sum() > 0
|
||||||
|
|
||||||
|
|
||||||
|
# UserAgentFeatureStep tests
|
||||||
|
def test_ua_empty(pipeline_context):
|
||||||
|
result = UserAgentFeatureStep(pipeline_context).transform(pd.DataFrame())
|
||||||
|
assert 'sessionId' in result.columns
|
||||||
|
|
||||||
|
|
||||||
|
def test_ua_headless_detection(pipeline_context):
|
||||||
|
df = pd.DataFrame({
|
||||||
|
'sessionId': ['s1', 's2'],
|
||||||
|
'userAgent': ['Mozilla/5.0 Chrome/120', 'HeadlessChrome/120'],
|
||||||
|
})
|
||||||
|
result = UserAgentFeatureStep(pipeline_context).transform(df)
|
||||||
|
assert 'is_headless' in result.columns
|
||||||
|
headless = dict(zip(result['sessionId'], result['is_headless']))
|
||||||
|
assert headless['s1'] == False
|
||||||
|
assert headless['s2'] == True
|
||||||
|
|
||||||
|
|
||||||
|
def test_ua_browser_family(pipeline_context):
|
||||||
|
df = pd.DataFrame({
|
||||||
|
'sessionId': ['s1', 's2', 's3'],
|
||||||
|
'userAgent': ['Mozilla/5.0 Firefox/120', 'Safari/605.1.15', 'Unknown'],
|
||||||
|
})
|
||||||
|
result = UserAgentFeatureStep(pipeline_context).transform(df)
|
||||||
|
browsers = dict(zip(result['sessionId'], result['browser_family']))
|
||||||
|
assert browsers['s1'] == 'Firefox'
|
||||||
|
assert browsers['s2'] == 'Safari'
|
||||||
|
assert browsers['s3'] == 'Other'
|
||||||
|
|
||||||
|
|
||||||
|
def test_ua_automation_detection(pipeline_context):
|
||||||
|
df = pd.DataFrame({
|
||||||
|
'sessionId': ['s1', 's2'],
|
||||||
|
'userAgent': ['Selenium WebDriver', 'Normal Chrome/120'],
|
||||||
|
})
|
||||||
|
result = UserAgentFeatureStep(pipeline_context).transform(df)
|
||||||
|
auto = dict(zip(result['sessionId'], result['is_automation']))
|
||||||
|
assert auto['s1'] == True
|
||||||
|
assert auto['s2'] == False
|
||||||
|
|
||||||
|
|
||||||
|
# ExtractSessionFeaturesStep tests
|
||||||
|
def test_extract_empty(pipeline_context):
|
||||||
|
result = ExtractSessionFeaturesStep(pipeline_context).transform(pd.DataFrame())
|
||||||
|
assert result.empty
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_merges_all(pipeline_context, session_interactions):
|
||||||
|
result = ExtractSessionFeaturesStep(pipeline_context).transform(session_interactions)
|
||||||
|
expected = ['session_duration_sec', 'total_events', 'unique_products_viewed', 'is_headless']
|
||||||
|
for col in expected:
|
||||||
|
assert col in result.columns
|
||||||
|
assert 'experimentId' in result.columns
|
||||||
|
|
||||||
|
|
||||||
|
# JoinLabelsStep tests
|
||||||
|
def test_join_labels_tuple_input(pipeline_context):
|
||||||
|
features = pd.DataFrame({'sessionId': ['s1'], 'experimentId': ['exp1'], 'total_events': [5]})
|
||||||
|
experiments = pd.DataFrame({'id': ['exp1'], 'xp_human_only': [True]})
|
||||||
|
result = JoinLabelsStep(pipeline_context).transform((features, experiments))
|
||||||
|
assert 'is_agent' in result.columns
|
||||||
|
assert result.iloc[0]['is_agent'] == False
|
||||||
|
|
||||||
|
|
||||||
|
def test_join_labels_empty_experiments(pipeline_context):
|
||||||
|
features = pd.DataFrame({'sessionId': ['s1'], 'experimentId': ['exp1']})
|
||||||
|
result = JoinLabelsStep(pipeline_context).transform((features, pd.DataFrame()))
|
||||||
|
assert pd.isna(result.iloc[0]['is_agent'])
|
||||||
|
|
||||||
|
|
||||||
|
# ValidateDataStep tests
|
||||||
|
def test_validate_empty(pipeline_context):
|
||||||
|
ValidateDataStep(pipeline_context).transform(pd.DataFrame())
|
||||||
|
report = pipeline_context.get_cached('validation_report')
|
||||||
|
assert report['status'] == 'empty'
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_missing_cols(pipeline_context):
|
||||||
|
df = pd.DataFrame({'sessionId': ['s1'], 'ts': ['2025-01-01']})
|
||||||
|
ValidateDataStep(pipeline_context).transform(df)
|
||||||
|
report = pipeline_context.get_cached('validation_report')
|
||||||
|
assert report['status'] == 'invalid'
|
||||||
|
assert 'eventName' in report['missing_cols']
|
||||||
|
|
||||||
|
|
||||||
|
def test_validate_valid(pipeline_context, session_interactions):
|
||||||
|
ValidateDataStep(pipeline_context).transform(session_interactions)
|
||||||
|
report = pipeline_context.get_cached('validation_report')
|
||||||
|
assert report['status'] == 'valid'
|
||||||
|
assert report['sessions'] > 0
|
||||||
128
lib/separability.py
Normal file
128
lib/separability.py
Normal file
@@ -0,0 +1,128 @@
|
|||||||
|
"""Utilities for loading separability artifacts and scoring interaction sessions."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, Iterable, List, Sequence
|
||||||
|
|
||||||
|
import joblib
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from experiments.ml.arch import featurize_trajectory
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_ARTIFACT_DIR = Path("data/separability")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SeparabilityArtifacts:
|
||||||
|
scaler: object
|
||||||
|
classifier: object
|
||||||
|
states: List[str]
|
||||||
|
event_transitions: Dict[str, Dict[str, float]]
|
||||||
|
feature_dim: int
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_events(raw_events: Sequence[object]) -> List[object]:
|
||||||
|
events: List[object] = []
|
||||||
|
for evt in raw_events:
|
||||||
|
if hasattr(evt, "value") and hasattr(evt.value, "payload"):
|
||||||
|
events.append(evt.value.payload)
|
||||||
|
else:
|
||||||
|
events.append(evt)
|
||||||
|
events.sort(key=lambda e: getattr(e, "ts", ""))
|
||||||
|
return events
|
||||||
|
|
||||||
|
|
||||||
|
def _event_transition_distribution(events: Sequence[object]) -> Dict[str, Dict[str, float]]:
|
||||||
|
counts: Dict[str, Dict[str, int]] = {}
|
||||||
|
for src_evt, dst_evt in zip(events, events[1:]):
|
||||||
|
src_name = getattr(src_evt, "eventName", "unknown")
|
||||||
|
dst_name = getattr(dst_evt, "eventName", "unknown")
|
||||||
|
counts.setdefault(src_name, {})
|
||||||
|
counts[src_name][dst_name] = counts[src_name].get(dst_name, 0) + 1
|
||||||
|
|
||||||
|
distribution: Dict[str, Dict[str, float]] = {}
|
||||||
|
for src, dsts in counts.items():
|
||||||
|
total = float(sum(dsts.values()))
|
||||||
|
distribution[src] = {dst: val / total for dst, val in dsts.items()} if total else {}
|
||||||
|
return distribution
|
||||||
|
|
||||||
|
|
||||||
|
def _kl_divergence(p: Dict[str, Dict[str, float]], q: Dict[str, Dict[str, float]]) -> float:
|
||||||
|
eps = 1e-10
|
||||||
|
total = 0.0
|
||||||
|
for src, dsts in p.items():
|
||||||
|
for dst, prob in dsts.items():
|
||||||
|
ref = q.get(src, {}).get(dst, 0.0)
|
||||||
|
total += (prob + eps) * np.log((prob + eps) / (ref + eps))
|
||||||
|
return float(total)
|
||||||
|
|
||||||
|
|
||||||
|
def load_artifacts(artifact_dir: Path | str = DEFAULT_ARTIFACT_DIR) -> SeparabilityArtifacts:
|
||||||
|
artifact_dir = Path(artifact_dir)
|
||||||
|
scaler_path = artifact_dir / "scaler.joblib"
|
||||||
|
model_path = artifact_dir / "classifier.joblib"
|
||||||
|
metadata_path = artifact_dir / "metadata.json"
|
||||||
|
|
||||||
|
if not (scaler_path.exists() and model_path.exists() and metadata_path.exists()):
|
||||||
|
raise FileNotFoundError(
|
||||||
|
f"Separability artifacts not found in {artifact_dir}. Run sim.strong_learner.train first."
|
||||||
|
)
|
||||||
|
|
||||||
|
scaler = joblib.load(scaler_path)
|
||||||
|
classifier = joblib.load(model_path)
|
||||||
|
with open(metadata_path, "r", encoding="utf-8") as fin:
|
||||||
|
metadata = json.load(fin)
|
||||||
|
|
||||||
|
return SeparabilityArtifacts(
|
||||||
|
scaler=scaler,
|
||||||
|
classifier=classifier,
|
||||||
|
states=list(metadata["reference_states"]),
|
||||||
|
event_transitions=metadata["event_transitions"],
|
||||||
|
feature_dim=int(metadata["feature_dim"]),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def score_session(
|
||||||
|
raw_events: Sequence[object],
|
||||||
|
artifacts: SeparabilityArtifacts,
|
||||||
|
) -> dict:
|
||||||
|
events = _normalize_events(raw_events)
|
||||||
|
if not events:
|
||||||
|
return {"prob_agent": 0.0, "delta_h": 0.0, "delta_a": 0.0}
|
||||||
|
|
||||||
|
reference_mdp = {"states": artifacts.states}
|
||||||
|
features = featurize_trajectory(events, mdp=reference_mdp, input_dim=artifacts.feature_dim)
|
||||||
|
scaled = artifacts.scaler.transform(features.reshape(1, -1))
|
||||||
|
prob_agent = float(artifacts.classifier.predict_proba(scaled)[0, 1])
|
||||||
|
|
||||||
|
session_dist = _event_transition_distribution(events)
|
||||||
|
delta_h = _kl_divergence(session_dist, artifacts.event_transitions.get("human", {}))
|
||||||
|
delta_a = _kl_divergence(session_dist, artifacts.event_transitions.get("agent", {}))
|
||||||
|
|
||||||
|
return {
|
||||||
|
"prob_agent": prob_agent,
|
||||||
|
"delta_h": delta_h,
|
||||||
|
"delta_a": delta_a,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def estimate_alpha(prob_agent: float, delta_h: float, delta_a: float, temperature: float = 1.0) -> float:
|
||||||
|
divergence_mass = delta_h + delta_a
|
||||||
|
if divergence_mass <= 1e-8:
|
||||||
|
return float(prob_agent)
|
||||||
|
|
||||||
|
ratio = delta_a / divergence_mass
|
||||||
|
blended = 0.5 * prob_agent + 0.5 * ratio
|
||||||
|
if temperature <= 0:
|
||||||
|
return float(np.clip(blended, 0.0, 1.0))
|
||||||
|
|
||||||
|
scaled = 1.0 / (1.0 + np.exp(-temperature * (blended - 0.5)))
|
||||||
|
return float(np.clip(scaled, 0.0, 1.0))
|
||||||
|
|
||||||
|
|
||||||
|
def score_sessions(raw_sessions: Iterable[Sequence[object]], artifacts: SeparabilityArtifacts) -> List[dict]:
|
||||||
|
return [score_session(events, artifacts) for events in raw_sessions]
|
||||||
69
paper/src/chapters/slacberger.tex
Normal file
69
paper/src/chapters/slacberger.tex
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
|
||||||
|
\section{Problem Formulation: A Stackelberg Game Approach}
|
||||||
|
\label{sec:math_formulation}
|
||||||
|
|
||||||
|
We formalize the interaction between the dynamic pricing system and non-human actors as a \textit{Stackelberg Game} (Leader-Follower) with incomplete information. This framework captures the hierarchical nature of the problem: the Platform (Leader) sets a pricing policy, and the Actors (Followers)---both Humans and Agents---observe these prices and react strategically.
|
||||||
|
|
||||||
|
\subsection{The Players and Objectives}
|
||||||
|
|
||||||
|
Let $t \in \{1, \dots, T\}$ denote discrete time steps. At each step, the system interactions are defined by the following entities:
|
||||||
|
|
||||||
|
\paragraph{1. The Leader (The Platform)}
|
||||||
|
The e-commerce platform acts as the leader, choosing a pricing policy $\pi$ to maximize total expected revenue. At time $t$, given a state $s_t \in \mathcal{S}$ (representing inventory, time of day, and historical interactions), the platform sets a price $p_t \in [p_{\min}, p_{\max}]$.
|
||||||
|
|
||||||
|
The platform's goal is to maximize the cumulative revenue from genuine human transactions while mitigating the distortion caused by agent interactions.
|
||||||
|
|
||||||
|
\paragraph{2. The Followers (The Demand Mixture)}
|
||||||
|
The observed demand is not a monolithic signal but a mixture of two distinct populations with divergent objective functions. Let $u$ denote an incoming actor. The type of the actor $\theta \in \{H, A\}$ is a latent variable, where $H$ denotes a Human and $A$ denotes an Agent.
|
||||||
|
|
||||||
|
\begin{itemize}
|
||||||
|
\item \textbf{The Human ($H$):} Acts as a \textit{myopic utility maximizer}. A human $i$ has a private valuation $v_i$ for the product. They execute a purchase decision $d_i \in \{0, 1\}$ based on the consumer surplus:
|
||||||
|
\begin{equation}
|
||||||
|
d_i(p_t) = \mathbb{I}(v_i - p_t \geq 0)
|
||||||
|
\end{equation}
|
||||||
|
where $\mathbb{I}(\cdot)$ is the indicator function. The aggregate human demand $q_H(p_t)$ follows a standard downward-sloping demand curve $D(p_t)$.
|
||||||
|
|
||||||
|
\item \textbf{The Agent ($A$):} Acts as an \textit{information maximizer} (reconnaissance). The agent does not intend to purchase at the displayed price $p_t$ unless an arbitrage condition is met. Instead, the agent generates interaction events (queries) to estimate the platform's pricing function $f(p)$. The agent's reward function $R_A$ is defined by Information Gain:
|
||||||
|
\begin{equation}
|
||||||
|
R_A(p_t) = H(\mathcal{P}) - H(\mathcal{P} \mid p_t) - c_{query}
|
||||||
|
\end{equation}
|
||||||
|
where $H(\mathcal{P})$ is the entropy of the agent's belief regarding the price distribution, and $c_{query}$ is the marginal cost of interaction (assumed $\approx 0$ for LLMs).
|
||||||
|
\end{itemize}
|
||||||
|
|
||||||
|
\subsection{The Demand Contamination Model}
|
||||||
|
|
||||||
|
% MAYBE alpha has to be \lambda which we also need to formally define still
|
||||||
|
|
||||||
|
The core difficulty in this setting is that the platform observes only the aggregate interaction volume $\hat{q}_t$, which is a contaminated signal. Let $\alpha_t \in [0, 1]$ represent the proportion of traffic generated by agents at time $t$. The observed signal is:
|
||||||
|
|
||||||
|
\begin{equation}
|
||||||
|
\hat{q}_t(p_t) = (1 - \alpha_t) \cdot q_H(p_t) + \alpha_t \cdot q_A(p_t) + \epsilon_t
|
||||||
|
\end{equation}
|
||||||
|
|
||||||
|
where:
|
||||||
|
\begin{itemize}
|
||||||
|
\item $q_H(p_t)$ is the \textit{true signal} (conversion intent).
|
||||||
|
\item $q_A(p_t)$ is the \textit{adversarial noise} (reconnaissance queries).
|
||||||
|
\item $\epsilon_t$ is random market noise.
|
||||||
|
\end{itemize}
|
||||||
|
|
||||||
|
Crucially, $q_A(p_t)$ is often inversely correlated with $q_H(p_t)$ in terms of utility; agents may flood the system with queries during high-volatility periods to map price boundaries, artificially inflating $\hat{q}_t$ without converting.
|
||||||
|
|
||||||
|
\subsection{The Optimization Objective: Robust Revenue}
|
||||||
|
|
||||||
|
Standard dynamic pricing algorithms (e.g., Thompson Sampling or UCB) assume $\alpha_t = 0$, estimating demand $\hat{D}(p) \approx \mathbb{E}[\hat{q} | p]$. In the presence of agents ($\alpha_t > 0$), this estimator becomes biased, leading to the \textit{Cost of Information} (COI) defined in Section 3.2.
|
||||||
|
|
||||||
|
We propose a robust optimization objective. The platform seeks a pricing policy $\pi^*$ that maximizes worst-case revenue over a statistically plausible set of contamination rates $\alpha$:
|
||||||
|
|
||||||
|
\begin{equation}
|
||||||
|
\pi^* = \argmax_{\pi} \sum_{t=1}^T \mathbb{E}_{s_t} \left[ \min_{\alpha} \left( p_t \cdot \hat{q}_t(p_t | \theta=H) \right) - \lambda \cdot \mathcal{L}_{detect}(\hat{q}_t) \right]
|
||||||
|
\end{equation}
|
||||||
|
|
||||||
|
Here:
|
||||||
|
\begin{itemize}
|
||||||
|
\item The first term, $p_t \cdot \hat{q}_t(p_t | \theta=H)$, represents the revenue generated strictly from the estimated human segment.
|
||||||
|
\item $\mathcal{L}_{detect}$ is a penalty term for failing to separate distributions (the cost of confusion).
|
||||||
|
\item $\lambda$ is a hyperparameter balancing revenue exploitation vs. robust detection.
|
||||||
|
\end{itemize}
|
||||||
|
|
||||||
|
This formulation effectively transforms the pricing problem into a \textit{Distributionally Robust Optimization (DRO)} problem, where the learner must guard against adversarial perturbations (Agent traffic) in the observed demand distribution.
|
||||||
BIN
paper/src/graphics/gcp.png
Normal file
BIN
paper/src/graphics/gcp.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 19 KiB |
BIN
paper/src/graphics/gcp.webp
Normal file
BIN
paper/src/graphics/gcp.webp
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 8.3 KiB |
12
pyproject.toml
Normal file
12
pyproject.toml
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
[build-system]
|
||||||
|
requires = ["setuptools>=45", "wheel"]
|
||||||
|
build-backend = "setuptools.build_meta"
|
||||||
|
|
||||||
|
[project]
|
||||||
|
name = "phantom"
|
||||||
|
version = "0.1.0"
|
||||||
|
description = "Pricing Heuristics Against Non-human Transaction Orchestration Mechanisms"
|
||||||
|
requires-python = ">=3.8"
|
||||||
|
|
||||||
|
[tool.setuptools.packages.find]
|
||||||
|
include = ["experiments*", "lib*"]
|
||||||
32
scripts/tpu_pod_run.sh
Executable file
32
scripts/tpu_pod_run.sh
Executable file
@@ -0,0 +1,32 @@
|
|||||||
|
#!/usr/bin/env sh
|
||||||
|
# Executed on each TPU pod worker via `gcloud tpu-vm scp` + `gcloud tpu-vm ssh --worker=all`.
|
||||||
|
# Authenticates with Artifact Registry using the VM's service account metadata token,
|
||||||
|
# pulls the TPU trainer image, then runs the W&B sweep agent inside Docker.
|
||||||
|
# TPU chip devices (/dev/accel*) are exposed via --privileged + /dev volume mount.
|
||||||
|
# Required env vars: WANDB_API_KEY, SWEEP_ID
|
||||||
|
# Optional: AGENT_COUNT (default 1, 0 = run until sweep ends)
|
||||||
|
set -eu
|
||||||
|
|
||||||
|
IMAGE="us-central1-docker.pkg.dev/phantom-trc/phantom/phantom-trainer:tpu-latest"
|
||||||
|
AGENT_COUNT="${AGENT_COUNT:-1}"
|
||||||
|
|
||||||
|
# use VM service account — no manual key needed on the pod
|
||||||
|
TOKEN=$(curl -sf -H "Metadata-Flavor: Google" \
|
||||||
|
"http://metadata.google.internal/computeMetadata/v1/instance/service-accounts/default/token" \
|
||||||
|
| python3 -c 'import sys, json; print(json.load(sys.stdin)["access_token"])')
|
||||||
|
|
||||||
|
echo "$TOKEN" | sudo docker login -u oauth2accesstoken \
|
||||||
|
--password-stdin https://us-central1-docker.pkg.dev
|
||||||
|
|
||||||
|
sudo docker pull "$IMAGE"
|
||||||
|
|
||||||
|
# --privileged + /dev mount gives the container access to /dev/accel* (TPU chips)
|
||||||
|
# --network host lets JAX reach the other pod workers for distributed init
|
||||||
|
sudo docker run --rm \
|
||||||
|
--privileged \
|
||||||
|
--network host \
|
||||||
|
--volume /dev:/dev \
|
||||||
|
-e WANDB_API_KEY="$WANDB_API_KEY" \
|
||||||
|
-e SWEEP_ID="$SWEEP_ID" \
|
||||||
|
-e AGENT_COUNT="$AGENT_COUNT" \
|
||||||
|
"$IMAGE"
|
||||||
83
scripts/tpu_sync_repo.sh
Normal file
83
scripts/tpu_sync_repo.sh
Normal file
@@ -0,0 +1,83 @@
|
|||||||
|
#!/usr/bin/env sh
|
||||||
|
set -eu
|
||||||
|
|
||||||
|
TPU_NAME="${TPU_NAME:?TPU_NAME is required}"
|
||||||
|
TPU_ZONE="${TPU_ZONE:-us-central2-b}"
|
||||||
|
TPU_PROJECT="${TPU_PROJECT:-phantom-trc}"
|
||||||
|
LOCAL_REPO_DIR="${LOCAL_REPO_DIR:-$(pwd)}"
|
||||||
|
REMOTE_REPO_DIR="${REMOTE_REPO_DIR:-/tmp/PHANTOM}"
|
||||||
|
ARCHIVE_PATH="${ARCHIVE_PATH:-/tmp/phantom-sync.tgz}"
|
||||||
|
|
||||||
|
FILE_LIST="$(mktemp /tmp/phantom-sync-files.XXXXXX)"
|
||||||
|
CLEANUP_LIST=true
|
||||||
|
|
||||||
|
cleanup() {
|
||||||
|
if [ "$CLEANUP_LIST" = "true" ]; then
|
||||||
|
rm -f "$FILE_LIST"
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
trap cleanup EXIT
|
||||||
|
|
||||||
|
if [ ! -d "$LOCAL_REPO_DIR" ]; then
|
||||||
|
echo "local repo directory not found: $LOCAL_REPO_DIR"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
if git -C "$LOCAL_REPO_DIR" rev-parse --is-inside-work-tree >/dev/null 2>&1; then
|
||||||
|
git -C "$LOCAL_REPO_DIR" ls-files -co --exclude-standard > "$FILE_LIST"
|
||||||
|
python3 - "$FILE_LIST" <<'PY'
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
file_list = Path(sys.argv[1])
|
||||||
|
skip_prefixes = (
|
||||||
|
"wandb/",
|
||||||
|
".venv/",
|
||||||
|
"venv/",
|
||||||
|
"node_modules/",
|
||||||
|
".next/",
|
||||||
|
".turbo/",
|
||||||
|
"__pycache__/",
|
||||||
|
".mypy_cache/",
|
||||||
|
".pytest_cache/",
|
||||||
|
".ruff_cache/",
|
||||||
|
"paper/build/",
|
||||||
|
"tests/e2e/test-results/",
|
||||||
|
)
|
||||||
|
|
||||||
|
rows = file_list.read_text().splitlines()
|
||||||
|
kept = [
|
||||||
|
row
|
||||||
|
for row in rows
|
||||||
|
if row and not any(row == p.rstrip("/") or row.startswith(p) for p in skip_prefixes)
|
||||||
|
]
|
||||||
|
file_list.write_text("\n".join(kept) + ("\n" if kept else ""))
|
||||||
|
PY
|
||||||
|
tar -czf "$ARCHIVE_PATH" -C "$LOCAL_REPO_DIR" -T "$FILE_LIST"
|
||||||
|
else
|
||||||
|
tar \
|
||||||
|
--exclude-vcs \
|
||||||
|
--exclude=".venv" --exclude="*/.venv" \
|
||||||
|
--exclude="venv" --exclude="*/venv" \
|
||||||
|
--exclude="node_modules" --exclude="*/node_modules" \
|
||||||
|
--exclude=".next" --exclude="*/.next" \
|
||||||
|
--exclude=".turbo" --exclude="*/.turbo" \
|
||||||
|
--exclude="__pycache__" --exclude="*/__pycache__" \
|
||||||
|
--exclude=".mypy_cache" --exclude="*/.mypy_cache" \
|
||||||
|
--exclude=".pytest_cache" --exclude="*/.pytest_cache" \
|
||||||
|
--exclude=".ruff_cache" --exclude="*/.ruff_cache" \
|
||||||
|
--exclude="wandb" --exclude="*/wandb" \
|
||||||
|
--exclude="paper/build" \
|
||||||
|
--exclude="tests/e2e/test-results" \
|
||||||
|
-czf "$ARCHIVE_PATH" \
|
||||||
|
-C "$LOCAL_REPO_DIR" .
|
||||||
|
fi
|
||||||
|
|
||||||
|
gcloud compute tpus tpu-vm scp "$ARCHIVE_PATH" "$TPU_NAME:/tmp/phantom-sync.tgz" \
|
||||||
|
--zone="$TPU_ZONE" --project="$TPU_PROJECT" --worker=all
|
||||||
|
|
||||||
|
gcloud compute tpus tpu-vm ssh "$TPU_NAME" \
|
||||||
|
--zone="$TPU_ZONE" --project="$TPU_PROJECT" --worker=all \
|
||||||
|
--command="rm -rf '$REMOTE_REPO_DIR' && mkdir -p '$REMOTE_REPO_DIR' && tar -xzf /tmp/phantom-sync.tgz -C '$REMOTE_REPO_DIR' && rm -f /tmp/phantom-sync.tgz"
|
||||||
|
|
||||||
|
rm -f "$ARCHIVE_PATH"
|
||||||
183
scripts/tpu_vm_sweep_agent.py
Normal file
183
scripts/tpu_vm_sweep_agent.py
Normal file
@@ -0,0 +1,183 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import shlex
|
||||||
|
import subprocess
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import wandb
|
||||||
|
|
||||||
|
|
||||||
|
CLI_MAP: dict[str, str] = {
|
||||||
|
"algo": "--algo",
|
||||||
|
"total_timesteps": "--total-timesteps",
|
||||||
|
"alpha": "--alpha",
|
||||||
|
"N": "--N",
|
||||||
|
"n_products": "--n-products",
|
||||||
|
"lambda_coi": "--lambda-coi",
|
||||||
|
"info_value": "--info-value",
|
||||||
|
"robust_radius": "--robust-radius",
|
||||||
|
"robust_points": "--robust-points",
|
||||||
|
"learning_rate": "--learning-rate",
|
||||||
|
"gamma": "--gamma",
|
||||||
|
"gae_lambda": "--gae-lambda",
|
||||||
|
"clip_range": "--clip-range",
|
||||||
|
"ent_coef": "--ent-coef",
|
||||||
|
"revenue_weight": "--revenue-weight",
|
||||||
|
"max_steps": "--max-steps",
|
||||||
|
"margin_floor": "--margin-floor",
|
||||||
|
"margin_floor_patience": "--margin-floor-patience",
|
||||||
|
"arch": "--arch",
|
||||||
|
"activation": "--activation",
|
||||||
|
"jax_num_envs": "--jax-num-envs",
|
||||||
|
"jax_num_steps": "--jax-num-steps",
|
||||||
|
"jax_num_minibatches": "--jax-num-minibatches",
|
||||||
|
"jax_update_epochs": "--jax-update-epochs",
|
||||||
|
"jax_anneal_lr": "--jax-anneal-lr",
|
||||||
|
"checkpoint_interval": "--checkpoint-interval",
|
||||||
|
"action_levels": "--action-levels",
|
||||||
|
"action_scale_low": "--action-scale-low",
|
||||||
|
"action_scale_high": "--action-scale-high",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _to_cli_args(cfg: dict) -> str:
|
||||||
|
parts: list[str] = ["--jax", "--no-wandb"]
|
||||||
|
for key, flag in CLI_MAP.items():
|
||||||
|
if key not in cfg:
|
||||||
|
continue
|
||||||
|
value = cfg[key]
|
||||||
|
if value is None:
|
||||||
|
continue
|
||||||
|
if isinstance(value, bool):
|
||||||
|
if key == "jax_anneal_lr":
|
||||||
|
parts.extend([flag, "true" if value else "false"])
|
||||||
|
elif value:
|
||||||
|
parts.append(flag)
|
||||||
|
continue
|
||||||
|
parts.extend([flag, str(value)])
|
||||||
|
return " ".join(shlex.quote(p) for p in parts)
|
||||||
|
|
||||||
|
|
||||||
|
_SENTINEL = "PHANTOM_METRICS:"
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_metrics(output: str) -> dict:
|
||||||
|
# fast path: look for the dedicated sentinel line emitted by run_local
|
||||||
|
for line in output.splitlines():
|
||||||
|
if line.startswith(_SENTINEL):
|
||||||
|
try:
|
||||||
|
return json.loads(line[len(_SENTINEL) :])
|
||||||
|
except Exception:
|
||||||
|
break
|
||||||
|
# fallback: scan for any JSON block containing eval/sweep keys;
|
||||||
|
# use greedy match to capture the largest possible block first
|
||||||
|
for block in re.findall(r"\{[^{}]*\}", output):
|
||||||
|
try:
|
||||||
|
obj = json.loads(block)
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
if isinstance(obj, dict) and ("sweep/score" in obj or "eval/reward" in obj):
|
||||||
|
return obj
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
p = argparse.ArgumentParser(
|
||||||
|
description="Run W&B sweep where each trial uses full TPU pod"
|
||||||
|
)
|
||||||
|
p.add_argument("--sweep-id", required=True)
|
||||||
|
p.add_argument("--tpu-name", required=True)
|
||||||
|
p.add_argument("--tpu-zone", default="us-central2-b")
|
||||||
|
p.add_argument("--tpu-project", default="phantom-trc")
|
||||||
|
p.add_argument("--tpu-repo-dir", default="/tmp/PHANTOM")
|
||||||
|
p.add_argument("--count", type=int, default=0)
|
||||||
|
p.add_argument("--workdir", default=str(Path(__file__).resolve().parents[1]))
|
||||||
|
args = p.parse_args()
|
||||||
|
|
||||||
|
workdir = Path(args.workdir).resolve()
|
||||||
|
env = os.environ.copy()
|
||||||
|
|
||||||
|
prepare_cmd = [
|
||||||
|
"make",
|
||||||
|
"train.tpu.vm.prepare",
|
||||||
|
f"TPU_NAME={args.tpu_name}",
|
||||||
|
f"TPU_ZONE={args.tpu_zone}",
|
||||||
|
f"TPU_PROJECT={args.tpu_project}",
|
||||||
|
f"TPU_REPO_DIR={args.tpu_repo_dir}",
|
||||||
|
]
|
||||||
|
prepare = subprocess.run(
|
||||||
|
prepare_cmd,
|
||||||
|
cwd=workdir,
|
||||||
|
env=env,
|
||||||
|
text=True,
|
||||||
|
capture_output=False,
|
||||||
|
check=False,
|
||||||
|
)
|
||||||
|
if prepare.returncode != 0:
|
||||||
|
raise RuntimeError("Failed to prepare TPU workers for sweep")
|
||||||
|
|
||||||
|
def run_trial() -> None:
|
||||||
|
run = None
|
||||||
|
try:
|
||||||
|
run = wandb.init()
|
||||||
|
cfg = dict(wandb.config)
|
||||||
|
cli_args = _to_cli_args(cfg)
|
||||||
|
env_trial = dict(env)
|
||||||
|
env_trial["LOCAL_TRAIN_ARGS"] = cli_args
|
||||||
|
|
||||||
|
cmd = [
|
||||||
|
"make",
|
||||||
|
"train.tpu.vm.run",
|
||||||
|
f"TPU_NAME={args.tpu_name}",
|
||||||
|
f"TPU_ZONE={args.tpu_zone}",
|
||||||
|
f"TPU_PROJECT={args.tpu_project}",
|
||||||
|
f"TPU_REPO_DIR={args.tpu_repo_dir}",
|
||||||
|
]
|
||||||
|
|
||||||
|
proc = subprocess.run(
|
||||||
|
cmd,
|
||||||
|
cwd=workdir,
|
||||||
|
env=env_trial,
|
||||||
|
text=True,
|
||||||
|
capture_output=True,
|
||||||
|
check=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
if proc.stdout:
|
||||||
|
print(proc.stdout)
|
||||||
|
if proc.stderr:
|
||||||
|
print(proc.stderr)
|
||||||
|
|
||||||
|
if proc.returncode != 0:
|
||||||
|
if run is not None:
|
||||||
|
run.summary["runner/exit_code"] = proc.returncode
|
||||||
|
raise RuntimeError(f"TPU trial failed with exit code {proc.returncode}")
|
||||||
|
|
||||||
|
metrics = _extract_metrics(proc.stdout)
|
||||||
|
if metrics:
|
||||||
|
wandb.log(metrics)
|
||||||
|
for k, v in metrics.items():
|
||||||
|
run.summary[k] = v
|
||||||
|
run.summary["runner/exit_code"] = 0
|
||||||
|
except Exception:
|
||||||
|
time.sleep(2)
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
if run is not None and wandb.run is not None:
|
||||||
|
wandb.finish()
|
||||||
|
|
||||||
|
wandb.agent(
|
||||||
|
args.sweep_id,
|
||||||
|
function=run_trial,
|
||||||
|
count=args.count if args.count > 0 else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
43
scripts/tpu_vm_train.sh
Normal file
43
scripts/tpu_vm_train.sh
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
#!/usr/bin/env sh
|
||||||
|
set -eu
|
||||||
|
|
||||||
|
REPO_DIR="${REPO_DIR:-$HOME/PHANTOM}"
|
||||||
|
PYTHON_BIN="${PYTHON_BIN:-python3}"
|
||||||
|
TRAIN_ARGS="${TRAIN_ARGS:---algo ppo --jax --total-timesteps 200000 --jax-num-envs 32 --jax-num-steps 128 --jax-num-minibatches 4 --jax-update-epochs 4}"
|
||||||
|
EXTRA_PIP="${EXTRA_PIP:-flax optax distrax}"
|
||||||
|
INSTALL_FULL_REQUIREMENTS="${INSTALL_FULL_REQUIREMENTS:-0}"
|
||||||
|
|
||||||
|
if [ ! -d "$REPO_DIR" ]; then
|
||||||
|
echo "repo directory not found: $REPO_DIR"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
cd "$REPO_DIR"
|
||||||
|
|
||||||
|
if [ -d "wandb" ]; then
|
||||||
|
rm -rf wandb
|
||||||
|
fi
|
||||||
|
|
||||||
|
# keep install idempotent and avoid re-installing jax/libtpu each run
|
||||||
|
if [ "$INSTALL_FULL_REQUIREMENTS" = "1" ] && [ -f "requirements.txt" ]; then
|
||||||
|
$PYTHON_BIN -m pip install -r requirements.txt
|
||||||
|
fi
|
||||||
|
if ! $PYTHON_BIN -c 'import flax, optax, distrax' >/dev/null 2>&1; then
|
||||||
|
if [ -f "engine/jax/requirements.txt" ]; then
|
||||||
|
$PYTHON_BIN -m pip install -r engine/jax/requirements.txt
|
||||||
|
fi
|
||||||
|
$PYTHON_BIN -m pip install -U $EXTRA_PIP
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ -n "${WANDB_API_KEY:-}" ]; then
|
||||||
|
if ! $PYTHON_BIN -c 'import wandb; import inspect; assert hasattr(wandb, "init") and callable(wandb.init)' >/dev/null 2>&1; then
|
||||||
|
$PYTHON_BIN -m pip install -U wandb
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ -n "${WANDB_API_KEY:-}" ]; then
|
||||||
|
export WANDB_API_KEY
|
||||||
|
exec $PYTHON_BIN -m engine.train $TRAIN_ARGS
|
||||||
|
fi
|
||||||
|
|
||||||
|
exec $PYTHON_BIN -m engine.train $TRAIN_ARGS --no-wandb
|
||||||
108
scripts/wandb_agent_bootstrap.sh
Executable file
108
scripts/wandb_agent_bootstrap.sh
Executable file
@@ -0,0 +1,108 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
set -euo pipefail
|
||||||
|
|
||||||
|
need_env() {
|
||||||
|
local name="$1"
|
||||||
|
if [ -z "${!name:-}" ]; then
|
||||||
|
echo "$name is required"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
need_cmd() {
|
||||||
|
local c="$1"
|
||||||
|
command -v "$c" >/dev/null 2>&1 || {
|
||||||
|
echo "Missing command: $c"
|
||||||
|
exit 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
need_cmd git
|
||||||
|
need_cmd python3
|
||||||
|
|
||||||
|
need_env WANDB_API_KEY
|
||||||
|
need_env GITHUB_TOKEN
|
||||||
|
need_env REPO_URL
|
||||||
|
need_env SWEEP_ID
|
||||||
|
|
||||||
|
BRANCH="${BRANCH:-main}"
|
||||||
|
WORKDIR="${WORKDIR:-$HOME/PHANTOM-agent}"
|
||||||
|
AGENT_COUNT="${AGENT_COUNT:-0}"
|
||||||
|
AGENT_LOOP="${AGENT_LOOP:-1}"
|
||||||
|
RETRY_SECONDS="${RETRY_SECONDS:-20}"
|
||||||
|
PYTHON_BIN="${PYTHON_BIN:-python3}"
|
||||||
|
|
||||||
|
mkdir -p "$(dirname "$WORKDIR")"
|
||||||
|
|
||||||
|
ASKPASS_FILE="$(mktemp)"
|
||||||
|
cat >"$ASKPASS_FILE" <<'EOF'
|
||||||
|
#!/usr/bin/env sh
|
||||||
|
case "$1" in
|
||||||
|
*Username*) echo "x-access-token" ;;
|
||||||
|
*Password*) echo "$GITHUB_TOKEN" ;;
|
||||||
|
*) echo "" ;;
|
||||||
|
esac
|
||||||
|
EOF
|
||||||
|
chmod 700 "$ASKPASS_FILE"
|
||||||
|
|
||||||
|
cleanup() {
|
||||||
|
rm -f "$ASKPASS_FILE"
|
||||||
|
}
|
||||||
|
trap cleanup EXIT
|
||||||
|
|
||||||
|
git_auth() {
|
||||||
|
GIT_TERMINAL_PROMPT=0 GIT_ASKPASS="$ASKPASS_FILE" git "$@"
|
||||||
|
}
|
||||||
|
|
||||||
|
sync_repo() {
|
||||||
|
if [ ! -d "$WORKDIR/.git" ]; then
|
||||||
|
rm -rf "$WORKDIR"
|
||||||
|
git_auth clone --single-branch --branch "$BRANCH" "$REPO_URL" "$WORKDIR"
|
||||||
|
return
|
||||||
|
fi
|
||||||
|
|
||||||
|
git -C "$WORKDIR" remote set-url origin "$REPO_URL"
|
||||||
|
git_auth -C "$WORKDIR" fetch origin "$BRANCH" --prune
|
||||||
|
git -C "$WORKDIR" checkout -B "$BRANCH" "origin/$BRANCH"
|
||||||
|
git -C "$WORKDIR" reset --hard "origin/$BRANCH"
|
||||||
|
}
|
||||||
|
|
||||||
|
install_deps() {
|
||||||
|
"$PYTHON_BIN" -m venv "$WORKDIR/.venv"
|
||||||
|
"$WORKDIR/.venv/bin/pip" install --upgrade pip
|
||||||
|
"$WORKDIR/.venv/bin/pip" install -r "$WORKDIR/requirements.txt"
|
||||||
|
}
|
||||||
|
|
||||||
|
run_agent() {
|
||||||
|
local cmd=("$WORKDIR/.venv/bin/python" -m engine.train --sweep-agent --sweep-id "$SWEEP_ID")
|
||||||
|
if [ "$AGENT_COUNT" != "0" ]; then
|
||||||
|
cmd+=(--count "$AGENT_COUNT")
|
||||||
|
fi
|
||||||
|
|
||||||
|
(
|
||||||
|
cd "$WORKDIR"
|
||||||
|
WANDB_API_KEY="$WANDB_API_KEY" \
|
||||||
|
WANDB_ENTITY="${WANDB_ENTITY:-}" \
|
||||||
|
WANDB_PROJECT="${WANDB_PROJECT:-}" \
|
||||||
|
"${cmd[@]}"
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
while true; do
|
||||||
|
sync_repo
|
||||||
|
install_deps
|
||||||
|
|
||||||
|
if run_agent; then
|
||||||
|
if [ "$AGENT_LOOP" = "1" ] && [ "$AGENT_COUNT" = "0" ]; then
|
||||||
|
sleep "$RETRY_SECONDS"
|
||||||
|
continue
|
||||||
|
fi
|
||||||
|
exit 0
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ "$AGENT_LOOP" != "1" ]; then
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
sleep "$RETRY_SECONDS"
|
||||||
|
done
|
||||||
7
sim/requirements.txt
Normal file
7
sim/requirements.txt
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
gymnasium>=0.29.0
|
||||||
|
numpy>=1.24.0
|
||||||
|
pandas>=2.0.0
|
||||||
|
stable-baselines3>=2.2.0
|
||||||
|
tensorboard>=2.15.0
|
||||||
|
jax>=0.4.20
|
||||||
|
jaxlib>=0.4.20
|
||||||
@@ -254,3 +254,5 @@ if __name__ == "__main__":
|
|||||||
f"{sum(len(t) for t in joint_mdp['transitions'].values())} transitions")
|
f"{sum(len(t) for t in joint_mdp['transitions'].values())} transitions")
|
||||||
if joint_mdp['states']:
|
if joint_mdp['states']:
|
||||||
visualize_mdp(joint_model, threshold=0.05, output="joint_mdp_viz", fmt="pdf", export_dot=True)
|
visualize_mdp(joint_model, threshold=0.05, output="joint_mdp_viz", fmt="pdf", export_dot=True)
|
||||||
|
|
||||||
|
# TODO: setup intra class divergence as baseline for evaluating and adding significance to the divergence which we observe across class
|
||||||
|
|||||||
117
sim/rl/behavior_loader/visualize_kl.py
Normal file
117
sim/rl/behavior_loader/visualize_kl.py
Normal file
@@ -0,0 +1,117 @@
|
|||||||
|
import numpy as np
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from collections import defaultdict
|
||||||
|
from models import BehaviorModel, AgentBehaviorModel, aggregate_event_transitions, kl_divergence
|
||||||
|
|
||||||
|
def event_frequency_distribution(mdp):
|
||||||
|
evt_cnt, total = defaultdict(int), 0
|
||||||
|
for s, trans in mdp['transitions'].items():
|
||||||
|
evt = s.split('|')[2]
|
||||||
|
for cnt in mdp['trans_counts'][s].values():
|
||||||
|
evt_cnt[evt] += cnt
|
||||||
|
total += cnt
|
||||||
|
return {evt: cnt/total for evt, cnt in evt_cnt.items()} if total > 0 else {}
|
||||||
|
|
||||||
|
def transition_distribution(mdp):
|
||||||
|
trans_cnt, total = defaultdict(int), 0
|
||||||
|
for s, trans in mdp['trans_counts'].items():
|
||||||
|
src = s.split('|')[2]
|
||||||
|
for s_next, cnt in trans.items():
|
||||||
|
dst = s_next.split('|')[2]
|
||||||
|
trans_cnt[f"{src}->{dst}"] += cnt
|
||||||
|
total += cnt
|
||||||
|
return {t: cnt/total for t, cnt in trans_cnt.items()} if total > 0 else {}
|
||||||
|
|
||||||
|
def kl_color(kl):
|
||||||
|
return '#d62828' if kl > 2.0 else '#f77f00' if kl > 0.5 else '#2a9d8f'
|
||||||
|
|
||||||
|
def plot_comparison(ax, human_vals, agent_vals, labels, title, ylabel, kl_val=None):
|
||||||
|
x, w = np.arange(len(labels)), 0.35
|
||||||
|
ax.bar(x - w/2, human_vals, w, label='Human', alpha=0.8, color='#2E86AB')
|
||||||
|
ax.bar(x + w/2, agent_vals, w, label='Agent', alpha=0.8, color='#A23B72')
|
||||||
|
ax.set_ylabel(ylabel, fontsize=9 if len(labels) > 10 else 11, fontweight='bold')
|
||||||
|
ax.set_title(title if not kl_val else f"{title}\nKL={kl_val:.4f}",
|
||||||
|
fontsize=10 if len(labels) > 10 else 12, fontweight='bold')
|
||||||
|
ax.set_xticks(x)
|
||||||
|
ax.set_xticklabels(labels, rotation=45, ha='right', fontsize=8)
|
||||||
|
ax.legend(fontsize=8)
|
||||||
|
ax.grid(axis='y', alpha=0.3, linestyle='--')
|
||||||
|
return ax
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
base_dir = "/home/velocitatem/Documents/Projects/PHANTOM/experiments"
|
||||||
|
human_dir, agent_dir = f"{base_dir}/collected_data/", f"{base_dir}/agents/collected_data/"
|
||||||
|
|
||||||
|
human_model, agent_model = BehaviorModel(human_dir), AgentBehaviorModel(agent_dir)
|
||||||
|
human_mdp, agent_mdp = human_model.build_MDP(), agent_model.build_MDP()
|
||||||
|
|
||||||
|
human_evt, agent_evt = aggregate_event_transitions(human_mdp), aggregate_event_transitions(agent_mdp)
|
||||||
|
common = set(human_evt.keys()) & set(agent_evt.keys())
|
||||||
|
kl_results = sorted([(e, kl_divergence(human_evt[e], agent_evt[e])) for e in common],
|
||||||
|
key=lambda x: x[1], reverse=True)
|
||||||
|
|
||||||
|
fig = plt.figure(figsize=(16, 10))
|
||||||
|
n_rows, n_cols = (len(kl_results) + 1) // 2, 2
|
||||||
|
|
||||||
|
for idx, (evt, kl) in enumerate(kl_results):
|
||||||
|
ax = plt.subplot(n_rows, n_cols, idx + 1)
|
||||||
|
h_dist, a_dist = human_evt.get(evt, {}), agent_evt.get(evt, {})
|
||||||
|
dests = sorted(set(h_dist.keys()) | set(a_dist.keys()))
|
||||||
|
if not dests: continue
|
||||||
|
|
||||||
|
h_probs, a_probs = [h_dist.get(d, 0) for d in dests], [a_dist.get(d, 0) for d in dests]
|
||||||
|
plot_comparison(ax, h_probs, a_probs, dests, f'From: {evt}', 'Probability')
|
||||||
|
ax.set_ylim([0, max(max(h_probs + a_probs, default=0) * 1.1, 0.1)])
|
||||||
|
ax.text(0.95, 0.95, f'KL={kl:.2f}', transform=ax.transAxes, fontsize=11,
|
||||||
|
fontweight='bold', va='top', ha='right',
|
||||||
|
bbox=dict(boxstyle='round', facecolor=kl_color(kl), alpha=0.3))
|
||||||
|
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.savefig('kl_divergence_comparison.png', dpi=300, bbox_inches='tight')
|
||||||
|
print("Saved visualization to kl_divergence_comparison.png")
|
||||||
|
|
||||||
|
fig2, ax2 = plt.subplots(figsize=(10, 6))
|
||||||
|
evts, kls = zip(*kl_results) if kl_results else ([], [])
|
||||||
|
colors = [kl_color(kl) for kl in kls]
|
||||||
|
bars = ax2.barh(evts, kls, color=colors, alpha=0.8)
|
||||||
|
ax2.set_xlabel('KL Divergence D(Human || Agent)', fontsize=12, fontweight='bold')
|
||||||
|
ax2.set_ylabel('Event Type', fontsize=12, fontweight='bold')
|
||||||
|
ax2.set_title('Behavioral Divergence Between Human and Agent Traffic', fontsize=14, fontweight='bold')
|
||||||
|
if kls:
|
||||||
|
ax2.axvline(x=np.mean(kls), color='black', linestyle='--', linewidth=2,
|
||||||
|
alpha=0.5, label=f'Mean={np.mean(kls):.2f}')
|
||||||
|
for bar, kl in zip(bars, kls):
|
||||||
|
ax2.text(bar.get_width() + 0.1, bar.get_y() + bar.get_height()/2,
|
||||||
|
f'{kl:.2f}', ha='left', va='center', fontsize=10, fontweight='bold')
|
||||||
|
ax2.legend()
|
||||||
|
ax2.grid(axis='x', alpha=0.3, linestyle='--')
|
||||||
|
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.savefig('kl_summary.png', dpi=300, bbox_inches='tight')
|
||||||
|
print("Saved KL summary to kl_summary.png")
|
||||||
|
|
||||||
|
h_freq, a_freq = event_frequency_distribution(human_mdp), event_frequency_distribution(agent_mdp)
|
||||||
|
h_trans, a_trans = transition_distribution(human_mdp), transition_distribution(agent_mdp)
|
||||||
|
freq_kl, trans_kl = kl_divergence(h_freq, a_freq), kl_divergence(h_trans, a_trans)
|
||||||
|
|
||||||
|
print(f"\n=== Global Distribution KL Divergence ===")
|
||||||
|
print(f"Event frequency KL: {freq_kl:.4f}")
|
||||||
|
print(f"Transition pair KL: {trans_kl:.4f}")
|
||||||
|
|
||||||
|
fig3, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))
|
||||||
|
|
||||||
|
all_evts = sorted(set(h_freq.keys()) | set(a_freq.keys()))
|
||||||
|
h_freqs, a_freqs = [h_freq.get(e, 0) for e in all_evts], [a_freq.get(e, 0) for e in all_evts]
|
||||||
|
plot_comparison(ax1, h_freqs, a_freqs, all_evts, 'Event Frequency Distribution',
|
||||||
|
'Frequency', freq_kl)
|
||||||
|
|
||||||
|
all_trans = sorted(set(h_trans.keys()) | set(a_trans.keys()))
|
||||||
|
top_trans = [t for t, _ in sorted([(t, h_trans.get(t, 0) + a_trans.get(t, 0))
|
||||||
|
for t in all_trans], key=lambda x: x[1], reverse=True)[:15]]
|
||||||
|
h_tprobs, a_tprobs = [h_trans.get(t, 0) for t in top_trans], [a_trans.get(t, 0) for t in top_trans]
|
||||||
|
plot_comparison(ax2, h_tprobs, a_tprobs, top_trans, 'Top Transition Pairs Distribution',
|
||||||
|
'Probability', trans_kl)
|
||||||
|
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.savefig('global_distributions.png', dpi=300, bbox_inches='tight')
|
||||||
|
print("Saved global distributions to global_distributions.png")
|
||||||
86
sim/rl/thesis_core.py
Normal file
86
sim/rl/thesis_core.py
Normal file
@@ -0,0 +1,86 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Dict, Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from sim.case.thesis_simplified.simplified import Session
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class PricingStep:
|
||||||
|
sessions: list[Session]
|
||||||
|
demand_by_session: Dict[str, float]
|
||||||
|
demand_by_product: np.ndarray
|
||||||
|
purchases_by_product: np.ndarray
|
||||||
|
revenue: float
|
||||||
|
cost: float
|
||||||
|
n_agents: int
|
||||||
|
|
||||||
|
|
||||||
|
def clip_prices(prices: np.ndarray, min_price: float, max_price: float) -> np.ndarray:
|
||||||
|
return np.clip(prices, min_price, max_price).astype(np.float32)
|
||||||
|
|
||||||
|
|
||||||
|
def constrain_prices(
|
||||||
|
prev_prices: Optional[np.ndarray],
|
||||||
|
proposed: np.ndarray,
|
||||||
|
*,
|
||||||
|
costs: np.ndarray,
|
||||||
|
min_price: float,
|
||||||
|
max_price: float,
|
||||||
|
max_adjustment: float,
|
||||||
|
min_margin_pct: float,
|
||||||
|
) -> np.ndarray:
|
||||||
|
prices = clip_prices(proposed, min_price, max_price)
|
||||||
|
floor = (costs * (1.0 + float(min_margin_pct))).astype(np.float32)
|
||||||
|
prices = np.maximum(prices, floor)
|
||||||
|
if prev_prices is None:
|
||||||
|
return prices
|
||||||
|
prev_prices = prev_prices.astype(np.float32)
|
||||||
|
ratio = np.clip(prices / (prev_prices + 1e-6), 1.0 - max_adjustment, 1.0 + max_adjustment)
|
||||||
|
return (prev_prices * ratio).astype(np.float32)
|
||||||
|
|
||||||
|
|
||||||
|
def aggregate_demand_by_product(
|
||||||
|
sessions: list[Session],
|
||||||
|
demand_by_session: Dict[str, float],
|
||||||
|
n_products: int,
|
||||||
|
) -> np.ndarray:
|
||||||
|
demand = np.zeros(n_products, dtype=np.float32)
|
||||||
|
sessions_by_id = {s.sid: s for s in sessions}
|
||||||
|
for sid, q in demand_by_session.items():
|
||||||
|
sess = sessions_by_id.get(sid)
|
||||||
|
if not sess or not sess.events:
|
||||||
|
continue
|
||||||
|
pidx = int(sess.events[0].product_idx)
|
||||||
|
if 0 <= pidx < n_products:
|
||||||
|
demand[pidx] += float(q)
|
||||||
|
return demand
|
||||||
|
|
||||||
|
|
||||||
|
def aggregate_purchases(
|
||||||
|
sessions: list[Session],
|
||||||
|
costs: np.ndarray,
|
||||||
|
n_products: int,
|
||||||
|
) -> tuple[np.ndarray, float, float, int]:
|
||||||
|
purchases = np.zeros(n_products, dtype=np.float32)
|
||||||
|
revenue = 0.0
|
||||||
|
cost = 0.0
|
||||||
|
n_agents = 0
|
||||||
|
|
||||||
|
for sess in sessions:
|
||||||
|
if sess.actor == "A":
|
||||||
|
n_agents += 1
|
||||||
|
for e in sess.events:
|
||||||
|
if e.action != "purchase":
|
||||||
|
continue
|
||||||
|
pidx = int(e.product_idx)
|
||||||
|
if 0 <= pidx < n_products:
|
||||||
|
purchases[pidx] += 1.0
|
||||||
|
revenue += float(e.price_seen)
|
||||||
|
cost += float(costs[pidx])
|
||||||
|
|
||||||
|
return purchases, revenue, cost, n_agents
|
||||||
|
|
||||||
Reference in New Issue
Block a user