dags/sparsity_diffusion_devx/configs/project_bite_config.py (114 lines of code) (raw):
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities to construct configs for solutionsteam_jax_bite DAG."""
import datetime
from typing import Tuple, Optional
from dags.common import test_owner
from xlml.apis import gcp_config, metric_config, task, test_config
from dags import gcs_bucket
from dags.sparsity_diffusion_devx.configs import common
from dags.common.vm_resource import TpuVersion, Project
from airflow.models.taskmixin import DAGNode
GCS_SUBFOLDER_PREFIX = test_owner.Team.SPARSITY_DIFFUSION_DEVX.value
def set_up_axlearn(pinned_version) -> Tuple[str]:
reset_version = ""
if pinned_version:
reset_version = f"cd axlearn && git reset --hard {pinned_version} && cd .."
return (
common.UPGRADE_PIP,
common.UPGRADE_SETUPTOOLS,
common.UPGRADE_PACKAGING,
"git clone https://github.com/apple/axlearn.git",
reset_version,
"python -m pip install ./axlearn[core]",
*common.set_up_nightly_jax(),
)
def get_bite_tpu_config(
tpu_version: TpuVersion,
tpu_cores: int,
tpu_zone: str,
runtime_version: str,
model_config: str,
time_out_in_min: int,
task_owner: str,
is_tpu_reserved: bool = False,
pinned_version: Optional[str] = None,
project_name: Optional[Project] = Project.CLOUD_ML_AUTO_SOLUTIONS.value,
network: str = "default",
subnetwork: str = "default",
):
job_gcp_config = gcp_config.GCPConfig(
project_name=project_name,
zone=tpu_zone,
dataset_name=metric_config.DatasetOption.XLML_DATASET,
)
set_up_cmds = set_up_axlearn(pinned_version)
run_model_cmds = (
(
"cd axlearn && python -m axlearn.common.launch_trainer_main"
f" --module=text.gpt.c4_trainer --config={model_config}"
f" --trainer_dir={metric_config.SshEnvVars.GCS_OUTPUT.value}"
f" --data_dir={gcs_bucket.AXLEARN_DIR} --jax_backend=tpu"
),
)
test_name = f"bite_{'pinned_' if pinned_version else ''}{model_config}"
job_test_config = test_config.TpuVmTest(
test_config.Tpu(
version=tpu_version,
cores=tpu_cores,
runtime_version=runtime_version,
reserved=is_tpu_reserved,
network=network,
subnetwork=subnetwork,
),
test_name=test_name,
set_up_cmds=set_up_cmds,
run_model_cmds=run_model_cmds,
timeout=datetime.timedelta(minutes=time_out_in_min),
task_owner=task_owner,
gcs_subfolder=f"{GCS_SUBFOLDER_PREFIX}/jax",
)
return task.run_queued_resource_test(
task_test_config=job_test_config,
task_gcp_config=job_gcp_config,
)
def get_bite_tpu_unittests_config(
tpu_version: TpuVersion,
tpu_cores: int,
tpu_zone: str,
runtime_version: str,
time_out_in_min: int,
task_owner: str,
is_tpu_reserved: bool = False,
pinned_version: Optional[str] = None,
):
unittest_setupcmds = (
# create configuration files needed
"""cat > Dockerfile_CI <<EOF
FROM python:3.10-slim
WORKDIR /workspace
COPY run_tpu_tests.sh /workspace/
RUN apt update -y
RUN apt install -y git
RUN git clone https://github.com/apple/axlearn.git
WORKDIR /workspace/axlearn
RUN pip install --upgrade pip
RUN pip install -e '.[core,dev,gcp]'
RUN pip install grain
RUN pip install google-cloud-aiplatform
RUN pip install -U --pre libtpu-nightly -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
RUN pip install --pre -U jaxlib -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
RUN pip install git+https://github.com/google/jax
RUN pip freeze
EOF
""",
# create script to run the tests inside of the container
# incluedes a basic sanity check python script which prints out TPU env info for reference
"""cat > run_tpu_tests.sh <<EOF
set -x
echo '#### Starting TPU JAX Tests'
pip freeze
JAX_PLATFORMS='tpu' python -c 'import jax; jax.print_environment_info() ; print(f"Global device count: {jax.device_count()}")'
cd /workspace/axlearn
pytest --no-header -v axlearn/common/flash_attention/
EOF
""",
"chmod +x run_tpu_tests.sh",
"sudo docker build -f Dockerfile_CI -t ml-auto-solutions/tpu_unittests .",
)
# Run the unittest as non-root user, ulimit param req to mmap TPUs inside docker (default limit is 8192)
unittest_runcmds = (
"echo '#### Start docker image - tpu_unittests'",
"sudo docker run --network=host --privileged --ulimit memlock=-1:-1 ml-auto-solutions/tpu_unittests /bin/bash -c '/workspace/run_tpu_tests.sh'",
)
job_gcp_config = gcp_config.GCPConfig(
project_name=Project.CLOUD_ML_AUTO_SOLUTIONS.value,
zone=tpu_zone,
dataset_name=metric_config.DatasetOption.XLML_DATASET,
)
tpu_unittests_test_config = test_config.TpuVmTest(
test_config.Tpu(
version=tpu_version,
cores=tpu_cores,
runtime_version=runtime_version,
reserved=is_tpu_reserved,
),
test_name="bite_unittests",
set_up_cmds=unittest_setupcmds,
run_model_cmds=unittest_runcmds,
timeout=datetime.timedelta(minutes=time_out_in_min),
task_owner=task_owner,
gcs_subfolder=f"{GCS_SUBFOLDER_PREFIX}/jax",
)
return task.run_queued_resource_test(
task_test_config=tpu_unittests_test_config,
task_gcp_config=job_gcp_config,
)