refactoring training spc setup and benchmarking

This commit is contained in:
2026-03-08 18:30:53 +01:00
parent 9fafb26ec8
commit 73246d7dd8
36 changed files with 2180 additions and 613 deletions

View File

@@ -30,10 +30,20 @@ case "$cmd" in
load_sweep_env
require_var WANDB_API_KEY "WANDB_API_KEY required - set it in $env_file"
WANDB_ENTITY="${WANDB_ENTITY:-}" \
WANDB_PROJECT="${WANDB_PROJECT:-phantom-pricing}" \
WANDB_PROJECT="${WANDB_PROJECT:-capstone}" \
WANDB_API_KEY="$WANDB_API_KEY" \
.venv/bin/python -m engine.train ${LOCAL_TRAIN_ARGS:---algo ppo --total-timesteps 50000}
;;
benchmark)
load_sweep_env
if [[ " ${LOCAL_BENCHMARK_ARGS:-} " != *" --no-wandb "* ]]; then
require_var WANDB_API_KEY "WANDB_API_KEY required - set it in $env_file"
fi
WANDB_ENTITY="${WANDB_ENTITY:-}" \
WANDB_PROJECT="${WANDB_PROJECT:-capstone}" \
WANDB_API_KEY="${WANDB_API_KEY:-}" \
.venv/bin/python -m engine.train --run-kind benchmark ${LOCAL_BENCHMARK_ARGS:---tiers static,surge,linear,qtable,ppo --alpha-values 0.0,0.3 --episodes 3 --total-timesteps 3000 --max-steps 40 --device cpu}
;;
train-agent)
load_sweep_env
require_var WANDB_API_KEY "WANDB_API_KEY required - set it in $env_file"
@@ -43,10 +53,23 @@ case "$cmd" in
args+=(--count "$AGENT_COUNT")
fi
WANDB_ENTITY="${WANDB_ENTITY:-}" \
WANDB_PROJECT="${WANDB_PROJECT:-phantom-pricing}" \
WANDB_PROJECT="${WANDB_PROJECT:-capstone}" \
WANDB_API_KEY="$WANDB_API_KEY" \
.venv/bin/python -m engine.train "${args[@]}"
;;
benchmark-agent)
load_sweep_env
require_var WANDB_API_KEY "WANDB_API_KEY required - set it in $env_file"
require_var SWEEP_ID "SWEEP_ID required, e.g. SWEEP_ID=entity/project/id"
args=(--sweep-agent --sweep-id "$SWEEP_ID")
if [ -n "${AGENT_COUNT:-}" ] && [ "${AGENT_COUNT}" != "0" ]; then
args+=(--count "$AGENT_COUNT")
fi
WANDB_ENTITY="${WANDB_ENTITY:-}" \
WANDB_PROJECT="${WANDB_PROJECT:-capstone}" \
WANDB_API_KEY="$WANDB_API_KEY" \
.venv/bin/python -m engine.train --run-kind benchmark "${args[@]}" ${BENCHMARK_AGENT_ARGS:-}
;;
train-bootstrap)
load_sweep_env
require_var WANDB_API_KEY "WANDB_API_KEY required - set it in $env_file"
@@ -55,7 +78,7 @@ case "$cmd" in
require_var SWEEP_ID "SWEEP_ID required, e.g. SWEEP_ID=entity/project/id"
WANDB_API_KEY="$WANDB_API_KEY" \
WANDB_ENTITY="${WANDB_ENTITY:-}" \
WANDB_PROJECT="${WANDB_PROJECT:-phantom-pricing}" \
WANDB_PROJECT="${WANDB_PROJECT:-capstone}" \
GITHUB_TOKEN="$GITHUB_TOKEN" \
REPO_URL="$REPO_URL" \
BRANCH="${BRANCH:-main}" \
@@ -115,7 +138,7 @@ PY
train-tpu-vm-sweep)
load_sweep_env
require_var TPU_NAME "TPU_NAME required, e.g. TPU_NAME=TPUlong"
require_var SWEEP_ID "SWEEP_ID required, e.g. SWEEP_ID=lusiana/phantom-pricing/abc123"
require_var SWEEP_ID "SWEEP_ID required, e.g. SWEEP_ID=lusiana/capstone/abc123"
require_var WANDB_API_KEY "WANDB_API_KEY required - set it in $env_file"
args=(
--sweep-id "$SWEEP_ID"

View File

@@ -96,7 +96,11 @@ def _extract_metrics(output: str) -> dict:
obj = json.loads(block)
except Exception:
continue
if isinstance(obj, dict) and ("sweep/score" in obj or "eval/reward" in obj):
if isinstance(obj, dict) and (
"objective/score" in obj
or "eval/reward_mean" in obj
or "sweep/score" in obj
):
return obj
return {}