FROM quay.io/pypa/manylinux_2_28_aarch64 as base

# Cuda ARM build needs gcc 11
ARG DEVTOOLSET_VERSION=11

# Language variables
ENV LC_ALL=en_US.UTF-8
ENV LANG=en_US.UTF-8
ENV LANGUAGE=en_US.UTF-8

# Installed needed OS packages. This is to support all
# the binary builds (torch, vision, audio, text, data)
RUN yum -y install epel-release
RUN yum -y update
RUN yum install -y \
  autoconf \
  automake \
  bison \
  bzip2 \
  curl \
  diffutils \
  file \
  git \
  make \
  patch \
  perl \
  unzip \
  util-linux \
  wget \
  which \
  xz \
  yasm \
  less \
  zstd \
  libgomp \
  sudo \
  gcc-toolset-${DEVTOOLSET_VERSION}-toolchain

# Ensure the expected devtoolset is used
ENV PATH=/opt/rh/gcc-toolset-${DEVTOOLSET_VERSION}/root/usr/bin:$PATH
ENV LD_LIBRARY_PATH=/opt/rh/gcc-toolset-${DEVTOOLSET_VERSION}/root/usr/lib64:/opt/rh/gcc-toolset-${DEVTOOLSET_VERSION}/root/usr/lib:$LD_LIBRARY_PATH

# git236+ would refuse to run git commands in repos owned by other users
# Which causes version check to fail, as pytorch repo is bind-mounted into the image
# Override this behaviour by treating every folder as safe
# For more details see https://github.com/pytorch/pytorch/issues/78659#issuecomment-1144107327
RUN git config --global --add safe.directory "*"


FROM base as openssl
# Install openssl (this must precede `build python` step)
# (In order to have a proper SSL module, Python is compiled
# against a recent openssl [see env vars above], which is linked
# statically. We delete openssl afterwards.)
ADD ./common/install_openssl.sh install_openssl.sh
RUN bash ./install_openssl.sh && rm install_openssl.sh
ENV SSL_CERT_FILE=/opt/_internal/certs.pem

FROM openssl as final
# remove unncessary python versions
RUN rm -rf /opt/python/cp26-cp26m /opt/_internal/cpython-2.6.9-ucs2
RUN rm -rf /opt/python/cp26-cp26mu /opt/_internal/cpython-2.6.9-ucs4
RUN rm -rf /opt/python/cp33-cp33m /opt/_internal/cpython-3.3.6
RUN rm -rf /opt/python/cp34-cp34m /opt/_internal/cpython-3.4.6

FROM base as cuda
ARG BASE_CUDA_VERSION
# Install CUDA
ADD ./common/install_cuda_aarch64.sh install_cuda_aarch64.sh
RUN bash ./install_cuda_aarch64.sh ${BASE_CUDA_VERSION} && rm install_cuda_aarch64.sh

FROM base as magma
ARG BASE_CUDA_VERSION
# Install magma
ADD ./common/install_magma.sh install_magma.sh
RUN bash ./install_magma.sh ${BASE_CUDA_VERSION} && rm install_magma.sh

FROM base as nvpl
# Install nvpl
ADD ./common/install_nvpl.sh install_nvpl.sh
RUN bash ./install_nvpl.sh && rm install_nvpl.sh

FROM final as cuda_final
ARG BASE_CUDA_VERSION
RUN rm -rf /usr/local/cuda-${BASE_CUDA_VERSION}
COPY --from=cuda     /usr/local/cuda-${BASE_CUDA_VERSION}  /usr/local/cuda-${BASE_CUDA_VERSION}
COPY --from=magma    /usr/local/cuda-${BASE_CUDA_VERSION}  /usr/local/cuda-${BASE_CUDA_VERSION}
COPY --from=nvpl /opt/nvpl/lib/  /usr/local/lib/
COPY --from=nvpl /opt/nvpl/include/  /usr/local/include/
RUN ln -sf /usr/local/cuda-${BASE_CUDA_VERSION} /usr/local/cuda
ENV PATH=/usr/local/cuda/bin:$PATH
