mirror of
https://github.com/velocitatem/PHANTOM.git
synced 2026-05-31 08:33:36 +00:00
chore: including new scritps for automation
This commit is contained in:
38
scripts/launch_calibration_screen.sh
Executable file
38
scripts/launch_calibration_screen.sh
Executable file
@@ -0,0 +1,38 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
|
||||
|
||||
export RAY_MODE="${RAY_MODE:-sweep}"
|
||||
export SWEEP_KIND="${SWEEP_KIND:-ppo_block_a}"
|
||||
export SWEEP_METHOD="${SWEEP_METHOD:-grid}"
|
||||
export SWEEP_PROFILE="${SWEEP_PROFILE:-default}"
|
||||
export SWEEP_RUN_CAP="${SWEEP_RUN_CAP:-27}"
|
||||
export COMPARE_ROBUST="${COMPARE_ROBUST:-1}"
|
||||
export NUM_NODES="${NUM_NODES:-3}"
|
||||
export AGENTS_PER_NODE="${AGENTS_PER_NODE:-4}"
|
||||
export AGENT_COUNT="${AGENT_COUNT:-0}"
|
||||
export INNER_THREADS="${INNER_THREADS:-1}"
|
||||
export PHANTOM_JAX_PLATFORM="${PHANTOM_JAX_PLATFORM:-cpu}"
|
||||
export OUTPUT_ROOT="${OUTPUT_ROOT:-engine/studies/results/block_a_sweep}"
|
||||
|
||||
if [ -z "${WORKER_CPUS:-}" ]; then
|
||||
export WORKER_CPUS="$((AGENTS_PER_NODE * INNER_THREADS))"
|
||||
fi
|
||||
|
||||
printf '%s\n' "Launching Block A PPO calibration sweep"
|
||||
printf '%s\n' "RAY_MODE=$RAY_MODE"
|
||||
printf '%s\n' "SWEEP_KIND=$SWEEP_KIND"
|
||||
printf '%s\n' "SWEEP_METHOD=$SWEEP_METHOD"
|
||||
printf '%s\n' "SWEEP_RUN_CAP=$SWEEP_RUN_CAP"
|
||||
printf '%s\n' "COMPARE_ROBUST=$COMPARE_ROBUST"
|
||||
printf '%s\n' "NUM_NODES=$NUM_NODES"
|
||||
printf '%s\n' "AGENTS_PER_NODE=$AGENTS_PER_NODE"
|
||||
printf '%s\n' "AGENT_COUNT=$AGENT_COUNT"
|
||||
printf '%s\n' "INNER_THREADS=$INNER_THREADS"
|
||||
printf '%s\n' "WORKER_CPUS=$WORKER_CPUS"
|
||||
printf '%s\n' "OUTPUT_ROOT=$OUTPUT_ROOT"
|
||||
|
||||
cd "$ROOT"
|
||||
bash ./submit_ray_job.sh
|
||||
9
scripts/setuptpu.sh
Normal file
9
scripts/setuptpu.sh
Normal file
@@ -0,0 +1,9 @@
|
||||
commands = (
|
||||
"pip install \"jax[tpu]\" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html"
|
||||
"pip install stable-baselines3>=2.2.0 gymnasium wandb tensorboard"
|
||||
|
||||
|
||||
"
|
||||
|
||||
|
||||
)
|
||||
333
scripts/wandb_compare_best.py
Normal file
333
scripts/wandb_compare_best.py
Normal file
@@ -0,0 +1,333 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import shlex
|
||||
import subprocess
|
||||
import sys
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
|
||||
def _truthy(value: Any) -> bool:
|
||||
if isinstance(value, bool):
|
||||
return value
|
||||
if value is None:
|
||||
return False
|
||||
return str(value).strip().lower() in {"1", "true", "yes", "on"}
|
||||
|
||||
|
||||
def _as_float(value: Any, default: float) -> float:
|
||||
try:
|
||||
return float(value)
|
||||
except (TypeError, ValueError):
|
||||
return float(default)
|
||||
|
||||
|
||||
def _as_int(value: Any, default: int) -> int:
|
||||
try:
|
||||
return int(float(value))
|
||||
except (TypeError, ValueError):
|
||||
return int(default)
|
||||
|
||||
|
||||
def _normalize_sweep_id(
|
||||
raw: str, entity: str, project: str
|
||||
) -> tuple[str, str, str, str]:
|
||||
sweep_raw = str(raw).strip()
|
||||
if not sweep_raw:
|
||||
raise ValueError("--sweep-id is required")
|
||||
parts = [piece.strip() for piece in sweep_raw.split("/") if piece.strip()]
|
||||
if len(parts) == 3:
|
||||
return f"{parts[0]}/{parts[1]}/{parts[2]}", parts[0], parts[1], parts[2]
|
||||
if len(parts) == 2:
|
||||
if not entity.strip():
|
||||
raise ValueError("--entity is required when --sweep-id is '<project>/<id>'")
|
||||
return f"{entity}/{parts[0]}/{parts[1]}", entity, parts[0], parts[1]
|
||||
if len(parts) == 1:
|
||||
if not entity.strip() or not project.strip():
|
||||
raise ValueError(
|
||||
"--entity and --project are required when --sweep-id is '<id>'"
|
||||
)
|
||||
return f"{entity}/{project}/{parts[0]}", entity, project, parts[0]
|
||||
raise ValueError(f"invalid --sweep-id value: '{raw}'")
|
||||
|
||||
|
||||
def _pick_best_defended_run(
|
||||
sweep: Any,
|
||||
metric: str,
|
||||
*,
|
||||
min_margin: float,
|
||||
min_coi: float,
|
||||
) -> tuple[Any, float]:
|
||||
ranked: list[tuple[float, Any]] = []
|
||||
for run in list(sweep.runs):
|
||||
if str(getattr(run, "state", "")).lower() != "finished":
|
||||
continue
|
||||
cfg = dict(getattr(run, "config", {}) or {})
|
||||
is_baseline = (
|
||||
_truthy(cfg.get("baseline_mode"))
|
||||
if "baseline_mode" in cfg
|
||||
else _truthy(cfg.get("no_robust"))
|
||||
)
|
||||
if is_baseline:
|
||||
continue
|
||||
summary = dict(getattr(run, "summary", {}) or {})
|
||||
margin = _as_float(summary.get("eval/margin_mean"), -1.0)
|
||||
coi_level = _as_float(summary.get("eval/coi_level_mean"), -1.0)
|
||||
if margin < float(min_margin):
|
||||
continue
|
||||
if coi_level < float(min_coi):
|
||||
continue
|
||||
score = summary.get(metric)
|
||||
if score is None and str(metric) == "eval/stress_revenue_worst":
|
||||
score = summary.get("eval/robust_revenue_worst")
|
||||
if score is None:
|
||||
continue
|
||||
try:
|
||||
ranked.append((float(score), run))
|
||||
except (TypeError, ValueError):
|
||||
continue
|
||||
if not ranked:
|
||||
raise RuntimeError(
|
||||
f"no finished defended runs found with summary metric '{metric}' and constraints "
|
||||
f"margin>={min_margin}, coi>={min_coi}"
|
||||
)
|
||||
ranked.sort(key=lambda item: item[0], reverse=True)
|
||||
return ranked[0][1], ranked[0][0]
|
||||
|
||||
|
||||
def _format_alpha_values(raw: str, fallback_alpha: float) -> str:
|
||||
cleaned = str(raw).strip()
|
||||
if cleaned:
|
||||
return cleaned
|
||||
return f"{float(fallback_alpha):.6g}"
|
||||
|
||||
|
||||
def _benchmark_tokens(
|
||||
*,
|
||||
project: str,
|
||||
cfg: dict[str, Any],
|
||||
alpha_values: str,
|
||||
episodes: int,
|
||||
) -> list[str]:
|
||||
algo = str(cfg.get("algo", "")).strip().lower()
|
||||
if algo not in {"qtable", "ppo", "a2c", "dqn"}:
|
||||
raise ValueError(f"unsupported algo in best run: '{algo}'")
|
||||
|
||||
total_timesteps = _as_int(cfg.get("total_timesteps"), 80_000)
|
||||
max_steps = _as_int(cfg.get("max_steps"), 100)
|
||||
ambiguity_radius = _as_float(
|
||||
cfg.get("ambiguity_radius", cfg.get("robust_radius")), 0.2
|
||||
)
|
||||
ambiguity_points = _as_int(cfg.get("ambiguity_points", cfg.get("robust_points")), 7)
|
||||
ambiguity_rollouts = _as_int(
|
||||
cfg.get("ambiguity_rollouts", cfg.get("robust_rollouts")), 1
|
||||
)
|
||||
lambda_coi = _as_float(cfg.get("lambda_coi"), 0.2)
|
||||
eta_ux = _as_float(cfg.get("eta_ux"), 0.5)
|
||||
reward_profit_weight = _as_float(cfg.get("reward_profit_weight"), 1.0)
|
||||
learning_rate = _as_float(cfg.get("learning_rate"), 3e-4)
|
||||
batch_size = _as_int(cfg.get("batch_size"), 256)
|
||||
n_steps = _as_int(cfg.get("n_steps"), 2048)
|
||||
sessions = _as_int(cfg.get("N"), 100)
|
||||
action_levels = _as_int(cfg.get("action_levels"), 9)
|
||||
margin_floor = _as_float(cfg.get("margin_floor"), 0.85)
|
||||
seed = _as_int(cfg.get("seed"), 42)
|
||||
|
||||
return [
|
||||
"--project",
|
||||
project,
|
||||
"--tiers",
|
||||
algo,
|
||||
"--alpha-values",
|
||||
alpha_values,
|
||||
"--episodes",
|
||||
str(int(episodes)),
|
||||
"--seed",
|
||||
str(seed),
|
||||
"--total-timesteps",
|
||||
str(total_timesteps),
|
||||
"--max-steps",
|
||||
str(max_steps),
|
||||
"--robust-radius",
|
||||
str(ambiguity_radius),
|
||||
"--robust-points",
|
||||
str(ambiguity_points),
|
||||
"--robust-rollouts",
|
||||
str(ambiguity_rollouts),
|
||||
"--lambda-coi",
|
||||
str(lambda_coi),
|
||||
"--eta-ux",
|
||||
str(eta_ux),
|
||||
"--reward-profit-weight",
|
||||
str(reward_profit_weight),
|
||||
"--learning-rate",
|
||||
str(learning_rate),
|
||||
"--batch-size",
|
||||
str(batch_size),
|
||||
"--n-steps",
|
||||
str(n_steps),
|
||||
"--N",
|
||||
str(sessions),
|
||||
"--action-levels",
|
||||
str(action_levels),
|
||||
"--margin-floor",
|
||||
str(margin_floor),
|
||||
"--device",
|
||||
"cpu",
|
||||
]
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Find best defended sweep run and prepare defended-vs-baseline benchmark"
|
||||
)
|
||||
parser.add_argument("--sweep-id", required=True)
|
||||
parser.add_argument("--entity", default="")
|
||||
parser.add_argument("--project", default="")
|
||||
parser.add_argument("--metric", default="eval/stress_revenue_worst")
|
||||
parser.add_argument("--min-margin", type=float, default=0.90)
|
||||
parser.add_argument("--min-coi", type=float, default=120.0)
|
||||
parser.add_argument("--alpha-values", default="")
|
||||
parser.add_argument("--episodes", type=int, default=15)
|
||||
parser.add_argument("--num-nodes", type=int, default=4)
|
||||
parser.add_argument("--tpu-per-task", type=float, default=0.0)
|
||||
parser.add_argument("--inner-workers", type=int, default=12)
|
||||
parser.add_argument("--inner-threads", type=int, default=1)
|
||||
parser.add_argument("--max-heavy-workers", type=int, default=3)
|
||||
parser.add_argument("--worker-cpus", type=int, default=24)
|
||||
parser.add_argument(
|
||||
"--output-root", default="engine/studies/results/overnight/best_compare"
|
||||
)
|
||||
parser.add_argument("--timeout", type=int, default=120)
|
||||
parser.add_argument("--submit", action="store_true")
|
||||
parser.add_argument("--ray-no-wait", action="store_true")
|
||||
parser.add_argument("--submission-id", default="")
|
||||
parser.add_argument("--output-json", default="")
|
||||
args = parser.parse_args()
|
||||
|
||||
root = Path(__file__).resolve().parents[1]
|
||||
cwd = str(Path.cwd())
|
||||
sys.path = [p for p in sys.path if p not in {"", cwd}]
|
||||
|
||||
try:
|
||||
import wandb
|
||||
except ImportError as exc:
|
||||
raise ImportError("wandb is required") from exc
|
||||
|
||||
full_sweep_id, entity, project, _ = _normalize_sweep_id(
|
||||
raw=str(args.sweep_id),
|
||||
entity=str(args.entity).strip(),
|
||||
project=str(args.project).strip(),
|
||||
)
|
||||
api = wandb.Api(timeout=int(args.timeout))
|
||||
sweep = api.sweep(full_sweep_id)
|
||||
best_run, best_score = _pick_best_defended_run(
|
||||
sweep,
|
||||
str(args.metric),
|
||||
min_margin=float(args.min_margin),
|
||||
min_coi=float(args.min_coi),
|
||||
)
|
||||
|
||||
best_cfg = dict(getattr(best_run, "config", {}) or {})
|
||||
best_alpha = _as_float(
|
||||
best_cfg.get(
|
||||
"alpha",
|
||||
getattr(best_run, "summary", {}).get("study/alpha", 0.6),
|
||||
),
|
||||
0.6,
|
||||
)
|
||||
alpha_values = _format_alpha_values(
|
||||
str(args.alpha_values), fallback_alpha=best_alpha
|
||||
)
|
||||
benchmark_tokens = _benchmark_tokens(
|
||||
project=project,
|
||||
cfg=best_cfg,
|
||||
alpha_values=alpha_values,
|
||||
episodes=int(args.episodes),
|
||||
)
|
||||
benchmark_args = shlex.join(benchmark_tokens)
|
||||
|
||||
submission_id = str(args.submission_id).strip()
|
||||
if not submission_id:
|
||||
stamp = datetime.now(timezone.utc).strftime("%m%d-%H%M")
|
||||
submission_id = f"best-compare-{stamp}"
|
||||
|
||||
env_overrides = {
|
||||
"RAY_MODE": "benchmark",
|
||||
"COMPARE_ROBUST": "1",
|
||||
"NUM_NODES": str(int(args.num_nodes)),
|
||||
"TPU_PER_TASK": str(float(args.tpu_per_task)),
|
||||
"PHANTOM_JAX_PLATFORM": "cpu",
|
||||
"WANDB_ENTITY": entity,
|
||||
"WANDB_PROJECT": project,
|
||||
"BENCHMARK_ARGS": benchmark_args,
|
||||
"INNER_WORKERS": str(int(args.inner_workers)),
|
||||
"INNER_THREADS": str(int(args.inner_threads)),
|
||||
"MAX_HEAVY_WORKERS": str(int(args.max_heavy_workers)),
|
||||
"WORKER_CPUS": str(int(args.worker_cpus)),
|
||||
"OUTPUT_ROOT": str(args.output_root),
|
||||
"SUBMISSION_ID": submission_id,
|
||||
}
|
||||
if bool(args.ray_no_wait):
|
||||
env_overrides["RAY_NO_WAIT"] = "1"
|
||||
|
||||
command_str = (
|
||||
"cd "
|
||||
+ shlex.quote(str(root))
|
||||
+ " && "
|
||||
+ " ".join(
|
||||
f"{key}={shlex.quote(str(value))}" for key, value in env_overrides.items()
|
||||
)
|
||||
+ " bash ./submit_ray_job.sh"
|
||||
)
|
||||
|
||||
payload = {
|
||||
"sweep_id": full_sweep_id,
|
||||
"selection_metric": str(args.metric),
|
||||
"constraints": {
|
||||
"min_margin": float(args.min_margin),
|
||||
"min_coi": float(args.min_coi),
|
||||
},
|
||||
"best_run": {
|
||||
"id": str(getattr(best_run, "id", "")),
|
||||
"name": str(getattr(best_run, "name", "")),
|
||||
"url": str(getattr(best_run, "url", "")),
|
||||
"score": float(best_score),
|
||||
"algo": str(best_cfg.get("algo", "")),
|
||||
"alpha": float(best_alpha),
|
||||
"eval_margin_mean": _as_float(
|
||||
getattr(best_run, "summary", {}).get("eval/margin_mean"), 0.0
|
||||
),
|
||||
"eval_coi_level_mean": _as_float(
|
||||
getattr(best_run, "summary", {}).get("eval/coi_level_mean"), 0.0
|
||||
),
|
||||
},
|
||||
"benchmark_compare_command": command_str,
|
||||
}
|
||||
print(json.dumps(payload, indent=2))
|
||||
|
||||
output_json = str(args.output_json).strip()
|
||||
if output_json:
|
||||
out_path = Path(output_json)
|
||||
if not out_path.is_absolute():
|
||||
out_path = root / out_path
|
||||
out_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
out_path.write_text(json.dumps(payload, indent=2) + "\n")
|
||||
|
||||
if bool(args.submit):
|
||||
run_env = dict(os.environ)
|
||||
run_env.update({key: str(value) for key, value in env_overrides.items()})
|
||||
subprocess.run(
|
||||
["bash", "./submit_ray_job.sh"],
|
||||
cwd=str(root),
|
||||
env=run_env,
|
||||
check=True,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
313
scripts/wandb_create_sweep.py
Normal file
313
scripts/wandb_create_sweep.py
Normal file
@@ -0,0 +1,313 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import contextlib
|
||||
import io
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
|
||||
def _base_sweep(method: str, metric_name: str) -> dict[str, Any]:
|
||||
return {
|
||||
"method": str(method),
|
||||
"metric": {"name": str(metric_name), "goal": "maximize"},
|
||||
}
|
||||
|
||||
|
||||
def _benchmark_sweep(method: str) -> dict[str, Any]:
|
||||
cfg = _base_sweep(method=method, metric_name="objective/score")
|
||||
cfg["name"] = "benchmark-all-algos-defense"
|
||||
cfg["parameters"] = {
|
||||
"tiers": {
|
||||
"values": [
|
||||
"static",
|
||||
"surge",
|
||||
"linear",
|
||||
"qtable",
|
||||
"ppo",
|
||||
"a2c",
|
||||
"dqn",
|
||||
]
|
||||
},
|
||||
"alpha_values": {"values": ["0.0", "0.1", "0.25", "0.4", "0.6", "0.8"]},
|
||||
"baseline_mode": {"values": [False, True]},
|
||||
"seed": {"values": [42, 1337, 2026, 7777]},
|
||||
"episodes": {"values": [8, 12]},
|
||||
"total_timesteps": {"values": [15000, 30000, 50000]},
|
||||
"lambda_coi": {"values": [0.1, 0.2, 0.4]},
|
||||
"ambiguity_radius": {"values": [0.1, 0.2, 0.3]},
|
||||
"ambiguity_points": {"values": [5, 7]},
|
||||
"ambiguity_rollouts": {"values": [1, 2]},
|
||||
"eta_ux": {"values": [0.25, 0.5, 0.75]},
|
||||
"reward_profit_weight": {"values": [0.75, 1.0, 1.25]},
|
||||
"learning_rate": {"values": [1e-4, 3e-4, 1e-3]},
|
||||
"batch_size": {"values": [128, 256, 512]},
|
||||
"n_steps": {"values": [1024, 2048, 4096]},
|
||||
"device": {"value": "cpu"},
|
||||
}
|
||||
return cfg
|
||||
|
||||
|
||||
def _train_sweep(method: str) -> dict[str, Any]:
|
||||
cfg = _base_sweep(method=method, metric_name="objective/score")
|
||||
cfg["name"] = "train-all-algos-defense"
|
||||
cfg["parameters"] = {
|
||||
"algo": {"values": ["qtable", "ppo", "a2c", "dqn"]},
|
||||
"alpha": {"values": [0.0, 0.1, 0.25, 0.4, 0.6]},
|
||||
"baseline_mode": {"values": [False, True]},
|
||||
"seed": {"values": [42, 1337, 2026, 7777]},
|
||||
"total_timesteps": {"values": [30000, 50000, 80000]},
|
||||
"learning_rate": {"values": [1e-4, 3e-4, 1e-3]},
|
||||
"batch_size": {"values": [128, 256, 512]},
|
||||
"n_steps": {"values": [1024, 2048, 4096]},
|
||||
"lambda_coi": {"values": [0.1, 0.2, 0.4]},
|
||||
"ambiguity_radius": {"values": [0.1, 0.2, 0.3]},
|
||||
"ambiguity_points": {"values": [3, 5, 7]},
|
||||
"ambiguity_rollouts": {"values": [1, 2]},
|
||||
"eta_ux": {"values": [0.25, 0.5, 0.75]},
|
||||
"reward_profit_weight": {"values": [0.75, 1.0, 1.25]},
|
||||
"N": {"values": [80, 100, 140]},
|
||||
"max_steps": {"values": [80, 100, 120]},
|
||||
"action_levels": {"values": [7, 9, 11]},
|
||||
"device": {"value": "cpu"},
|
||||
}
|
||||
return cfg
|
||||
|
||||
|
||||
def _train_robust_revenue_sweep(method: str) -> dict[str, Any]:
|
||||
cfg = _base_sweep(method=method, metric_name="eval/stress_revenue_worst")
|
||||
cfg["name"] = "train-defense-revenue-search"
|
||||
cfg["parameters"] = {
|
||||
"algo": {"values": ["qtable", "ppo", "a2c", "dqn"]},
|
||||
"alpha": {"values": [0.4, 0.6, 0.8]},
|
||||
"baseline_mode": {"value": False},
|
||||
"seed": {"values": [42, 1337, 2026, 7777]},
|
||||
"total_timesteps": {"values": [60_000, 80_000, 120_000]},
|
||||
"learning_rate": {"values": [1e-4, 3e-4, 1e-3]},
|
||||
"batch_size": {"values": [128, 256, 512]},
|
||||
"n_steps": {"values": [1024, 2048, 4096]},
|
||||
"lambda_coi": {"values": [0.2, 0.4, 0.6]},
|
||||
"ambiguity_radius": {"values": [0.1, 0.2, 0.3]},
|
||||
"ambiguity_points": {"values": [5, 7, 9]},
|
||||
"ambiguity_rollouts": {"values": [1, 2]},
|
||||
"eta_ux": {"values": [0.25, 0.5, 0.75]},
|
||||
"reward_profit_weight": {"values": [1.0, 1.25]},
|
||||
"N": {"values": [80, 100, 140]},
|
||||
"max_steps": {"values": [80, 100, 120]},
|
||||
"action_levels": {"values": [7, 9, 11]},
|
||||
"margin_floor": {"value": 0.85},
|
||||
"device": {"value": "cpu"},
|
||||
}
|
||||
return cfg
|
||||
|
||||
|
||||
def _ppo_calibration_sweep(method: str) -> dict[str, Any]:
|
||||
cfg = _base_sweep(method=method, metric_name="objective/score")
|
||||
cfg["name"] = "benchmark-ppo-calibration"
|
||||
cfg["parameters"] = {
|
||||
"tiers": {"value": "ppo"},
|
||||
"alpha_values": {"values": ["0.0", "0.1", "0.25", "0.4", "0.6", "0.8"]},
|
||||
"baseline_mode": {"values": [False, True]},
|
||||
"seed": {"values": [42, 1337, 2026, 7777]},
|
||||
"episodes": {"value": 12},
|
||||
"total_timesteps": {"value": 60000},
|
||||
"lambda_coi": {
|
||||
"distribution": "uniform",
|
||||
"min": 0.05,
|
||||
"max": 0.6,
|
||||
},
|
||||
"ambiguity_radius": {
|
||||
"distribution": "uniform",
|
||||
"min": 0.05,
|
||||
"max": 0.45,
|
||||
},
|
||||
"ambiguity_points": {"value": 7},
|
||||
"ambiguity_rollouts": {"value": 1},
|
||||
"eta_ux": {"value": 0.5},
|
||||
"reward_profit_weight": {"value": 1.0},
|
||||
"learning_rate": {
|
||||
"distribution": "log_uniform_values",
|
||||
"min": 1e-4,
|
||||
"max": 1e-3,
|
||||
},
|
||||
"batch_size": {"values": [128, 256, 512]},
|
||||
"n_steps": {"values": [1024, 2048, 4096]},
|
||||
"device": {"value": "cpu"},
|
||||
}
|
||||
return cfg
|
||||
|
||||
|
||||
def _ppo_block_a_sweep(method: str) -> dict[str, Any]:
|
||||
cfg = _base_sweep(method=method, metric_name="objective/score")
|
||||
cfg["name"] = "benchmark-ppo-block-a-calibration"
|
||||
cfg["parameters"] = {
|
||||
"tiers": {"value": "ppo"},
|
||||
"alpha_values": {"value": "0.25,0.6,0.8"},
|
||||
"seed": {"values": [42, 1337, 2026]},
|
||||
"episodes": {"value": 12},
|
||||
"total_timesteps": {"value": 80000},
|
||||
"lambda_coi": {"values": [0.05, 0.1, 0.2]},
|
||||
"ambiguity_radius": {"values": [0.05, 0.1, 0.2]},
|
||||
"ambiguity_points": {"value": 7},
|
||||
"ambiguity_rollouts": {"value": 1},
|
||||
"eta_ux": {"value": 0.5},
|
||||
"reward_profit_weight": {"value": 1.0},
|
||||
"learning_rate": {"value": 3e-4},
|
||||
"batch_size": {"value": 256},
|
||||
"n_steps": {"value": 2048},
|
||||
"device": {"value": "cpu"},
|
||||
}
|
||||
return cfg
|
||||
|
||||
|
||||
def _ppo_shift_screen_sweep(method: str) -> dict[str, Any]:
|
||||
cfg = _base_sweep(method=method, metric_name="objective/score")
|
||||
cfg["name"] = "benchmark-ppo-shift-screen"
|
||||
cfg["parameters"] = {
|
||||
"tiers": {"value": "ppo"},
|
||||
"alpha_values": {"value": "0.25"},
|
||||
"eval_alpha_values": {"value": "0.6,0.8"},
|
||||
"seed": {"values": [42, 1337, 2026]},
|
||||
"episodes": {"value": 20},
|
||||
"total_timesteps": {"value": 80000},
|
||||
"lambda_coi": {"values": [0.0, 0.02, 0.05, 0.1]},
|
||||
"ambiguity_radius": {"values": [0.0, 0.02, 0.05, 0.1]},
|
||||
"ambiguity_points": {"value": 5},
|
||||
"ambiguity_rollouts": {"value": 1},
|
||||
"eta_ux": {"value": 0.0},
|
||||
"reward_profit_weight": {"value": 1.0},
|
||||
"learning_rate": {"value": 3e-4},
|
||||
"batch_size": {"value": 256},
|
||||
"n_steps": {"value": 2048},
|
||||
"device": {"value": "cpu"},
|
||||
}
|
||||
return cfg
|
||||
|
||||
|
||||
def _ppo_rl_study_sweep(method: str) -> dict[str, Any]:
|
||||
cfg = _base_sweep(method=method, metric_name="eval/stress_revenue_worst")
|
||||
cfg["name"] = "train-ppo-standard-vs-defended-equilibrium"
|
||||
cfg["parameters"] = {
|
||||
"algo": {"value": "ppo"},
|
||||
"seed": {"values": [42, 1337, 7777]},
|
||||
"alpha": {"values": [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]},
|
||||
"n_products": {"values": [5, 25, 50, 100]},
|
||||
"N": {"value": 100},
|
||||
"no_robust": {"values": [False, True]},
|
||||
"lambda_coi": {"values": [0.05, 0.15, 0.3]},
|
||||
"ambiguity_radius": {"values": [0.1, 0.2, 0.3]},
|
||||
"ambiguity_points": {"value": 7},
|
||||
"ambiguity_rollouts": {"value": 1},
|
||||
"eta_ux": {"value": 0.0},
|
||||
"reward_profit_weight": {"value": 1.0},
|
||||
"total_timesteps": {"value": 100000},
|
||||
"eval_episodes": {"value": 10},
|
||||
"eval_freq": {"value": 1000},
|
||||
"log_freq": {"value": 100},
|
||||
"hist_freq": {"value": 500},
|
||||
"learning_rate": {"value": 3e-4},
|
||||
"batch_size": {"value": 256},
|
||||
"n_steps": {"value": 2048},
|
||||
"device": {"value": "cpu"},
|
||||
}
|
||||
return cfg
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description="Create W&B sweep for PHANTOM")
|
||||
parser.add_argument(
|
||||
"--kind",
|
||||
choices=[
|
||||
"benchmark",
|
||||
"train",
|
||||
"ppo_calibration",
|
||||
"ppo_block_a",
|
||||
"ppo_shift_screen",
|
||||
"ppo_rl_study",
|
||||
],
|
||||
default="benchmark",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--profile",
|
||||
choices=["default", "robust_revenue"],
|
||||
default="default",
|
||||
)
|
||||
parser.add_argument("--project", required=True)
|
||||
parser.add_argument("--entity", default="")
|
||||
parser.add_argument(
|
||||
"--method", choices=["random", "bayes", "grid"], default="random"
|
||||
)
|
||||
parser.add_argument("--run-cap", type=int, default=0)
|
||||
parser.add_argument("--json", action="store_true")
|
||||
parser.add_argument("--full-id", action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
||||
cwd = str(Path.cwd())
|
||||
sys.path = [p for p in sys.path if p not in {"", cwd}]
|
||||
|
||||
try:
|
||||
import wandb
|
||||
except ImportError as exc:
|
||||
raise ImportError("wandb is required to create sweeps") from exc
|
||||
|
||||
if str(args.kind) == "benchmark":
|
||||
if str(args.profile) != "default":
|
||||
raise ValueError("benchmark sweeps only support --profile default")
|
||||
sweep_cfg = _benchmark_sweep(args.method)
|
||||
elif str(args.kind) == "train":
|
||||
if str(args.profile) == "robust_revenue":
|
||||
sweep_cfg = _train_robust_revenue_sweep(args.method)
|
||||
else:
|
||||
sweep_cfg = _train_sweep(args.method)
|
||||
elif str(args.kind) == "ppo_calibration":
|
||||
if str(args.profile) != "default":
|
||||
raise ValueError("ppo_calibration sweeps only support --profile default")
|
||||
sweep_cfg = _ppo_calibration_sweep(args.method)
|
||||
elif str(args.kind) == "ppo_block_a":
|
||||
if str(args.profile) != "default":
|
||||
raise ValueError("ppo_block_a sweeps only support --profile default")
|
||||
sweep_cfg = _ppo_block_a_sweep(args.method)
|
||||
elif str(args.kind) == "ppo_shift_screen":
|
||||
if str(args.profile) != "default":
|
||||
raise ValueError("ppo_shift_screen sweeps only support --profile default")
|
||||
sweep_cfg = _ppo_shift_screen_sweep(args.method)
|
||||
else:
|
||||
if str(args.profile) != "default":
|
||||
raise ValueError("ppo_rl_study sweeps only support --profile default")
|
||||
sweep_cfg = _ppo_rl_study_sweep(args.method)
|
||||
if int(args.run_cap) > 0:
|
||||
sweep_cfg["run_cap"] = int(args.run_cap)
|
||||
|
||||
with contextlib.redirect_stdout(io.StringIO()):
|
||||
sweep_id = wandb.sweep(
|
||||
sweep=sweep_cfg,
|
||||
project=str(args.project),
|
||||
entity=str(args.entity) if str(args.entity).strip() else None,
|
||||
)
|
||||
full_id = (
|
||||
f"{args.entity}/{args.project}/{sweep_id}"
|
||||
if str(args.entity).strip()
|
||||
else f"{args.project}/{sweep_id}"
|
||||
)
|
||||
|
||||
if bool(args.json):
|
||||
print(
|
||||
json.dumps(
|
||||
{
|
||||
"kind": str(args.kind),
|
||||
"profile": str(args.profile),
|
||||
"project": str(args.project),
|
||||
"entity": str(args.entity),
|
||||
"sweep_id": str(sweep_id),
|
||||
"full_id": str(full_id),
|
||||
}
|
||||
)
|
||||
)
|
||||
return
|
||||
print(full_id if bool(args.full_id) else sweep_id)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
342
scripts/whoclicked_card.py
Normal file
342
scripts/whoclicked_card.py
Normal file
@@ -0,0 +1,342 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Build and upload a Hugging Face dataset card for whoclickedit."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import pandas as pd
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
|
||||
PROJECT_ROOT = Path(__file__).resolve().parent.parent
|
||||
DEFAULT_INPUT = PROJECT_ROOT / "experiments" / "exports" / "whoclicked.csv"
|
||||
DEFAULT_OUTPUT = PROJECT_ROOT / "experiments" / "exports" / "whoclicked_dataset_card.md"
|
||||
DEFAULT_REPO = os.getenv("HF_WHOCLICKED_REPO", "velocitatem/whoclickedit")
|
||||
|
||||
|
||||
def _token() -> str | None:
|
||||
return os.getenv("HF_TOKEN") or None
|
||||
|
||||
|
||||
def _exception_details(exc: Exception) -> str:
|
||||
parts = [str(exc).strip()]
|
||||
response = getattr(exc, "response", None)
|
||||
if response is not None:
|
||||
status = getattr(response, "status_code", None)
|
||||
if status is not None:
|
||||
parts.append(f"HTTP {status}")
|
||||
text = getattr(response, "text", "")
|
||||
if text:
|
||||
parts.append(text.strip()[:500])
|
||||
return " | ".join(p for p in parts if p)
|
||||
|
||||
|
||||
def _size_category(n_rows: int) -> str:
|
||||
if n_rows < 1_000:
|
||||
return "n<1K"
|
||||
if n_rows < 10_000:
|
||||
return "1K<n<10K"
|
||||
if n_rows < 100_000:
|
||||
return "10K<n<100K"
|
||||
if n_rows < 1_000_000:
|
||||
return "100K<n<1M"
|
||||
return "1M<n"
|
||||
|
||||
|
||||
def _series_count(df: pd.DataFrame, col: str) -> dict[str, int]:
|
||||
if col not in df.columns:
|
||||
return {}
|
||||
vc = df[col].fillna("<null>").astype(str).value_counts(dropna=False)
|
||||
return {k: int(v) for k, v in vc.items()}
|
||||
|
||||
|
||||
def _group_count(df: pd.DataFrame, left: str, right: str) -> dict[tuple[str, str], int]:
|
||||
if left not in df.columns or right not in df.columns:
|
||||
return {}
|
||||
grouped = (
|
||||
df.groupby([left, right], dropna=False)
|
||||
.size()
|
||||
.reset_index(name="count")
|
||||
.sort_values([left, right])
|
||||
)
|
||||
out: dict[tuple[str, str], int] = {}
|
||||
for _, row in grouped.iterrows():
|
||||
out[(str(row[left]), str(row[right]))] = int(row["count"])
|
||||
return out
|
||||
|
||||
|
||||
def _session_count_by_actor(df: pd.DataFrame) -> dict[str, int]:
|
||||
if "actor_type" not in df.columns or "sessionId" not in df.columns:
|
||||
return {}
|
||||
grouped = (
|
||||
df[["actor_type", "sessionId"]]
|
||||
.dropna(subset=["sessionId"])
|
||||
.drop_duplicates()
|
||||
.groupby("actor_type")
|
||||
.size()
|
||||
)
|
||||
return {str(k): int(v) for k, v in grouped.items()}
|
||||
|
||||
|
||||
def _time_range(df: pd.DataFrame) -> tuple[str, str]:
|
||||
if "ts" not in df.columns:
|
||||
return "unknown", "unknown"
|
||||
ts = pd.to_datetime(df["ts"], errors="coerce", utc=True)
|
||||
ts = ts.dropna()
|
||||
if ts.empty:
|
||||
return "unknown", "unknown"
|
||||
return ts.min().isoformat(), ts.max().isoformat()
|
||||
|
||||
|
||||
def _render_card(df: pd.DataFrame) -> str:
|
||||
total_rows = len(df)
|
||||
total_cols = len(df.columns)
|
||||
size_cat = _size_category(total_rows)
|
||||
|
||||
actor_counts = _series_count(df, "actor_type")
|
||||
record_counts = _series_count(df, "record_type")
|
||||
by_actor_record = _group_count(df, "actor_type", "record_type")
|
||||
store_counts = _series_count(df, "storeMode")
|
||||
session_counts = _session_count_by_actor(df)
|
||||
t_min, t_max = _time_range(df)
|
||||
|
||||
event_counts: dict[str, int] = {}
|
||||
if "record_type" in df.columns and "eventName" in df.columns:
|
||||
interactions = df[df["record_type"] == "interaction"]
|
||||
event_counts = _series_count(interactions, "eventName")
|
||||
|
||||
metadata_cols = sorted(c for c in df.columns if c.startswith("metadata_"))
|
||||
|
||||
actor_lines = (
|
||||
"\n".join(f"- `{k}`: {v}" for k, v in actor_counts.items()) or "- none"
|
||||
)
|
||||
record_lines = (
|
||||
"\n".join(f"- `{k}`: {v}" for k, v in record_counts.items()) or "- none"
|
||||
)
|
||||
pair_lines = (
|
||||
"\n".join(
|
||||
f"- `{a}` / `{r}`: {n}"
|
||||
for (a, r), n in sorted(
|
||||
by_actor_record.items(), key=lambda x: (x[0][0], x[0][1])
|
||||
)
|
||||
)
|
||||
or "- none"
|
||||
)
|
||||
store_lines = (
|
||||
"\n".join(f"- `{k}`: {v}" for k, v in store_counts.items()) or "- none"
|
||||
)
|
||||
session_lines = (
|
||||
"\n".join(f"- `{k}`: {v}" for k, v in session_counts.items()) or "- none"
|
||||
)
|
||||
top_events = list(event_counts.items())[:10]
|
||||
event_lines = "\n".join(f"- `{k}`: {v}" for k, v in top_events) or "- none"
|
||||
metadata_lines = "\n".join(f"- `{c}`" for c in metadata_cols) or "- none"
|
||||
|
||||
return f"""---
|
||||
pretty_name: whoclickedit
|
||||
license: mit
|
||||
language:
|
||||
- en
|
||||
task_categories:
|
||||
- tabular-classification
|
||||
task_ids:
|
||||
- tabular-multi-class-classification
|
||||
tags:
|
||||
- e-commerce
|
||||
- dynamic-pricing
|
||||
- behavioral-telemetry
|
||||
- human-vs-agent
|
||||
- session-data
|
||||
size_categories:
|
||||
- {size_cat}
|
||||
---
|
||||
|
||||
# Dataset Card for whoclickedit
|
||||
|
||||
## Dataset Summary
|
||||
whoclickedit is an event-level behavioral dataset for human versus agent interaction analysis in dynamic pricing experiments.
|
||||
It merges interaction logs and price quote logs into one flat CSV (`whoclicked.csv`) with explicit labels for actor type.
|
||||
|
||||
## Dataset Snapshot
|
||||
- Rows: `{total_rows}`
|
||||
- Columns: `{total_cols}`
|
||||
- Time range (UTC): `{t_min}` to `{t_max}`
|
||||
- Unique sessions by actor:
|
||||
{session_lines}
|
||||
- Rows by actor:
|
||||
{actor_lines}
|
||||
- Rows by record type:
|
||||
{record_lines}
|
||||
- Rows by actor x record type:
|
||||
{pair_lines}
|
||||
- Store modes:
|
||||
{store_lines}
|
||||
|
||||
## Source and Processing
|
||||
Data is collected from two local roots in the PHANTOM project:
|
||||
- `experiments/collected_data` (human sessions)
|
||||
- `experiments/agents/collected_data` (agent sessions)
|
||||
|
||||
Each session folder contains:
|
||||
- `int.json` (interaction events)
|
||||
- `price.json` (price quote logs)
|
||||
|
||||
The ETL does the following:
|
||||
- Normalizes both Kafka-envelope and flat payload formats
|
||||
- Flattens nested metadata fields into `metadata_*` columns
|
||||
- Preserves all raw rows (no deduplication)
|
||||
- Adds labels:
|
||||
- `actor_type` in `{{human, agent}}`
|
||||
- `is_agent` in `{{0, 1}}`
|
||||
- `record_type` in `{{interaction, price_log}}`
|
||||
|
||||
## Data Fields
|
||||
Core fields used for modeling:
|
||||
- `actor_type`, `is_agent`, `record_type`
|
||||
- `sessionId`, `experimentId`, `storeMode`, `ts`
|
||||
- `eventName`, `page`, `productId`, `price`, `userAgent`
|
||||
|
||||
Kafka provenance fields:
|
||||
- `kafka_partition_id`, `kafka_offset`, `kafka_timestamp_ms`, `kafka_compression`
|
||||
- `kafka_is_transactional`, `kafka_headers`, `kafka_key_*`, `kafka_value_*`
|
||||
|
||||
Flattened metadata fields currently present:
|
||||
{metadata_lines}
|
||||
|
||||
Top interaction events:
|
||||
{event_lines}
|
||||
|
||||
## Intended Uses
|
||||
- Human-vs-agent traffic classification
|
||||
- Session-level behavioral modeling
|
||||
- Dynamic pricing robustness analysis under agent-mediated reconnaissance
|
||||
|
||||
## Out-of-Scope Uses
|
||||
- Identity inference or user-level profiling
|
||||
- Credit, employment, insurance, or legal decision making
|
||||
|
||||
## Data Splits
|
||||
No official train/validation/test split is provided in the current release.
|
||||
Users should create time-aware or session-aware splits to avoid leakage.
|
||||
|
||||
## Privacy and Sensitive Content
|
||||
- `userAgent` and referrer metadata can be quasi-identifying in small samples.
|
||||
- Use care before publishing derived artifacts that can re-identify participants.
|
||||
|
||||
## Limitations
|
||||
- Data is generated in a controlled experiment platform, not a full production marketplace.
|
||||
- Agent traffic currently reflects the configured tasking and browser automation setup.
|
||||
- Coverage is stronger for `hotel` than `airline` in the current release.
|
||||
|
||||
## Citation
|
||||
If you use this dataset, cite the PHANTOM thesis project and link this dataset page.
|
||||
"""
|
||||
|
||||
|
||||
def build_card(input_csv: Path, output_md: Path) -> None:
|
||||
if not input_csv.exists():
|
||||
raise FileNotFoundError(f"Input CSV not found: {input_csv}")
|
||||
df = pd.read_csv(input_csv)
|
||||
card = _render_card(df)
|
||||
output_md.parent.mkdir(parents=True, exist_ok=True)
|
||||
output_md.write_text(card)
|
||||
print(f"wrote dataset card to {output_md}")
|
||||
|
||||
|
||||
def upload_card(
|
||||
card_path: Path, repo_id: str, path_in_repo: str, commit_message: str
|
||||
) -> None:
|
||||
if not card_path.exists():
|
||||
raise FileNotFoundError(f"Card file not found: {card_path}")
|
||||
|
||||
api = HfApi(token=_token())
|
||||
try:
|
||||
me = api.whoami(token=_token())
|
||||
except Exception as exc:
|
||||
detail = _exception_details(exc)
|
||||
raise RuntimeError(f"Hugging Face auth failed. Details: {detail}") from exc
|
||||
|
||||
user_name = me.get("name") or me.get("fullname") or "unknown"
|
||||
print(f"authenticated to HF as: {user_name}")
|
||||
|
||||
try:
|
||||
api.repo_info(repo_id=repo_id, repo_type="dataset")
|
||||
except Exception as exc:
|
||||
detail = _exception_details(exc)
|
||||
raise RuntimeError(
|
||||
f"Dataset repo '{repo_id}' is not accessible. Details: {detail}"
|
||||
) from exc
|
||||
|
||||
try:
|
||||
commit = api.upload_file(
|
||||
path_or_fileobj=str(card_path),
|
||||
path_in_repo=path_in_repo,
|
||||
repo_id=repo_id,
|
||||
repo_type="dataset",
|
||||
commit_message=commit_message,
|
||||
)
|
||||
except Exception as exc:
|
||||
detail = _exception_details(exc)
|
||||
raise RuntimeError(
|
||||
f"Card upload failed for '{repo_id}'. Details: {detail}"
|
||||
) from exc
|
||||
|
||||
print(f"uploaded dataset card to https://huggingface.co/datasets/{repo_id}")
|
||||
print(f"commit: {commit}")
|
||||
|
||||
|
||||
def _parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Build or upload whoclickedit dataset card"
|
||||
)
|
||||
sub = parser.add_subparsers(dest="command", required=True)
|
||||
|
||||
build = sub.add_parser("build", help="build card markdown from CSV")
|
||||
build.add_argument("--input", type=Path, default=DEFAULT_INPUT)
|
||||
build.add_argument("--output", type=Path, default=DEFAULT_OUTPUT)
|
||||
|
||||
upload = sub.add_parser("upload", help="upload existing card as dataset README.md")
|
||||
upload.add_argument("--input", type=Path, default=DEFAULT_OUTPUT)
|
||||
upload.add_argument("--repo", default=DEFAULT_REPO)
|
||||
upload.add_argument("--path-in-repo", default="README.md")
|
||||
upload.add_argument("--message", default="Add dataset card for whoclickedit")
|
||||
|
||||
both = sub.add_parser("build-upload", help="build card and upload to dataset repo")
|
||||
both.add_argument("--csv", type=Path, default=DEFAULT_INPUT)
|
||||
both.add_argument("--card", type=Path, default=DEFAULT_OUTPUT)
|
||||
both.add_argument("--repo", default=DEFAULT_REPO)
|
||||
both.add_argument("--path-in-repo", default="README.md")
|
||||
both.add_argument("--message", default="Add dataset card for whoclickedit")
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main() -> int:
|
||||
args = _parse_args()
|
||||
try:
|
||||
if args.command == "build":
|
||||
build_card(args.input, args.output)
|
||||
return 0
|
||||
|
||||
if args.command == "upload":
|
||||
upload_card(args.input, args.repo, args.path_in_repo, args.message)
|
||||
return 0
|
||||
|
||||
if args.command == "build-upload":
|
||||
build_card(args.csv, args.card)
|
||||
upload_card(args.card, args.repo, args.path_in_repo, args.message)
|
||||
return 0
|
||||
|
||||
raise ValueError(f"Unknown command: {args.command}")
|
||||
except Exception as exc:
|
||||
print(f"error: {exc}", file=sys.stderr)
|
||||
return 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
412
scripts/whoclicked_etl.py
Normal file
412
scripts/whoclicked_etl.py
Normal file
@@ -0,0 +1,412 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Build and upload a flattened who-clicked dataset from local collected_data."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import pandas as pd
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
|
||||
PROJECT_ROOT = Path(__file__).resolve().parent.parent
|
||||
DEFAULT_HUMAN_DIR = PROJECT_ROOT / "experiments" / "collected_data"
|
||||
DEFAULT_AGENT_DIR = PROJECT_ROOT / "experiments" / "agents" / "collected_data"
|
||||
DEFAULT_OUTPUT = PROJECT_ROOT / "experiments" / "exports" / "whoclicked.csv"
|
||||
DEFAULT_REPO = os.getenv("HF_WHOCLICKED_REPO", "velocitatem/whoclickedit")
|
||||
|
||||
BASE_COLUMNS = [
|
||||
"actor_type",
|
||||
"is_agent",
|
||||
"record_type",
|
||||
"topic",
|
||||
"source_session_dir",
|
||||
"source_file",
|
||||
"source_row_index",
|
||||
"ingest_format",
|
||||
"sessionId",
|
||||
"experimentId",
|
||||
"storeMode",
|
||||
"ts",
|
||||
"eventName",
|
||||
"page",
|
||||
"productId",
|
||||
"price",
|
||||
"userAgent",
|
||||
"kafka_partition_id",
|
||||
"kafka_offset",
|
||||
"kafka_timestamp_ms",
|
||||
"kafka_compression",
|
||||
"kafka_is_transactional",
|
||||
"kafka_headers",
|
||||
"kafka_key_payload",
|
||||
"kafka_key_encoding",
|
||||
"kafka_key_schema_id",
|
||||
"kafka_value_encoding",
|
||||
"kafka_value_schema_id",
|
||||
"kafka_value_size",
|
||||
]
|
||||
|
||||
|
||||
def _token() -> str | None:
|
||||
return os.getenv("HF_TOKEN") or None
|
||||
|
||||
|
||||
def _exception_details(exc: Exception) -> str:
|
||||
parts = [str(exc).strip()]
|
||||
response = getattr(exc, "response", None)
|
||||
if response is not None:
|
||||
status = getattr(response, "status_code", None)
|
||||
if status is not None:
|
||||
parts.append(f"HTTP {status}")
|
||||
text = getattr(response, "text", "")
|
||||
if text:
|
||||
text = text.strip()
|
||||
if text:
|
||||
parts.append(text[:500])
|
||||
return " | ".join(p for p in parts if p)
|
||||
|
||||
|
||||
def _flatten_dict(data: dict[str, Any], prefix: str = "") -> dict[str, Any]:
|
||||
flat: dict[str, Any] = {}
|
||||
for key, value in data.items():
|
||||
normalized_key = str(key).strip().replace(" ", "_")
|
||||
next_key = f"{prefix}_{normalized_key}" if prefix else normalized_key
|
||||
if isinstance(value, dict):
|
||||
flat.update(_flatten_dict(value, next_key))
|
||||
else:
|
||||
flat[next_key] = value
|
||||
return flat
|
||||
|
||||
|
||||
def _as_scalar(value: Any) -> Any:
|
||||
if isinstance(value, (dict, list, tuple)):
|
||||
return json.dumps(value, ensure_ascii=True, sort_keys=True)
|
||||
return value
|
||||
|
||||
|
||||
def _empty_envelope() -> dict[str, Any]:
|
||||
return {
|
||||
"kafka_partition_id": None,
|
||||
"kafka_offset": None,
|
||||
"kafka_timestamp_ms": None,
|
||||
"kafka_compression": None,
|
||||
"kafka_is_transactional": None,
|
||||
"kafka_headers": None,
|
||||
"kafka_key_payload": None,
|
||||
"kafka_key_encoding": None,
|
||||
"kafka_key_schema_id": None,
|
||||
"kafka_value_encoding": None,
|
||||
"kafka_value_schema_id": None,
|
||||
"kafka_value_size": None,
|
||||
}
|
||||
|
||||
|
||||
def _extract_payload_and_envelope(
|
||||
record: Any,
|
||||
) -> tuple[dict[str, Any], dict[str, Any], str]:
|
||||
if (
|
||||
isinstance(record, dict)
|
||||
and isinstance(record.get("value"), dict)
|
||||
and isinstance(record["value"].get("payload"), dict)
|
||||
):
|
||||
key = record.get("key") if isinstance(record.get("key"), dict) else {}
|
||||
value = record["value"]
|
||||
envelope = {
|
||||
"kafka_partition_id": record.get("partitionID"),
|
||||
"kafka_offset": record.get("offset"),
|
||||
"kafka_timestamp_ms": record.get("timestamp"),
|
||||
"kafka_compression": record.get("compression"),
|
||||
"kafka_is_transactional": record.get("isTransactional"),
|
||||
"kafka_headers": _as_scalar(record.get("headers")),
|
||||
"kafka_key_payload": key.get("payload"),
|
||||
"kafka_key_encoding": key.get("encoding"),
|
||||
"kafka_key_schema_id": key.get("schemaId"),
|
||||
"kafka_value_encoding": value.get("encoding"),
|
||||
"kafka_value_schema_id": value.get("schemaId"),
|
||||
"kafka_value_size": value.get("size"),
|
||||
}
|
||||
return dict(value["payload"]), envelope, "kafka_envelope"
|
||||
|
||||
if isinstance(record, dict):
|
||||
return dict(record), _empty_envelope(), "flat_payload"
|
||||
|
||||
return {}, _empty_envelope(), "unknown"
|
||||
|
||||
|
||||
def _load_json_list(path: Path) -> list[Any]:
|
||||
raw = json.loads(path.read_text())
|
||||
if not isinstance(raw, list):
|
||||
raise ValueError(f"Expected list in {path}, got {type(raw).__name__}")
|
||||
return raw
|
||||
|
||||
|
||||
def _normalize_file_rows(
|
||||
actor_type: str,
|
||||
is_agent: int,
|
||||
session_dir_name: str,
|
||||
source_file: str,
|
||||
records: list[Any],
|
||||
) -> list[dict[str, Any]]:
|
||||
record_type = "interaction" if source_file == "int.json" else "price_log"
|
||||
topic = "user-interactions" if record_type == "interaction" else "price-logs"
|
||||
|
||||
rows: list[dict[str, Any]] = []
|
||||
for idx, raw_record in enumerate(records):
|
||||
payload, envelope, ingest_format = _extract_payload_and_envelope(raw_record)
|
||||
metadata = payload.pop("metadata", None)
|
||||
|
||||
payload_flat = _flatten_dict(payload)
|
||||
row: dict[str, Any] = {
|
||||
"actor_type": actor_type,
|
||||
"is_agent": is_agent,
|
||||
"record_type": record_type,
|
||||
"topic": topic,
|
||||
"source_session_dir": session_dir_name,
|
||||
"source_file": source_file,
|
||||
"source_row_index": idx,
|
||||
"ingest_format": ingest_format,
|
||||
**envelope,
|
||||
}
|
||||
row.update({k: _as_scalar(v) for k, v in payload_flat.items()})
|
||||
|
||||
if isinstance(metadata, dict):
|
||||
metadata_flat = _flatten_dict(metadata, "metadata")
|
||||
row.update({k: _as_scalar(v) for k, v in metadata_flat.items()})
|
||||
elif metadata is not None:
|
||||
row["metadata_raw"] = _as_scalar(metadata)
|
||||
|
||||
rows.append(row)
|
||||
|
||||
return rows
|
||||
|
||||
|
||||
def _collect_rows_for_actor(
|
||||
actor_type: str, is_agent: int, base_dir: Path
|
||||
) -> list[dict[str, Any]]:
|
||||
if not base_dir.exists():
|
||||
raise FileNotFoundError(f"Directory not found: {base_dir}")
|
||||
|
||||
rows: list[dict[str, Any]] = []
|
||||
for session_dir in sorted(
|
||||
(p for p in base_dir.iterdir() if p.is_dir()), key=lambda p: p.name
|
||||
):
|
||||
for source_file in ("int.json", "price.json"):
|
||||
file_path = session_dir / source_file
|
||||
if not file_path.exists():
|
||||
continue
|
||||
records = _load_json_list(file_path)
|
||||
rows.extend(
|
||||
_normalize_file_rows(
|
||||
actor_type=actor_type,
|
||||
is_agent=is_agent,
|
||||
session_dir_name=session_dir.name,
|
||||
source_file=source_file,
|
||||
records=records,
|
||||
)
|
||||
)
|
||||
return rows
|
||||
|
||||
|
||||
def build_dataframe(human_dir: Path, agent_dir: Path) -> pd.DataFrame:
|
||||
rows = [
|
||||
*_collect_rows_for_actor("human", 0, human_dir),
|
||||
*_collect_rows_for_actor("agent", 1, agent_dir),
|
||||
]
|
||||
if not rows:
|
||||
return pd.DataFrame(columns=BASE_COLUMNS)
|
||||
|
||||
df = pd.DataFrame(rows)
|
||||
ordered_columns = [
|
||||
*BASE_COLUMNS,
|
||||
*sorted(c for c in df.columns if c not in BASE_COLUMNS),
|
||||
]
|
||||
return df[ordered_columns]
|
||||
|
||||
|
||||
def _print_summary(df: pd.DataFrame, output_path: Path) -> None:
|
||||
print(f"wrote {len(df)} rows and {len(df.columns)} columns to {output_path}")
|
||||
if df.empty:
|
||||
return
|
||||
|
||||
print("rows by actor/record_type:")
|
||||
grouped = (
|
||||
df.groupby(["actor_type", "record_type"], dropna=False)
|
||||
.size()
|
||||
.reset_index(name="count")
|
||||
.sort_values(["actor_type", "record_type"])
|
||||
)
|
||||
for _, row in grouped.iterrows():
|
||||
print(f" - {row['actor_type']} / {row['record_type']}: {int(row['count'])}")
|
||||
|
||||
required = ["actor_type", "is_agent", "record_type", "sessionId", "ts"]
|
||||
missing = {col: int(df[col].isna().sum()) for col in required if col in df.columns}
|
||||
print(f"missing in required columns: {missing}")
|
||||
|
||||
|
||||
def build_csv(human_dir: Path, agent_dir: Path, output: Path) -> pd.DataFrame:
|
||||
df = build_dataframe(human_dir=human_dir, agent_dir=agent_dir)
|
||||
output.parent.mkdir(parents=True, exist_ok=True)
|
||||
df.to_csv(output, index=False)
|
||||
_print_summary(df, output)
|
||||
return df
|
||||
|
||||
|
||||
def _resolve_repo_id(api: HfApi, repo_id: str) -> str:
|
||||
if "/" in repo_id:
|
||||
return repo_id
|
||||
try:
|
||||
me = api.whoami(token=_token())
|
||||
username = me.get("name")
|
||||
if username:
|
||||
return f"{username}/{repo_id}"
|
||||
except Exception:
|
||||
pass
|
||||
return repo_id
|
||||
|
||||
|
||||
def upload_csv(
|
||||
input_path: Path,
|
||||
repo_id: str,
|
||||
path_in_repo: str,
|
||||
commit_message: str,
|
||||
create_if_missing: bool = False,
|
||||
) -> None:
|
||||
if not input_path.exists():
|
||||
raise FileNotFoundError(f"Input CSV not found: {input_path}")
|
||||
|
||||
api = HfApi(token=_token())
|
||||
|
||||
try:
|
||||
me = api.whoami(token=_token())
|
||||
except Exception as exc:
|
||||
detail = _exception_details(exc)
|
||||
hint = "Set HF_TOKEN with write access or run huggingface-cli login."
|
||||
raise RuntimeError(
|
||||
f"Hugging Face auth failed. {hint} Details: {detail}"
|
||||
) from exc
|
||||
|
||||
user_name = me.get("name") or me.get("fullname") or "unknown"
|
||||
print(f"authenticated to HF as: {user_name}")
|
||||
|
||||
resolved_repo_id = _resolve_repo_id(api, repo_id)
|
||||
if create_if_missing:
|
||||
api.create_repo(repo_id=resolved_repo_id, repo_type="dataset", exist_ok=True)
|
||||
else:
|
||||
try:
|
||||
api.repo_info(repo_id=resolved_repo_id, repo_type="dataset")
|
||||
except Exception as exc:
|
||||
detail = _exception_details(exc)
|
||||
hint = (
|
||||
"Check owner/repo spelling, ensure it is a dataset repo, "
|
||||
"or pass --create-if-missing."
|
||||
)
|
||||
raise RuntimeError(
|
||||
f"Dataset repo '{resolved_repo_id}' is not accessible. {hint} Details: {detail}"
|
||||
) from exc
|
||||
|
||||
try:
|
||||
commit = api.upload_file(
|
||||
path_or_fileobj=str(input_path),
|
||||
path_in_repo=path_in_repo,
|
||||
repo_id=resolved_repo_id,
|
||||
repo_type="dataset",
|
||||
commit_message=commit_message,
|
||||
)
|
||||
except Exception as exc:
|
||||
detail = _exception_details(exc)
|
||||
hint = (
|
||||
"Pass --repo <owner>/whoclickedit and ensure HF_TOKEN is set "
|
||||
"(or run huggingface-cli login)."
|
||||
)
|
||||
raise RuntimeError(
|
||||
f"Upload failed for '{resolved_repo_id}'. {hint} Details: {detail}"
|
||||
) from exc
|
||||
|
||||
print(
|
||||
f"uploaded {input_path} to https://huggingface.co/datasets/{resolved_repo_id}"
|
||||
)
|
||||
print(f"commit: {commit}")
|
||||
|
||||
|
||||
def _parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="ETL for whoclickedit: flatten local collected_data and upload to HF"
|
||||
)
|
||||
sub = parser.add_subparsers(dest="command", required=True)
|
||||
|
||||
build = sub.add_parser("build", help="build flattened CSV locally")
|
||||
build.add_argument("--human-dir", type=Path, default=DEFAULT_HUMAN_DIR)
|
||||
build.add_argument("--agent-dir", type=Path, default=DEFAULT_AGENT_DIR)
|
||||
build.add_argument("--output", type=Path, default=DEFAULT_OUTPUT)
|
||||
|
||||
upload = sub.add_parser("upload", help="upload an existing CSV to HF dataset")
|
||||
upload.add_argument("--input", type=Path, default=DEFAULT_OUTPUT)
|
||||
upload.add_argument("--repo", default=DEFAULT_REPO)
|
||||
upload.add_argument("--path-in-repo", default="whoclicked.csv")
|
||||
upload.add_argument("--message", default="Update flattened whoclickedit dataset")
|
||||
upload.add_argument("--create-if-missing", action="store_true")
|
||||
|
||||
build_upload = sub.add_parser(
|
||||
"build-upload", help="build CSV and upload to HF dataset"
|
||||
)
|
||||
build_upload.add_argument("--human-dir", type=Path, default=DEFAULT_HUMAN_DIR)
|
||||
build_upload.add_argument("--agent-dir", type=Path, default=DEFAULT_AGENT_DIR)
|
||||
build_upload.add_argument("--output", type=Path, default=DEFAULT_OUTPUT)
|
||||
build_upload.add_argument("--repo", default=DEFAULT_REPO)
|
||||
build_upload.add_argument("--path-in-repo", default="whoclicked.csv")
|
||||
build_upload.add_argument(
|
||||
"--message", default="Update flattened whoclickedit dataset"
|
||||
)
|
||||
build_upload.add_argument("--create-if-missing", action="store_true")
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main() -> int:
|
||||
args = _parse_args()
|
||||
|
||||
try:
|
||||
if args.command == "build":
|
||||
build_csv(
|
||||
human_dir=args.human_dir, agent_dir=args.agent_dir, output=args.output
|
||||
)
|
||||
return 0
|
||||
|
||||
if args.command == "upload":
|
||||
upload_csv(
|
||||
input_path=args.input,
|
||||
repo_id=args.repo,
|
||||
path_in_repo=args.path_in_repo,
|
||||
commit_message=args.message,
|
||||
create_if_missing=args.create_if_missing,
|
||||
)
|
||||
return 0
|
||||
|
||||
if args.command == "build-upload":
|
||||
build_csv(
|
||||
human_dir=args.human_dir, agent_dir=args.agent_dir, output=args.output
|
||||
)
|
||||
upload_csv(
|
||||
input_path=args.output,
|
||||
repo_id=args.repo,
|
||||
path_in_repo=args.path_in_repo,
|
||||
commit_message=args.message,
|
||||
create_if_missing=args.create_if_missing,
|
||||
)
|
||||
return 0
|
||||
|
||||
raise ValueError(f"Unknown command: {args.command}")
|
||||
|
||||
except Exception as exc:
|
||||
print(f"error: {exc}", file=sys.stderr)
|
||||
return 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
Reference in New Issue
Block a user