mirror of
https://github.com/velocitatem/PHANTOM.git
synced 2026-05-31 08:33:36 +00:00
cleaning up jax bs
This commit is contained in:
@@ -7,36 +7,9 @@ WORKDIR /app
|
||||
COPY docker/trainer.requirements.txt /tmp/requirements.txt
|
||||
RUN pip install --no-cache-dir -r /tmp/requirements.txt
|
||||
|
||||
# Optional for JAX-on-GPU workflows.
|
||||
ARG INSTALL_JAX_GPU=false
|
||||
RUN if [ "${INSTALL_JAX_GPU}" = "true" ]; then \
|
||||
pip install --no-cache-dir "jax[cuda12]==0.4.30" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html; \
|
||||
fi
|
||||
|
||||
COPY --chmod=755 docker/trainer-agent-entrypoint.sh /usr/local/bin/trainer-agent-entrypoint
|
||||
COPY engine /app/engine
|
||||
|
||||
ENV PYTHONPATH=/app \
|
||||
XLA_PYTHON_CLIENT_PREALLOCATE=false
|
||||
|
||||
ENTRYPOINT ["/usr/local/bin/trainer-agent-entrypoint"]
|
||||
|
||||
|
||||
FROM python:3.11-slim AS tpu
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
COPY docker/trainer.requirements.txt /tmp/requirements.txt
|
||||
RUN pip install --no-cache-dir -r /tmp/requirements.txt
|
||||
|
||||
RUN pip install --no-cache-dir "jax[tpu]==0.4.30" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
|
||||
|
||||
COPY --chmod=755 docker/trainer-agent-entrypoint.sh /usr/local/bin/trainer-agent-entrypoint
|
||||
COPY engine /app/engine
|
||||
|
||||
ENV PYTHONPATH=/app \
|
||||
PHANTOM_USE_JAX=1 \
|
||||
PHANTOM_DEFAULT_AGENT_ARGS="--jax" \
|
||||
XLA_PYTHON_CLIENT_PREALLOCATE=false
|
||||
ENV PYTHONPATH=/app
|
||||
|
||||
ENTRYPOINT ["/usr/local/bin/trainer-agent-entrypoint"]
|
||||
|
||||
@@ -5,9 +5,3 @@ gymnasium>=0.29.0
|
||||
stable-baselines3>=2.2.0
|
||||
tensorboard>=2.15.0
|
||||
wandb>=0.17.0
|
||||
tensorflow-probability==0.24.0
|
||||
flax==0.10.7
|
||||
optax==0.2.7
|
||||
distrax==0.1.5
|
||||
orbax-checkpoint==0.11.32
|
||||
chex==0.1.90
|
||||
|
||||
Reference in New Issue
Block a user