mirror of
https://github.com/velocitatem/PHANTOM.git
synced 2026-05-31 08:33:36 +00:00
new trainer image
This commit is contained in:
43
docker/Trainer.dockerfile
Normal file
43
docker/Trainer.dockerfile
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
# syntax=docker/dockerfile:1.7
|
||||||
|
|
||||||
|
FROM pytorch/pytorch:2.5.1-cuda12.4-cudnn9-runtime AS gpu
|
||||||
|
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
COPY docker/trainer.requirements.txt /tmp/requirements.txt
|
||||||
|
RUN pip install --no-cache-dir -r /tmp/requirements.txt
|
||||||
|
|
||||||
|
# Optional for JAX-on-GPU workflows.
|
||||||
|
ARG INSTALL_JAX_GPU=false
|
||||||
|
RUN if [ "${INSTALL_JAX_GPU}" = "true" ]; then \
|
||||||
|
pip install --no-cache-dir "jax[cuda12]==0.4.30" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html; \
|
||||||
|
fi
|
||||||
|
|
||||||
|
COPY --chmod=755 docker/trainer-agent-entrypoint.sh /usr/local/bin/trainer-agent-entrypoint
|
||||||
|
COPY engine /app/engine
|
||||||
|
|
||||||
|
ENV PYTHONPATH=/app \
|
||||||
|
XLA_PYTHON_CLIENT_PREALLOCATE=false
|
||||||
|
|
||||||
|
ENTRYPOINT ["/usr/local/bin/trainer-agent-entrypoint"]
|
||||||
|
|
||||||
|
|
||||||
|
FROM python:3.11-slim AS tpu
|
||||||
|
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
COPY docker/trainer.requirements.txt /tmp/requirements.txt
|
||||||
|
RUN pip install --no-cache-dir -r /tmp/requirements.txt
|
||||||
|
|
||||||
|
RUN pip install --no-cache-dir "jax[tpu]==0.4.30" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
|
||||||
|
|
||||||
|
COPY --chmod=755 docker/trainer-agent-entrypoint.sh /usr/local/bin/trainer-agent-entrypoint
|
||||||
|
COPY engine /app/engine
|
||||||
|
|
||||||
|
ENV PYTHONPATH=/app \
|
||||||
|
PHANTOM_USE_JAX=1 \
|
||||||
|
PHANTOM_DEFAULT_AGENT_ARGS="--jax" \
|
||||||
|
JAX_PLATFORMS=tpu,cpu \
|
||||||
|
XLA_PYTHON_CLIENT_PREALLOCATE=false
|
||||||
|
|
||||||
|
ENTRYPOINT ["/usr/local/bin/trainer-agent-entrypoint"]
|
||||||
23
docker/trainer-agent-entrypoint.sh
Normal file
23
docker/trainer-agent-entrypoint.sh
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
#!/usr/bin/env sh
|
||||||
|
set -eu
|
||||||
|
|
||||||
|
if [ -z "${SWEEP_ID:-}" ]; then
|
||||||
|
echo "SWEEP_ID is required"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
set -- python -m engine.train --sweep-agent --sweep-id "${SWEEP_ID}"
|
||||||
|
|
||||||
|
if [ -n "${PHANTOM_DEFAULT_AGENT_ARGS:-}" ]; then
|
||||||
|
set -- "$@" ${PHANTOM_DEFAULT_AGENT_ARGS}
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ -n "${TRAIN_ARGS:-}" ]; then
|
||||||
|
set -- "$@" ${TRAIN_ARGS}
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ "${AGENT_COUNT:-0}" != "0" ]; then
|
||||||
|
set -- "$@" --count "${AGENT_COUNT}"
|
||||||
|
fi
|
||||||
|
|
||||||
|
exec "$@"
|
||||||
13
docker/trainer.requirements.txt
Normal file
13
docker/trainer.requirements.txt
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
numpy>=1.24.0
|
||||||
|
pandas>=2.0.0
|
||||||
|
scipy>=1.11.0
|
||||||
|
gymnasium>=0.29.0
|
||||||
|
stable-baselines3>=2.2.0
|
||||||
|
tensorboard>=2.15.0
|
||||||
|
wandb>=0.17.0
|
||||||
|
tensorflow-probability==0.24.0
|
||||||
|
flax==0.10.7
|
||||||
|
optax==0.2.7
|
||||||
|
distrax==0.1.5
|
||||||
|
orbax-checkpoint==0.11.32
|
||||||
|
chex==0.1.90
|
||||||
Reference in New Issue
Block a user