mirror of
https://github.com/velocitatem/PHANTOM.git
synced 2026-05-31 08:33:36 +00:00
44 lines
1.3 KiB
Bash
44 lines
1.3 KiB
Bash
#!/usr/bin/env sh
|
|
set -eu
|
|
|
|
REPO_DIR="${REPO_DIR:-$HOME/PHANTOM}"
|
|
PYTHON_BIN="${PYTHON_BIN:-python3}"
|
|
TRAIN_ARGS="${TRAIN_ARGS:---algo ppo --jax --total-timesteps 200000 --jax-num-envs 32 --jax-num-steps 128 --jax-num-minibatches 4 --jax-update-epochs 4}"
|
|
EXTRA_PIP="${EXTRA_PIP:-flax optax distrax}"
|
|
INSTALL_FULL_REQUIREMENTS="${INSTALL_FULL_REQUIREMENTS:-0}"
|
|
|
|
if [ ! -d "$REPO_DIR" ]; then
|
|
echo "repo directory not found: $REPO_DIR"
|
|
exit 1
|
|
fi
|
|
|
|
cd "$REPO_DIR"
|
|
|
|
if [ -d "wandb" ]; then
|
|
rm -rf wandb
|
|
fi
|
|
|
|
# keep install idempotent and avoid re-installing jax/libtpu each run
|
|
if [ "$INSTALL_FULL_REQUIREMENTS" = "1" ] && [ -f "requirements.txt" ]; then
|
|
$PYTHON_BIN -m pip install -r requirements.txt
|
|
fi
|
|
if ! $PYTHON_BIN -c 'import flax, optax, distrax' >/dev/null 2>&1; then
|
|
if [ -f "engine/jax/requirements.txt" ]; then
|
|
$PYTHON_BIN -m pip install -r engine/jax/requirements.txt
|
|
fi
|
|
$PYTHON_BIN -m pip install -U $EXTRA_PIP
|
|
fi
|
|
|
|
if [ -n "${WANDB_API_KEY:-}" ]; then
|
|
if ! $PYTHON_BIN -c 'import wandb; import inspect; assert hasattr(wandb, "init") and callable(wandb.init)' >/dev/null 2>&1; then
|
|
$PYTHON_BIN -m pip install -U wandb
|
|
fi
|
|
fi
|
|
|
|
if [ -n "${WANDB_API_KEY:-}" ]; then
|
|
export WANDB_API_KEY
|
|
exec $PYTHON_BIN -m engine.train $TRAIN_ARGS
|
|
fi
|
|
|
|
exec $PYTHON_BIN -m engine.train $TRAIN_ARGS --no-wandb
|