-
Notifications
You must be signed in to change notification settings - Fork 56
/
Copy pathDockerfile.jax
129 lines (112 loc) · 4.44 KB
/
Dockerfile.jax
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
# syntax=docker/dockerfile:1-labs
ARG BASE_IMAGE=ghcr.io/nvidia/jax:base
ARG BUILD_PATH_JAXLIB=/opt/jaxlibs
ARG URLREF_JAX=https://github.com/google/jax.git#main
ARG URLREF_XLA=https://github.com/openxla/xla.git#main
ARG URLREF_FLAX=https://github.com/google/flax.git#main
ARG URLREF_TRANSFORMER_ENGINE=https://github.com/NVIDIA/TransformerEngine.git#main
ARG SRC_PATH_JAX=/opt/jax
ARG SRC_PATH_XLA=/opt/xla
ARG SRC_PATH_FLAX=/opt/flax
ARG SRC_PATH_TRANSFORMER_ENGINE=/opt/transformer-engine
ARG GIT_USER_NAME="JAX Toolbox"
ARG BAZEL_CACHE=/tmp
ARG BUILD_DATE
###############################################################################
## Build JAX
###############################################################################
FROM ${BASE_IMAGE} AS builder
ARG URLREF_JAX
ARG URLREF_TRANSFORMER_ENGINE
ARG URLREF_XLA
ARG SRC_PATH_JAX
ARG SRC_PATH_TRANSFORMER_ENGINE
ARG SRC_PATH_XLA
ARG BAZEL_CACHE
ARG BUILD_PATH_JAXLIB
ARG GIT_USER_NAME
ARG GIT_USER_EMAIL
RUN --mount=type=ssh \
--mount=type=secret,id=SSH_KNOWN_HOSTS,target=/root/.ssh/known_hosts \
<<"EOF" bash -ex
git-clone.sh ${URLREF_JAX} ${SRC_PATH_JAX}
git-clone.sh ${URLREF_XLA} ${SRC_PATH_XLA}
EOF
ADD build-jax.sh local_cuda_arch test-jax.sh /usr/local/bin/
# Install bazelisk
RUN ARCH="$(dpkg --print-architecture)" && \
wget -O /usr/local/bin/bazel https://github.com/bazelbuild/bazelisk/releases/latest/download/bazelisk-linux-${ARCH} && \
chmod +x /usr/local/bin/bazel
# TODO: move this patch into the manifest
ADD xla-arm64-neon.patch /opt
RUN build-jax.sh \
--bazel-cache ${BAZEL_CACHE} \
--build-path-jaxlib ${BUILD_PATH_JAXLIB} \
--src-path-jax ${SRC_PATH_JAX} \
--src-path-xla ${SRC_PATH_XLA} \
--sm all \
--xla-arm64-patch /opt/xla-arm64-neon.patch \
--clean
## Transformer engine: check out source and build wheel
RUN <<"EOF" bash -ex -o pipefail
pip install ninja && rm -rf ~/.cache/pip
# TransformerEngine now needs JAX at build time
git-clone.sh ${URLREF_TRANSFORMER_ENGINE} ${SRC_PATH_TRANSFORMER_ENGINE}
pushd ${SRC_PATH_TRANSFORMER_ENGINE}
export NVTE_BUILD_THREADS_PER_JOB=8
python setup.py bdist_wheel && rm -rf build
ls "${SRC_PATH_TRANSFORMER_ENGINE}/dist"
EOF
###############################################################################
## Pack jaxlib wheel and various source dirs into a pre-installation image
###############################################################################
ARG BASE_IMAGE
FROM ${BASE_IMAGE} AS mealkit
ARG URLREF_FLAX
ARG SRC_PATH_JAX
ARG SRC_PATH_XLA
ARG SRC_PATH_FLAX
ARG SRC_PATH_TRANSFORMER_ENGINE
ARG BUILD_DATE
ARG BUILD_PATH_JAXLIB
ENV BUILD_DATE=${BUILD_DATE}
# The following environment variables tune performance
ENV XLA_FLAGS=""
ENV XLA_FLAGS="${XLA_FLAGS} --xla_gpu_enable_latency_hiding_scheduler=true"
ENV XLA_FLAGS="${XLA_FLAGS} --xla_gpu_enable_triton_gemm=false"
ENV NCCL_NVLS_ENABLE=0
COPY --from=builder ${BUILD_PATH_JAXLIB} ${BUILD_PATH_JAXLIB}
COPY --from=builder ${SRC_PATH_JAX} ${SRC_PATH_JAX}
COPY --from=builder ${SRC_PATH_XLA} ${SRC_PATH_XLA}
COPY --from=builder /usr/local/bin/bazel /usr/local/bin/bazel
# Preserve the versions of jax and xla
COPY --from=builder /opt/manifest.d/git-clone.yaml /opt/manifest.d/git-clone.yaml
ADD build-jax.sh local_cuda_arch test-jax.sh /usr/local/bin/
RUN mkdir -p /opt/pip-tools.d
## Editable installations of jax and jaxlib
RUN <<"EOF" bash -ex
for component in $(ls ${BUILD_PATH_JAXLIB}); do
echo "-e file://${BUILD_PATH_JAXLIB}/${component}" >> /opt/pip-tools.d/requirements-jax.in;
done
echo "-e file://${SRC_PATH_JAX}[k8s]" >> /opt/pip-tools.d/requirements-jax.in
echo "numpy<2.0.0" >> /opt/pip-tools.d/requirements-jax.in
EOF
## Flax
RUN <<"EOF" bash -ex
git-clone.sh ${URLREF_FLAX} ${SRC_PATH_FLAX}
echo "-e file://${SRC_PATH_FLAX}" >> /opt/pip-tools.d/requirements-flax.in
EOF
# Copy TransformerEngine wheel from the builder stage
ENV NVTE_FRAMEWORK=jax
ENV SRC_PATH_TRANSFORMER_ENGINE=${SRC_PATH_TRANSFORMER_ENGINE}
COPY --from=builder ${SRC_PATH_TRANSFORMER_ENGINE} ${SRC_PATH_TRANSFORMER_ENGINE}
RUN <<"EOF" bash -ex
ls ${SRC_PATH_TRANSFORMER_ENGINE}/dist/*.whl
echo "transformer-engine @ file://$(ls ${SRC_PATH_TRANSFORMER_ENGINE}/dist/*.whl)" > /opt/pip-tools.d/requirements-te.in
EOF
###############################################################################
## Install primary packages and transitive dependencies
###############################################################################
FROM mealkit AS final
RUN pip-finalize.sh