mirror of
https://github.com/velocitatem/PHANTOM.git
synced 2026-05-31 08:33:36 +00:00
refactoring training spc setup and benchmarking
This commit is contained in:
@@ -30,10 +30,20 @@ case "$cmd" in
|
||||
load_sweep_env
|
||||
require_var WANDB_API_KEY "WANDB_API_KEY required - set it in $env_file"
|
||||
WANDB_ENTITY="${WANDB_ENTITY:-}" \
|
||||
WANDB_PROJECT="${WANDB_PROJECT:-phantom-pricing}" \
|
||||
WANDB_PROJECT="${WANDB_PROJECT:-capstone}" \
|
||||
WANDB_API_KEY="$WANDB_API_KEY" \
|
||||
.venv/bin/python -m engine.train ${LOCAL_TRAIN_ARGS:---algo ppo --total-timesteps 50000}
|
||||
;;
|
||||
benchmark)
|
||||
load_sweep_env
|
||||
if [[ " ${LOCAL_BENCHMARK_ARGS:-} " != *" --no-wandb "* ]]; then
|
||||
require_var WANDB_API_KEY "WANDB_API_KEY required - set it in $env_file"
|
||||
fi
|
||||
WANDB_ENTITY="${WANDB_ENTITY:-}" \
|
||||
WANDB_PROJECT="${WANDB_PROJECT:-capstone}" \
|
||||
WANDB_API_KEY="${WANDB_API_KEY:-}" \
|
||||
.venv/bin/python -m engine.train --run-kind benchmark ${LOCAL_BENCHMARK_ARGS:---tiers static,surge,linear,qtable,ppo --alpha-values 0.0,0.3 --episodes 3 --total-timesteps 3000 --max-steps 40 --device cpu}
|
||||
;;
|
||||
train-agent)
|
||||
load_sweep_env
|
||||
require_var WANDB_API_KEY "WANDB_API_KEY required - set it in $env_file"
|
||||
@@ -43,10 +53,23 @@ case "$cmd" in
|
||||
args+=(--count "$AGENT_COUNT")
|
||||
fi
|
||||
WANDB_ENTITY="${WANDB_ENTITY:-}" \
|
||||
WANDB_PROJECT="${WANDB_PROJECT:-phantom-pricing}" \
|
||||
WANDB_PROJECT="${WANDB_PROJECT:-capstone}" \
|
||||
WANDB_API_KEY="$WANDB_API_KEY" \
|
||||
.venv/bin/python -m engine.train "${args[@]}"
|
||||
;;
|
||||
benchmark-agent)
|
||||
load_sweep_env
|
||||
require_var WANDB_API_KEY "WANDB_API_KEY required - set it in $env_file"
|
||||
require_var SWEEP_ID "SWEEP_ID required, e.g. SWEEP_ID=entity/project/id"
|
||||
args=(--sweep-agent --sweep-id "$SWEEP_ID")
|
||||
if [ -n "${AGENT_COUNT:-}" ] && [ "${AGENT_COUNT}" != "0" ]; then
|
||||
args+=(--count "$AGENT_COUNT")
|
||||
fi
|
||||
WANDB_ENTITY="${WANDB_ENTITY:-}" \
|
||||
WANDB_PROJECT="${WANDB_PROJECT:-capstone}" \
|
||||
WANDB_API_KEY="$WANDB_API_KEY" \
|
||||
.venv/bin/python -m engine.train --run-kind benchmark "${args[@]}" ${BENCHMARK_AGENT_ARGS:-}
|
||||
;;
|
||||
train-bootstrap)
|
||||
load_sweep_env
|
||||
require_var WANDB_API_KEY "WANDB_API_KEY required - set it in $env_file"
|
||||
@@ -55,7 +78,7 @@ case "$cmd" in
|
||||
require_var SWEEP_ID "SWEEP_ID required, e.g. SWEEP_ID=entity/project/id"
|
||||
WANDB_API_KEY="$WANDB_API_KEY" \
|
||||
WANDB_ENTITY="${WANDB_ENTITY:-}" \
|
||||
WANDB_PROJECT="${WANDB_PROJECT:-phantom-pricing}" \
|
||||
WANDB_PROJECT="${WANDB_PROJECT:-capstone}" \
|
||||
GITHUB_TOKEN="$GITHUB_TOKEN" \
|
||||
REPO_URL="$REPO_URL" \
|
||||
BRANCH="${BRANCH:-main}" \
|
||||
@@ -115,7 +138,7 @@ PY
|
||||
train-tpu-vm-sweep)
|
||||
load_sweep_env
|
||||
require_var TPU_NAME "TPU_NAME required, e.g. TPU_NAME=TPUlong"
|
||||
require_var SWEEP_ID "SWEEP_ID required, e.g. SWEEP_ID=lusiana/phantom-pricing/abc123"
|
||||
require_var SWEEP_ID "SWEEP_ID required, e.g. SWEEP_ID=lusiana/capstone/abc123"
|
||||
require_var WANDB_API_KEY "WANDB_API_KEY required - set it in $env_file"
|
||||
args=(
|
||||
--sweep-id "$SWEEP_ID"
|
||||
|
||||
@@ -96,7 +96,11 @@ def _extract_metrics(output: str) -> dict:
|
||||
obj = json.loads(block)
|
||||
except Exception:
|
||||
continue
|
||||
if isinstance(obj, dict) and ("sweep/score" in obj or "eval/reward" in obj):
|
||||
if isinstance(obj, dict) and (
|
||||
"objective/score" in obj
|
||||
or "eval/reward_mean" in obj
|
||||
or "sweep/score" in obj
|
||||
):
|
||||
return obj
|
||||
return {}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user