From 63f1aad0b9df87f33f35ba54b0321c50069924a9 Mon Sep 17 00:00:00 2001 From: Daniel Rosel Date: Mon, 16 Mar 2026 15:18:38 +0100 Subject: [PATCH] chore: including new scritps for automation --- scripts/launch_calibration_screen.sh | 38 +++ scripts/setuptpu.sh | 9 + scripts/wandb_compare_best.py | 333 ++++++++++++++++++++++ scripts/wandb_create_sweep.py | 313 ++++++++++++++++++++ scripts/whoclicked_card.py | 342 ++++++++++++++++++++++ scripts/whoclicked_etl.py | 412 +++++++++++++++++++++++++++ 6 files changed, 1447 insertions(+) create mode 100755 scripts/launch_calibration_screen.sh create mode 100644 scripts/setuptpu.sh create mode 100644 scripts/wandb_compare_best.py create mode 100644 scripts/wandb_create_sweep.py create mode 100644 scripts/whoclicked_card.py create mode 100644 scripts/whoclicked_etl.py diff --git a/scripts/launch_calibration_screen.sh b/scripts/launch_calibration_screen.sh new file mode 100755 index 0000000..6e312a5 --- /dev/null +++ b/scripts/launch_calibration_screen.sh @@ -0,0 +1,38 @@ +#!/usr/bin/env bash + +set -euo pipefail + +ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" + +export RAY_MODE="${RAY_MODE:-sweep}" +export SWEEP_KIND="${SWEEP_KIND:-ppo_block_a}" +export SWEEP_METHOD="${SWEEP_METHOD:-grid}" +export SWEEP_PROFILE="${SWEEP_PROFILE:-default}" +export SWEEP_RUN_CAP="${SWEEP_RUN_CAP:-27}" +export COMPARE_ROBUST="${COMPARE_ROBUST:-1}" +export NUM_NODES="${NUM_NODES:-3}" +export AGENTS_PER_NODE="${AGENTS_PER_NODE:-4}" +export AGENT_COUNT="${AGENT_COUNT:-0}" +export INNER_THREADS="${INNER_THREADS:-1}" +export PHANTOM_JAX_PLATFORM="${PHANTOM_JAX_PLATFORM:-cpu}" +export OUTPUT_ROOT="${OUTPUT_ROOT:-engine/studies/results/block_a_sweep}" + +if [ -z "${WORKER_CPUS:-}" ]; then + export WORKER_CPUS="$((AGENTS_PER_NODE * INNER_THREADS))" +fi + +printf '%s\n' "Launching Block A PPO calibration sweep" +printf '%s\n' "RAY_MODE=$RAY_MODE" +printf '%s\n' "SWEEP_KIND=$SWEEP_KIND" +printf '%s\n' "SWEEP_METHOD=$SWEEP_METHOD" +printf '%s\n' "SWEEP_RUN_CAP=$SWEEP_RUN_CAP" +printf '%s\n' "COMPARE_ROBUST=$COMPARE_ROBUST" +printf '%s\n' "NUM_NODES=$NUM_NODES" +printf '%s\n' "AGENTS_PER_NODE=$AGENTS_PER_NODE" +printf '%s\n' "AGENT_COUNT=$AGENT_COUNT" +printf '%s\n' "INNER_THREADS=$INNER_THREADS" +printf '%s\n' "WORKER_CPUS=$WORKER_CPUS" +printf '%s\n' "OUTPUT_ROOT=$OUTPUT_ROOT" + +cd "$ROOT" +bash ./submit_ray_job.sh diff --git a/scripts/setuptpu.sh b/scripts/setuptpu.sh new file mode 100644 index 0000000..041266d --- /dev/null +++ b/scripts/setuptpu.sh @@ -0,0 +1,9 @@ +commands = ( + "pip install \"jax[tpu]\" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html" + "pip install stable-baselines3>=2.2.0 gymnasium wandb tensorboard" + + +" + + +) diff --git a/scripts/wandb_compare_best.py b/scripts/wandb_compare_best.py new file mode 100644 index 0000000..544f9d8 --- /dev/null +++ b/scripts/wandb_compare_best.py @@ -0,0 +1,333 @@ +from __future__ import annotations + +import argparse +import json +import os +import shlex +import subprocess +import sys +from datetime import datetime, timezone +from pathlib import Path +from typing import Any + + +def _truthy(value: Any) -> bool: + if isinstance(value, bool): + return value + if value is None: + return False + return str(value).strip().lower() in {"1", "true", "yes", "on"} + + +def _as_float(value: Any, default: float) -> float: + try: + return float(value) + except (TypeError, ValueError): + return float(default) + + +def _as_int(value: Any, default: int) -> int: + try: + return int(float(value)) + except (TypeError, ValueError): + return int(default) + + +def _normalize_sweep_id( + raw: str, entity: str, project: str +) -> tuple[str, str, str, str]: + sweep_raw = str(raw).strip() + if not sweep_raw: + raise ValueError("--sweep-id is required") + parts = [piece.strip() for piece in sweep_raw.split("/") if piece.strip()] + if len(parts) == 3: + return f"{parts[0]}/{parts[1]}/{parts[2]}", parts[0], parts[1], parts[2] + if len(parts) == 2: + if not entity.strip(): + raise ValueError("--entity is required when --sweep-id is '/'") + return f"{entity}/{parts[0]}/{parts[1]}", entity, parts[0], parts[1] + if len(parts) == 1: + if not entity.strip() or not project.strip(): + raise ValueError( + "--entity and --project are required when --sweep-id is ''" + ) + return f"{entity}/{project}/{parts[0]}", entity, project, parts[0] + raise ValueError(f"invalid --sweep-id value: '{raw}'") + + +def _pick_best_defended_run( + sweep: Any, + metric: str, + *, + min_margin: float, + min_coi: float, +) -> tuple[Any, float]: + ranked: list[tuple[float, Any]] = [] + for run in list(sweep.runs): + if str(getattr(run, "state", "")).lower() != "finished": + continue + cfg = dict(getattr(run, "config", {}) or {}) + is_baseline = ( + _truthy(cfg.get("baseline_mode")) + if "baseline_mode" in cfg + else _truthy(cfg.get("no_robust")) + ) + if is_baseline: + continue + summary = dict(getattr(run, "summary", {}) or {}) + margin = _as_float(summary.get("eval/margin_mean"), -1.0) + coi_level = _as_float(summary.get("eval/coi_level_mean"), -1.0) + if margin < float(min_margin): + continue + if coi_level < float(min_coi): + continue + score = summary.get(metric) + if score is None and str(metric) == "eval/stress_revenue_worst": + score = summary.get("eval/robust_revenue_worst") + if score is None: + continue + try: + ranked.append((float(score), run)) + except (TypeError, ValueError): + continue + if not ranked: + raise RuntimeError( + f"no finished defended runs found with summary metric '{metric}' and constraints " + f"margin>={min_margin}, coi>={min_coi}" + ) + ranked.sort(key=lambda item: item[0], reverse=True) + return ranked[0][1], ranked[0][0] + + +def _format_alpha_values(raw: str, fallback_alpha: float) -> str: + cleaned = str(raw).strip() + if cleaned: + return cleaned + return f"{float(fallback_alpha):.6g}" + + +def _benchmark_tokens( + *, + project: str, + cfg: dict[str, Any], + alpha_values: str, + episodes: int, +) -> list[str]: + algo = str(cfg.get("algo", "")).strip().lower() + if algo not in {"qtable", "ppo", "a2c", "dqn"}: + raise ValueError(f"unsupported algo in best run: '{algo}'") + + total_timesteps = _as_int(cfg.get("total_timesteps"), 80_000) + max_steps = _as_int(cfg.get("max_steps"), 100) + ambiguity_radius = _as_float( + cfg.get("ambiguity_radius", cfg.get("robust_radius")), 0.2 + ) + ambiguity_points = _as_int(cfg.get("ambiguity_points", cfg.get("robust_points")), 7) + ambiguity_rollouts = _as_int( + cfg.get("ambiguity_rollouts", cfg.get("robust_rollouts")), 1 + ) + lambda_coi = _as_float(cfg.get("lambda_coi"), 0.2) + eta_ux = _as_float(cfg.get("eta_ux"), 0.5) + reward_profit_weight = _as_float(cfg.get("reward_profit_weight"), 1.0) + learning_rate = _as_float(cfg.get("learning_rate"), 3e-4) + batch_size = _as_int(cfg.get("batch_size"), 256) + n_steps = _as_int(cfg.get("n_steps"), 2048) + sessions = _as_int(cfg.get("N"), 100) + action_levels = _as_int(cfg.get("action_levels"), 9) + margin_floor = _as_float(cfg.get("margin_floor"), 0.85) + seed = _as_int(cfg.get("seed"), 42) + + return [ + "--project", + project, + "--tiers", + algo, + "--alpha-values", + alpha_values, + "--episodes", + str(int(episodes)), + "--seed", + str(seed), + "--total-timesteps", + str(total_timesteps), + "--max-steps", + str(max_steps), + "--robust-radius", + str(ambiguity_radius), + "--robust-points", + str(ambiguity_points), + "--robust-rollouts", + str(ambiguity_rollouts), + "--lambda-coi", + str(lambda_coi), + "--eta-ux", + str(eta_ux), + "--reward-profit-weight", + str(reward_profit_weight), + "--learning-rate", + str(learning_rate), + "--batch-size", + str(batch_size), + "--n-steps", + str(n_steps), + "--N", + str(sessions), + "--action-levels", + str(action_levels), + "--margin-floor", + str(margin_floor), + "--device", + "cpu", + ] + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Find best defended sweep run and prepare defended-vs-baseline benchmark" + ) + parser.add_argument("--sweep-id", required=True) + parser.add_argument("--entity", default="") + parser.add_argument("--project", default="") + parser.add_argument("--metric", default="eval/stress_revenue_worst") + parser.add_argument("--min-margin", type=float, default=0.90) + parser.add_argument("--min-coi", type=float, default=120.0) + parser.add_argument("--alpha-values", default="") + parser.add_argument("--episodes", type=int, default=15) + parser.add_argument("--num-nodes", type=int, default=4) + parser.add_argument("--tpu-per-task", type=float, default=0.0) + parser.add_argument("--inner-workers", type=int, default=12) + parser.add_argument("--inner-threads", type=int, default=1) + parser.add_argument("--max-heavy-workers", type=int, default=3) + parser.add_argument("--worker-cpus", type=int, default=24) + parser.add_argument( + "--output-root", default="engine/studies/results/overnight/best_compare" + ) + parser.add_argument("--timeout", type=int, default=120) + parser.add_argument("--submit", action="store_true") + parser.add_argument("--ray-no-wait", action="store_true") + parser.add_argument("--submission-id", default="") + parser.add_argument("--output-json", default="") + args = parser.parse_args() + + root = Path(__file__).resolve().parents[1] + cwd = str(Path.cwd()) + sys.path = [p for p in sys.path if p not in {"", cwd}] + + try: + import wandb + except ImportError as exc: + raise ImportError("wandb is required") from exc + + full_sweep_id, entity, project, _ = _normalize_sweep_id( + raw=str(args.sweep_id), + entity=str(args.entity).strip(), + project=str(args.project).strip(), + ) + api = wandb.Api(timeout=int(args.timeout)) + sweep = api.sweep(full_sweep_id) + best_run, best_score = _pick_best_defended_run( + sweep, + str(args.metric), + min_margin=float(args.min_margin), + min_coi=float(args.min_coi), + ) + + best_cfg = dict(getattr(best_run, "config", {}) or {}) + best_alpha = _as_float( + best_cfg.get( + "alpha", + getattr(best_run, "summary", {}).get("study/alpha", 0.6), + ), + 0.6, + ) + alpha_values = _format_alpha_values( + str(args.alpha_values), fallback_alpha=best_alpha + ) + benchmark_tokens = _benchmark_tokens( + project=project, + cfg=best_cfg, + alpha_values=alpha_values, + episodes=int(args.episodes), + ) + benchmark_args = shlex.join(benchmark_tokens) + + submission_id = str(args.submission_id).strip() + if not submission_id: + stamp = datetime.now(timezone.utc).strftime("%m%d-%H%M") + submission_id = f"best-compare-{stamp}" + + env_overrides = { + "RAY_MODE": "benchmark", + "COMPARE_ROBUST": "1", + "NUM_NODES": str(int(args.num_nodes)), + "TPU_PER_TASK": str(float(args.tpu_per_task)), + "PHANTOM_JAX_PLATFORM": "cpu", + "WANDB_ENTITY": entity, + "WANDB_PROJECT": project, + "BENCHMARK_ARGS": benchmark_args, + "INNER_WORKERS": str(int(args.inner_workers)), + "INNER_THREADS": str(int(args.inner_threads)), + "MAX_HEAVY_WORKERS": str(int(args.max_heavy_workers)), + "WORKER_CPUS": str(int(args.worker_cpus)), + "OUTPUT_ROOT": str(args.output_root), + "SUBMISSION_ID": submission_id, + } + if bool(args.ray_no_wait): + env_overrides["RAY_NO_WAIT"] = "1" + + command_str = ( + "cd " + + shlex.quote(str(root)) + + " && " + + " ".join( + f"{key}={shlex.quote(str(value))}" for key, value in env_overrides.items() + ) + + " bash ./submit_ray_job.sh" + ) + + payload = { + "sweep_id": full_sweep_id, + "selection_metric": str(args.metric), + "constraints": { + "min_margin": float(args.min_margin), + "min_coi": float(args.min_coi), + }, + "best_run": { + "id": str(getattr(best_run, "id", "")), + "name": str(getattr(best_run, "name", "")), + "url": str(getattr(best_run, "url", "")), + "score": float(best_score), + "algo": str(best_cfg.get("algo", "")), + "alpha": float(best_alpha), + "eval_margin_mean": _as_float( + getattr(best_run, "summary", {}).get("eval/margin_mean"), 0.0 + ), + "eval_coi_level_mean": _as_float( + getattr(best_run, "summary", {}).get("eval/coi_level_mean"), 0.0 + ), + }, + "benchmark_compare_command": command_str, + } + print(json.dumps(payload, indent=2)) + + output_json = str(args.output_json).strip() + if output_json: + out_path = Path(output_json) + if not out_path.is_absolute(): + out_path = root / out_path + out_path.parent.mkdir(parents=True, exist_ok=True) + out_path.write_text(json.dumps(payload, indent=2) + "\n") + + if bool(args.submit): + run_env = dict(os.environ) + run_env.update({key: str(value) for key, value in env_overrides.items()}) + subprocess.run( + ["bash", "./submit_ray_job.sh"], + cwd=str(root), + env=run_env, + check=True, + ) + + +if __name__ == "__main__": + main() diff --git a/scripts/wandb_create_sweep.py b/scripts/wandb_create_sweep.py new file mode 100644 index 0000000..e44354a --- /dev/null +++ b/scripts/wandb_create_sweep.py @@ -0,0 +1,313 @@ +from __future__ import annotations + +import argparse +import contextlib +import io +import json +import sys +from pathlib import Path +from typing import Any + + +def _base_sweep(method: str, metric_name: str) -> dict[str, Any]: + return { + "method": str(method), + "metric": {"name": str(metric_name), "goal": "maximize"}, + } + + +def _benchmark_sweep(method: str) -> dict[str, Any]: + cfg = _base_sweep(method=method, metric_name="objective/score") + cfg["name"] = "benchmark-all-algos-defense" + cfg["parameters"] = { + "tiers": { + "values": [ + "static", + "surge", + "linear", + "qtable", + "ppo", + "a2c", + "dqn", + ] + }, + "alpha_values": {"values": ["0.0", "0.1", "0.25", "0.4", "0.6", "0.8"]}, + "baseline_mode": {"values": [False, True]}, + "seed": {"values": [42, 1337, 2026, 7777]}, + "episodes": {"values": [8, 12]}, + "total_timesteps": {"values": [15000, 30000, 50000]}, + "lambda_coi": {"values": [0.1, 0.2, 0.4]}, + "ambiguity_radius": {"values": [0.1, 0.2, 0.3]}, + "ambiguity_points": {"values": [5, 7]}, + "ambiguity_rollouts": {"values": [1, 2]}, + "eta_ux": {"values": [0.25, 0.5, 0.75]}, + "reward_profit_weight": {"values": [0.75, 1.0, 1.25]}, + "learning_rate": {"values": [1e-4, 3e-4, 1e-3]}, + "batch_size": {"values": [128, 256, 512]}, + "n_steps": {"values": [1024, 2048, 4096]}, + "device": {"value": "cpu"}, + } + return cfg + + +def _train_sweep(method: str) -> dict[str, Any]: + cfg = _base_sweep(method=method, metric_name="objective/score") + cfg["name"] = "train-all-algos-defense" + cfg["parameters"] = { + "algo": {"values": ["qtable", "ppo", "a2c", "dqn"]}, + "alpha": {"values": [0.0, 0.1, 0.25, 0.4, 0.6]}, + "baseline_mode": {"values": [False, True]}, + "seed": {"values": [42, 1337, 2026, 7777]}, + "total_timesteps": {"values": [30000, 50000, 80000]}, + "learning_rate": {"values": [1e-4, 3e-4, 1e-3]}, + "batch_size": {"values": [128, 256, 512]}, + "n_steps": {"values": [1024, 2048, 4096]}, + "lambda_coi": {"values": [0.1, 0.2, 0.4]}, + "ambiguity_radius": {"values": [0.1, 0.2, 0.3]}, + "ambiguity_points": {"values": [3, 5, 7]}, + "ambiguity_rollouts": {"values": [1, 2]}, + "eta_ux": {"values": [0.25, 0.5, 0.75]}, + "reward_profit_weight": {"values": [0.75, 1.0, 1.25]}, + "N": {"values": [80, 100, 140]}, + "max_steps": {"values": [80, 100, 120]}, + "action_levels": {"values": [7, 9, 11]}, + "device": {"value": "cpu"}, + } + return cfg + + +def _train_robust_revenue_sweep(method: str) -> dict[str, Any]: + cfg = _base_sweep(method=method, metric_name="eval/stress_revenue_worst") + cfg["name"] = "train-defense-revenue-search" + cfg["parameters"] = { + "algo": {"values": ["qtable", "ppo", "a2c", "dqn"]}, + "alpha": {"values": [0.4, 0.6, 0.8]}, + "baseline_mode": {"value": False}, + "seed": {"values": [42, 1337, 2026, 7777]}, + "total_timesteps": {"values": [60_000, 80_000, 120_000]}, + "learning_rate": {"values": [1e-4, 3e-4, 1e-3]}, + "batch_size": {"values": [128, 256, 512]}, + "n_steps": {"values": [1024, 2048, 4096]}, + "lambda_coi": {"values": [0.2, 0.4, 0.6]}, + "ambiguity_radius": {"values": [0.1, 0.2, 0.3]}, + "ambiguity_points": {"values": [5, 7, 9]}, + "ambiguity_rollouts": {"values": [1, 2]}, + "eta_ux": {"values": [0.25, 0.5, 0.75]}, + "reward_profit_weight": {"values": [1.0, 1.25]}, + "N": {"values": [80, 100, 140]}, + "max_steps": {"values": [80, 100, 120]}, + "action_levels": {"values": [7, 9, 11]}, + "margin_floor": {"value": 0.85}, + "device": {"value": "cpu"}, + } + return cfg + + +def _ppo_calibration_sweep(method: str) -> dict[str, Any]: + cfg = _base_sweep(method=method, metric_name="objective/score") + cfg["name"] = "benchmark-ppo-calibration" + cfg["parameters"] = { + "tiers": {"value": "ppo"}, + "alpha_values": {"values": ["0.0", "0.1", "0.25", "0.4", "0.6", "0.8"]}, + "baseline_mode": {"values": [False, True]}, + "seed": {"values": [42, 1337, 2026, 7777]}, + "episodes": {"value": 12}, + "total_timesteps": {"value": 60000}, + "lambda_coi": { + "distribution": "uniform", + "min": 0.05, + "max": 0.6, + }, + "ambiguity_radius": { + "distribution": "uniform", + "min": 0.05, + "max": 0.45, + }, + "ambiguity_points": {"value": 7}, + "ambiguity_rollouts": {"value": 1}, + "eta_ux": {"value": 0.5}, + "reward_profit_weight": {"value": 1.0}, + "learning_rate": { + "distribution": "log_uniform_values", + "min": 1e-4, + "max": 1e-3, + }, + "batch_size": {"values": [128, 256, 512]}, + "n_steps": {"values": [1024, 2048, 4096]}, + "device": {"value": "cpu"}, + } + return cfg + + +def _ppo_block_a_sweep(method: str) -> dict[str, Any]: + cfg = _base_sweep(method=method, metric_name="objective/score") + cfg["name"] = "benchmark-ppo-block-a-calibration" + cfg["parameters"] = { + "tiers": {"value": "ppo"}, + "alpha_values": {"value": "0.25,0.6,0.8"}, + "seed": {"values": [42, 1337, 2026]}, + "episodes": {"value": 12}, + "total_timesteps": {"value": 80000}, + "lambda_coi": {"values": [0.05, 0.1, 0.2]}, + "ambiguity_radius": {"values": [0.05, 0.1, 0.2]}, + "ambiguity_points": {"value": 7}, + "ambiguity_rollouts": {"value": 1}, + "eta_ux": {"value": 0.5}, + "reward_profit_weight": {"value": 1.0}, + "learning_rate": {"value": 3e-4}, + "batch_size": {"value": 256}, + "n_steps": {"value": 2048}, + "device": {"value": "cpu"}, + } + return cfg + + +def _ppo_shift_screen_sweep(method: str) -> dict[str, Any]: + cfg = _base_sweep(method=method, metric_name="objective/score") + cfg["name"] = "benchmark-ppo-shift-screen" + cfg["parameters"] = { + "tiers": {"value": "ppo"}, + "alpha_values": {"value": "0.25"}, + "eval_alpha_values": {"value": "0.6,0.8"}, + "seed": {"values": [42, 1337, 2026]}, + "episodes": {"value": 20}, + "total_timesteps": {"value": 80000}, + "lambda_coi": {"values": [0.0, 0.02, 0.05, 0.1]}, + "ambiguity_radius": {"values": [0.0, 0.02, 0.05, 0.1]}, + "ambiguity_points": {"value": 5}, + "ambiguity_rollouts": {"value": 1}, + "eta_ux": {"value": 0.0}, + "reward_profit_weight": {"value": 1.0}, + "learning_rate": {"value": 3e-4}, + "batch_size": {"value": 256}, + "n_steps": {"value": 2048}, + "device": {"value": "cpu"}, + } + return cfg + + +def _ppo_rl_study_sweep(method: str) -> dict[str, Any]: + cfg = _base_sweep(method=method, metric_name="eval/stress_revenue_worst") + cfg["name"] = "train-ppo-standard-vs-defended-equilibrium" + cfg["parameters"] = { + "algo": {"value": "ppo"}, + "seed": {"values": [42, 1337, 7777]}, + "alpha": {"values": [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]}, + "n_products": {"values": [5, 25, 50, 100]}, + "N": {"value": 100}, + "no_robust": {"values": [False, True]}, + "lambda_coi": {"values": [0.05, 0.15, 0.3]}, + "ambiguity_radius": {"values": [0.1, 0.2, 0.3]}, + "ambiguity_points": {"value": 7}, + "ambiguity_rollouts": {"value": 1}, + "eta_ux": {"value": 0.0}, + "reward_profit_weight": {"value": 1.0}, + "total_timesteps": {"value": 100000}, + "eval_episodes": {"value": 10}, + "eval_freq": {"value": 1000}, + "log_freq": {"value": 100}, + "hist_freq": {"value": 500}, + "learning_rate": {"value": 3e-4}, + "batch_size": {"value": 256}, + "n_steps": {"value": 2048}, + "device": {"value": "cpu"}, + } + return cfg + + +def main() -> None: + parser = argparse.ArgumentParser(description="Create W&B sweep for PHANTOM") + parser.add_argument( + "--kind", + choices=[ + "benchmark", + "train", + "ppo_calibration", + "ppo_block_a", + "ppo_shift_screen", + "ppo_rl_study", + ], + default="benchmark", + ) + parser.add_argument( + "--profile", + choices=["default", "robust_revenue"], + default="default", + ) + parser.add_argument("--project", required=True) + parser.add_argument("--entity", default="") + parser.add_argument( + "--method", choices=["random", "bayes", "grid"], default="random" + ) + parser.add_argument("--run-cap", type=int, default=0) + parser.add_argument("--json", action="store_true") + parser.add_argument("--full-id", action="store_true") + args = parser.parse_args() + + cwd = str(Path.cwd()) + sys.path = [p for p in sys.path if p not in {"", cwd}] + + try: + import wandb + except ImportError as exc: + raise ImportError("wandb is required to create sweeps") from exc + + if str(args.kind) == "benchmark": + if str(args.profile) != "default": + raise ValueError("benchmark sweeps only support --profile default") + sweep_cfg = _benchmark_sweep(args.method) + elif str(args.kind) == "train": + if str(args.profile) == "robust_revenue": + sweep_cfg = _train_robust_revenue_sweep(args.method) + else: + sweep_cfg = _train_sweep(args.method) + elif str(args.kind) == "ppo_calibration": + if str(args.profile) != "default": + raise ValueError("ppo_calibration sweeps only support --profile default") + sweep_cfg = _ppo_calibration_sweep(args.method) + elif str(args.kind) == "ppo_block_a": + if str(args.profile) != "default": + raise ValueError("ppo_block_a sweeps only support --profile default") + sweep_cfg = _ppo_block_a_sweep(args.method) + elif str(args.kind) == "ppo_shift_screen": + if str(args.profile) != "default": + raise ValueError("ppo_shift_screen sweeps only support --profile default") + sweep_cfg = _ppo_shift_screen_sweep(args.method) + else: + if str(args.profile) != "default": + raise ValueError("ppo_rl_study sweeps only support --profile default") + sweep_cfg = _ppo_rl_study_sweep(args.method) + if int(args.run_cap) > 0: + sweep_cfg["run_cap"] = int(args.run_cap) + + with contextlib.redirect_stdout(io.StringIO()): + sweep_id = wandb.sweep( + sweep=sweep_cfg, + project=str(args.project), + entity=str(args.entity) if str(args.entity).strip() else None, + ) + full_id = ( + f"{args.entity}/{args.project}/{sweep_id}" + if str(args.entity).strip() + else f"{args.project}/{sweep_id}" + ) + + if bool(args.json): + print( + json.dumps( + { + "kind": str(args.kind), + "profile": str(args.profile), + "project": str(args.project), + "entity": str(args.entity), + "sweep_id": str(sweep_id), + "full_id": str(full_id), + } + ) + ) + return + print(full_id if bool(args.full_id) else sweep_id) + + +if __name__ == "__main__": + main() diff --git a/scripts/whoclicked_card.py b/scripts/whoclicked_card.py new file mode 100644 index 0000000..8b5e4b2 --- /dev/null +++ b/scripts/whoclicked_card.py @@ -0,0 +1,342 @@ +#!/usr/bin/env python3 +"""Build and upload a Hugging Face dataset card for whoclickedit.""" + +from __future__ import annotations + +import argparse +import os +import sys +from pathlib import Path +from typing import Any + +import pandas as pd +from huggingface_hub import HfApi + + +PROJECT_ROOT = Path(__file__).resolve().parent.parent +DEFAULT_INPUT = PROJECT_ROOT / "experiments" / "exports" / "whoclicked.csv" +DEFAULT_OUTPUT = PROJECT_ROOT / "experiments" / "exports" / "whoclicked_dataset_card.md" +DEFAULT_REPO = os.getenv("HF_WHOCLICKED_REPO", "velocitatem/whoclickedit") + + +def _token() -> str | None: + return os.getenv("HF_TOKEN") or None + + +def _exception_details(exc: Exception) -> str: + parts = [str(exc).strip()] + response = getattr(exc, "response", None) + if response is not None: + status = getattr(response, "status_code", None) + if status is not None: + parts.append(f"HTTP {status}") + text = getattr(response, "text", "") + if text: + parts.append(text.strip()[:500]) + return " | ".join(p for p in parts if p) + + +def _size_category(n_rows: int) -> str: + if n_rows < 1_000: + return "n<1K" + if n_rows < 10_000: + return "1K dict[str, int]: + if col not in df.columns: + return {} + vc = df[col].fillna("").astype(str).value_counts(dropna=False) + return {k: int(v) for k, v in vc.items()} + + +def _group_count(df: pd.DataFrame, left: str, right: str) -> dict[tuple[str, str], int]: + if left not in df.columns or right not in df.columns: + return {} + grouped = ( + df.groupby([left, right], dropna=False) + .size() + .reset_index(name="count") + .sort_values([left, right]) + ) + out: dict[tuple[str, str], int] = {} + for _, row in grouped.iterrows(): + out[(str(row[left]), str(row[right]))] = int(row["count"]) + return out + + +def _session_count_by_actor(df: pd.DataFrame) -> dict[str, int]: + if "actor_type" not in df.columns or "sessionId" not in df.columns: + return {} + grouped = ( + df[["actor_type", "sessionId"]] + .dropna(subset=["sessionId"]) + .drop_duplicates() + .groupby("actor_type") + .size() + ) + return {str(k): int(v) for k, v in grouped.items()} + + +def _time_range(df: pd.DataFrame) -> tuple[str, str]: + if "ts" not in df.columns: + return "unknown", "unknown" + ts = pd.to_datetime(df["ts"], errors="coerce", utc=True) + ts = ts.dropna() + if ts.empty: + return "unknown", "unknown" + return ts.min().isoformat(), ts.max().isoformat() + + +def _render_card(df: pd.DataFrame) -> str: + total_rows = len(df) + total_cols = len(df.columns) + size_cat = _size_category(total_rows) + + actor_counts = _series_count(df, "actor_type") + record_counts = _series_count(df, "record_type") + by_actor_record = _group_count(df, "actor_type", "record_type") + store_counts = _series_count(df, "storeMode") + session_counts = _session_count_by_actor(df) + t_min, t_max = _time_range(df) + + event_counts: dict[str, int] = {} + if "record_type" in df.columns and "eventName" in df.columns: + interactions = df[df["record_type"] == "interaction"] + event_counts = _series_count(interactions, "eventName") + + metadata_cols = sorted(c for c in df.columns if c.startswith("metadata_")) + + actor_lines = ( + "\n".join(f"- `{k}`: {v}" for k, v in actor_counts.items()) or "- none" + ) + record_lines = ( + "\n".join(f"- `{k}`: {v}" for k, v in record_counts.items()) or "- none" + ) + pair_lines = ( + "\n".join( + f"- `{a}` / `{r}`: {n}" + for (a, r), n in sorted( + by_actor_record.items(), key=lambda x: (x[0][0], x[0][1]) + ) + ) + or "- none" + ) + store_lines = ( + "\n".join(f"- `{k}`: {v}" for k, v in store_counts.items()) or "- none" + ) + session_lines = ( + "\n".join(f"- `{k}`: {v}" for k, v in session_counts.items()) or "- none" + ) + top_events = list(event_counts.items())[:10] + event_lines = "\n".join(f"- `{k}`: {v}" for k, v in top_events) or "- none" + metadata_lines = "\n".join(f"- `{c}`" for c in metadata_cols) or "- none" + + return f"""--- +pretty_name: whoclickedit +license: mit +language: +- en +task_categories: +- tabular-classification +task_ids: +- tabular-multi-class-classification +tags: +- e-commerce +- dynamic-pricing +- behavioral-telemetry +- human-vs-agent +- session-data +size_categories: +- {size_cat} +--- + +# Dataset Card for whoclickedit + +## Dataset Summary +whoclickedit is an event-level behavioral dataset for human versus agent interaction analysis in dynamic pricing experiments. +It merges interaction logs and price quote logs into one flat CSV (`whoclicked.csv`) with explicit labels for actor type. + +## Dataset Snapshot +- Rows: `{total_rows}` +- Columns: `{total_cols}` +- Time range (UTC): `{t_min}` to `{t_max}` +- Unique sessions by actor: +{session_lines} +- Rows by actor: +{actor_lines} +- Rows by record type: +{record_lines} +- Rows by actor x record type: +{pair_lines} +- Store modes: +{store_lines} + +## Source and Processing +Data is collected from two local roots in the PHANTOM project: +- `experiments/collected_data` (human sessions) +- `experiments/agents/collected_data` (agent sessions) + +Each session folder contains: +- `int.json` (interaction events) +- `price.json` (price quote logs) + +The ETL does the following: +- Normalizes both Kafka-envelope and flat payload formats +- Flattens nested metadata fields into `metadata_*` columns +- Preserves all raw rows (no deduplication) +- Adds labels: + - `actor_type` in `{{human, agent}}` + - `is_agent` in `{{0, 1}}` + - `record_type` in `{{interaction, price_log}}` + +## Data Fields +Core fields used for modeling: +- `actor_type`, `is_agent`, `record_type` +- `sessionId`, `experimentId`, `storeMode`, `ts` +- `eventName`, `page`, `productId`, `price`, `userAgent` + +Kafka provenance fields: +- `kafka_partition_id`, `kafka_offset`, `kafka_timestamp_ms`, `kafka_compression` +- `kafka_is_transactional`, `kafka_headers`, `kafka_key_*`, `kafka_value_*` + +Flattened metadata fields currently present: +{metadata_lines} + +Top interaction events: +{event_lines} + +## Intended Uses +- Human-vs-agent traffic classification +- Session-level behavioral modeling +- Dynamic pricing robustness analysis under agent-mediated reconnaissance + +## Out-of-Scope Uses +- Identity inference or user-level profiling +- Credit, employment, insurance, or legal decision making + +## Data Splits +No official train/validation/test split is provided in the current release. +Users should create time-aware or session-aware splits to avoid leakage. + +## Privacy and Sensitive Content +- `userAgent` and referrer metadata can be quasi-identifying in small samples. +- Use care before publishing derived artifacts that can re-identify participants. + +## Limitations +- Data is generated in a controlled experiment platform, not a full production marketplace. +- Agent traffic currently reflects the configured tasking and browser automation setup. +- Coverage is stronger for `hotel` than `airline` in the current release. + +## Citation +If you use this dataset, cite the PHANTOM thesis project and link this dataset page. +""" + + +def build_card(input_csv: Path, output_md: Path) -> None: + if not input_csv.exists(): + raise FileNotFoundError(f"Input CSV not found: {input_csv}") + df = pd.read_csv(input_csv) + card = _render_card(df) + output_md.parent.mkdir(parents=True, exist_ok=True) + output_md.write_text(card) + print(f"wrote dataset card to {output_md}") + + +def upload_card( + card_path: Path, repo_id: str, path_in_repo: str, commit_message: str +) -> None: + if not card_path.exists(): + raise FileNotFoundError(f"Card file not found: {card_path}") + + api = HfApi(token=_token()) + try: + me = api.whoami(token=_token()) + except Exception as exc: + detail = _exception_details(exc) + raise RuntimeError(f"Hugging Face auth failed. Details: {detail}") from exc + + user_name = me.get("name") or me.get("fullname") or "unknown" + print(f"authenticated to HF as: {user_name}") + + try: + api.repo_info(repo_id=repo_id, repo_type="dataset") + except Exception as exc: + detail = _exception_details(exc) + raise RuntimeError( + f"Dataset repo '{repo_id}' is not accessible. Details: {detail}" + ) from exc + + try: + commit = api.upload_file( + path_or_fileobj=str(card_path), + path_in_repo=path_in_repo, + repo_id=repo_id, + repo_type="dataset", + commit_message=commit_message, + ) + except Exception as exc: + detail = _exception_details(exc) + raise RuntimeError( + f"Card upload failed for '{repo_id}'. Details: {detail}" + ) from exc + + print(f"uploaded dataset card to https://huggingface.co/datasets/{repo_id}") + print(f"commit: {commit}") + + +def _parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Build or upload whoclickedit dataset card" + ) + sub = parser.add_subparsers(dest="command", required=True) + + build = sub.add_parser("build", help="build card markdown from CSV") + build.add_argument("--input", type=Path, default=DEFAULT_INPUT) + build.add_argument("--output", type=Path, default=DEFAULT_OUTPUT) + + upload = sub.add_parser("upload", help="upload existing card as dataset README.md") + upload.add_argument("--input", type=Path, default=DEFAULT_OUTPUT) + upload.add_argument("--repo", default=DEFAULT_REPO) + upload.add_argument("--path-in-repo", default="README.md") + upload.add_argument("--message", default="Add dataset card for whoclickedit") + + both = sub.add_parser("build-upload", help="build card and upload to dataset repo") + both.add_argument("--csv", type=Path, default=DEFAULT_INPUT) + both.add_argument("--card", type=Path, default=DEFAULT_OUTPUT) + both.add_argument("--repo", default=DEFAULT_REPO) + both.add_argument("--path-in-repo", default="README.md") + both.add_argument("--message", default="Add dataset card for whoclickedit") + + return parser.parse_args() + + +def main() -> int: + args = _parse_args() + try: + if args.command == "build": + build_card(args.input, args.output) + return 0 + + if args.command == "upload": + upload_card(args.input, args.repo, args.path_in_repo, args.message) + return 0 + + if args.command == "build-upload": + build_card(args.csv, args.card) + upload_card(args.card, args.repo, args.path_in_repo, args.message) + return 0 + + raise ValueError(f"Unknown command: {args.command}") + except Exception as exc: + print(f"error: {exc}", file=sys.stderr) + return 1 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/scripts/whoclicked_etl.py b/scripts/whoclicked_etl.py new file mode 100644 index 0000000..105f15a --- /dev/null +++ b/scripts/whoclicked_etl.py @@ -0,0 +1,412 @@ +#!/usr/bin/env python3 +"""Build and upload a flattened who-clicked dataset from local collected_data.""" + +from __future__ import annotations + +import argparse +import json +import os +import sys +from pathlib import Path +from typing import Any + +import pandas as pd +from huggingface_hub import HfApi + + +PROJECT_ROOT = Path(__file__).resolve().parent.parent +DEFAULT_HUMAN_DIR = PROJECT_ROOT / "experiments" / "collected_data" +DEFAULT_AGENT_DIR = PROJECT_ROOT / "experiments" / "agents" / "collected_data" +DEFAULT_OUTPUT = PROJECT_ROOT / "experiments" / "exports" / "whoclicked.csv" +DEFAULT_REPO = os.getenv("HF_WHOCLICKED_REPO", "velocitatem/whoclickedit") + +BASE_COLUMNS = [ + "actor_type", + "is_agent", + "record_type", + "topic", + "source_session_dir", + "source_file", + "source_row_index", + "ingest_format", + "sessionId", + "experimentId", + "storeMode", + "ts", + "eventName", + "page", + "productId", + "price", + "userAgent", + "kafka_partition_id", + "kafka_offset", + "kafka_timestamp_ms", + "kafka_compression", + "kafka_is_transactional", + "kafka_headers", + "kafka_key_payload", + "kafka_key_encoding", + "kafka_key_schema_id", + "kafka_value_encoding", + "kafka_value_schema_id", + "kafka_value_size", +] + + +def _token() -> str | None: + return os.getenv("HF_TOKEN") or None + + +def _exception_details(exc: Exception) -> str: + parts = [str(exc).strip()] + response = getattr(exc, "response", None) + if response is not None: + status = getattr(response, "status_code", None) + if status is not None: + parts.append(f"HTTP {status}") + text = getattr(response, "text", "") + if text: + text = text.strip() + if text: + parts.append(text[:500]) + return " | ".join(p for p in parts if p) + + +def _flatten_dict(data: dict[str, Any], prefix: str = "") -> dict[str, Any]: + flat: dict[str, Any] = {} + for key, value in data.items(): + normalized_key = str(key).strip().replace(" ", "_") + next_key = f"{prefix}_{normalized_key}" if prefix else normalized_key + if isinstance(value, dict): + flat.update(_flatten_dict(value, next_key)) + else: + flat[next_key] = value + return flat + + +def _as_scalar(value: Any) -> Any: + if isinstance(value, (dict, list, tuple)): + return json.dumps(value, ensure_ascii=True, sort_keys=True) + return value + + +def _empty_envelope() -> dict[str, Any]: + return { + "kafka_partition_id": None, + "kafka_offset": None, + "kafka_timestamp_ms": None, + "kafka_compression": None, + "kafka_is_transactional": None, + "kafka_headers": None, + "kafka_key_payload": None, + "kafka_key_encoding": None, + "kafka_key_schema_id": None, + "kafka_value_encoding": None, + "kafka_value_schema_id": None, + "kafka_value_size": None, + } + + +def _extract_payload_and_envelope( + record: Any, +) -> tuple[dict[str, Any], dict[str, Any], str]: + if ( + isinstance(record, dict) + and isinstance(record.get("value"), dict) + and isinstance(record["value"].get("payload"), dict) + ): + key = record.get("key") if isinstance(record.get("key"), dict) else {} + value = record["value"] + envelope = { + "kafka_partition_id": record.get("partitionID"), + "kafka_offset": record.get("offset"), + "kafka_timestamp_ms": record.get("timestamp"), + "kafka_compression": record.get("compression"), + "kafka_is_transactional": record.get("isTransactional"), + "kafka_headers": _as_scalar(record.get("headers")), + "kafka_key_payload": key.get("payload"), + "kafka_key_encoding": key.get("encoding"), + "kafka_key_schema_id": key.get("schemaId"), + "kafka_value_encoding": value.get("encoding"), + "kafka_value_schema_id": value.get("schemaId"), + "kafka_value_size": value.get("size"), + } + return dict(value["payload"]), envelope, "kafka_envelope" + + if isinstance(record, dict): + return dict(record), _empty_envelope(), "flat_payload" + + return {}, _empty_envelope(), "unknown" + + +def _load_json_list(path: Path) -> list[Any]: + raw = json.loads(path.read_text()) + if not isinstance(raw, list): + raise ValueError(f"Expected list in {path}, got {type(raw).__name__}") + return raw + + +def _normalize_file_rows( + actor_type: str, + is_agent: int, + session_dir_name: str, + source_file: str, + records: list[Any], +) -> list[dict[str, Any]]: + record_type = "interaction" if source_file == "int.json" else "price_log" + topic = "user-interactions" if record_type == "interaction" else "price-logs" + + rows: list[dict[str, Any]] = [] + for idx, raw_record in enumerate(records): + payload, envelope, ingest_format = _extract_payload_and_envelope(raw_record) + metadata = payload.pop("metadata", None) + + payload_flat = _flatten_dict(payload) + row: dict[str, Any] = { + "actor_type": actor_type, + "is_agent": is_agent, + "record_type": record_type, + "topic": topic, + "source_session_dir": session_dir_name, + "source_file": source_file, + "source_row_index": idx, + "ingest_format": ingest_format, + **envelope, + } + row.update({k: _as_scalar(v) for k, v in payload_flat.items()}) + + if isinstance(metadata, dict): + metadata_flat = _flatten_dict(metadata, "metadata") + row.update({k: _as_scalar(v) for k, v in metadata_flat.items()}) + elif metadata is not None: + row["metadata_raw"] = _as_scalar(metadata) + + rows.append(row) + + return rows + + +def _collect_rows_for_actor( + actor_type: str, is_agent: int, base_dir: Path +) -> list[dict[str, Any]]: + if not base_dir.exists(): + raise FileNotFoundError(f"Directory not found: {base_dir}") + + rows: list[dict[str, Any]] = [] + for session_dir in sorted( + (p for p in base_dir.iterdir() if p.is_dir()), key=lambda p: p.name + ): + for source_file in ("int.json", "price.json"): + file_path = session_dir / source_file + if not file_path.exists(): + continue + records = _load_json_list(file_path) + rows.extend( + _normalize_file_rows( + actor_type=actor_type, + is_agent=is_agent, + session_dir_name=session_dir.name, + source_file=source_file, + records=records, + ) + ) + return rows + + +def build_dataframe(human_dir: Path, agent_dir: Path) -> pd.DataFrame: + rows = [ + *_collect_rows_for_actor("human", 0, human_dir), + *_collect_rows_for_actor("agent", 1, agent_dir), + ] + if not rows: + return pd.DataFrame(columns=BASE_COLUMNS) + + df = pd.DataFrame(rows) + ordered_columns = [ + *BASE_COLUMNS, + *sorted(c for c in df.columns if c not in BASE_COLUMNS), + ] + return df[ordered_columns] + + +def _print_summary(df: pd.DataFrame, output_path: Path) -> None: + print(f"wrote {len(df)} rows and {len(df.columns)} columns to {output_path}") + if df.empty: + return + + print("rows by actor/record_type:") + grouped = ( + df.groupby(["actor_type", "record_type"], dropna=False) + .size() + .reset_index(name="count") + .sort_values(["actor_type", "record_type"]) + ) + for _, row in grouped.iterrows(): + print(f" - {row['actor_type']} / {row['record_type']}: {int(row['count'])}") + + required = ["actor_type", "is_agent", "record_type", "sessionId", "ts"] + missing = {col: int(df[col].isna().sum()) for col in required if col in df.columns} + print(f"missing in required columns: {missing}") + + +def build_csv(human_dir: Path, agent_dir: Path, output: Path) -> pd.DataFrame: + df = build_dataframe(human_dir=human_dir, agent_dir=agent_dir) + output.parent.mkdir(parents=True, exist_ok=True) + df.to_csv(output, index=False) + _print_summary(df, output) + return df + + +def _resolve_repo_id(api: HfApi, repo_id: str) -> str: + if "/" in repo_id: + return repo_id + try: + me = api.whoami(token=_token()) + username = me.get("name") + if username: + return f"{username}/{repo_id}" + except Exception: + pass + return repo_id + + +def upload_csv( + input_path: Path, + repo_id: str, + path_in_repo: str, + commit_message: str, + create_if_missing: bool = False, +) -> None: + if not input_path.exists(): + raise FileNotFoundError(f"Input CSV not found: {input_path}") + + api = HfApi(token=_token()) + + try: + me = api.whoami(token=_token()) + except Exception as exc: + detail = _exception_details(exc) + hint = "Set HF_TOKEN with write access or run huggingface-cli login." + raise RuntimeError( + f"Hugging Face auth failed. {hint} Details: {detail}" + ) from exc + + user_name = me.get("name") or me.get("fullname") or "unknown" + print(f"authenticated to HF as: {user_name}") + + resolved_repo_id = _resolve_repo_id(api, repo_id) + if create_if_missing: + api.create_repo(repo_id=resolved_repo_id, repo_type="dataset", exist_ok=True) + else: + try: + api.repo_info(repo_id=resolved_repo_id, repo_type="dataset") + except Exception as exc: + detail = _exception_details(exc) + hint = ( + "Check owner/repo spelling, ensure it is a dataset repo, " + "or pass --create-if-missing." + ) + raise RuntimeError( + f"Dataset repo '{resolved_repo_id}' is not accessible. {hint} Details: {detail}" + ) from exc + + try: + commit = api.upload_file( + path_or_fileobj=str(input_path), + path_in_repo=path_in_repo, + repo_id=resolved_repo_id, + repo_type="dataset", + commit_message=commit_message, + ) + except Exception as exc: + detail = _exception_details(exc) + hint = ( + "Pass --repo /whoclickedit and ensure HF_TOKEN is set " + "(or run huggingface-cli login)." + ) + raise RuntimeError( + f"Upload failed for '{resolved_repo_id}'. {hint} Details: {detail}" + ) from exc + + print( + f"uploaded {input_path} to https://huggingface.co/datasets/{resolved_repo_id}" + ) + print(f"commit: {commit}") + + +def _parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="ETL for whoclickedit: flatten local collected_data and upload to HF" + ) + sub = parser.add_subparsers(dest="command", required=True) + + build = sub.add_parser("build", help="build flattened CSV locally") + build.add_argument("--human-dir", type=Path, default=DEFAULT_HUMAN_DIR) + build.add_argument("--agent-dir", type=Path, default=DEFAULT_AGENT_DIR) + build.add_argument("--output", type=Path, default=DEFAULT_OUTPUT) + + upload = sub.add_parser("upload", help="upload an existing CSV to HF dataset") + upload.add_argument("--input", type=Path, default=DEFAULT_OUTPUT) + upload.add_argument("--repo", default=DEFAULT_REPO) + upload.add_argument("--path-in-repo", default="whoclicked.csv") + upload.add_argument("--message", default="Update flattened whoclickedit dataset") + upload.add_argument("--create-if-missing", action="store_true") + + build_upload = sub.add_parser( + "build-upload", help="build CSV and upload to HF dataset" + ) + build_upload.add_argument("--human-dir", type=Path, default=DEFAULT_HUMAN_DIR) + build_upload.add_argument("--agent-dir", type=Path, default=DEFAULT_AGENT_DIR) + build_upload.add_argument("--output", type=Path, default=DEFAULT_OUTPUT) + build_upload.add_argument("--repo", default=DEFAULT_REPO) + build_upload.add_argument("--path-in-repo", default="whoclicked.csv") + build_upload.add_argument( + "--message", default="Update flattened whoclickedit dataset" + ) + build_upload.add_argument("--create-if-missing", action="store_true") + + return parser.parse_args() + + +def main() -> int: + args = _parse_args() + + try: + if args.command == "build": + build_csv( + human_dir=args.human_dir, agent_dir=args.agent_dir, output=args.output + ) + return 0 + + if args.command == "upload": + upload_csv( + input_path=args.input, + repo_id=args.repo, + path_in_repo=args.path_in_repo, + commit_message=args.message, + create_if_missing=args.create_if_missing, + ) + return 0 + + if args.command == "build-upload": + build_csv( + human_dir=args.human_dir, agent_dir=args.agent_dir, output=args.output + ) + upload_csv( + input_path=args.output, + repo_id=args.repo, + path_in_repo=args.path_in_repo, + commit_message=args.message, + create_if_missing=args.create_if_missing, + ) + return 0 + + raise ValueError(f"Unknown command: {args.command}") + + except Exception as exc: + print(f"error: {exc}", file=sys.stderr) + return 1 + + +if __name__ == "__main__": + raise SystemExit(main())