pai/common/git_utils.py (175 lines of code) (raw):
# Copyright 2023 Alibaba, Inc. or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
import os
import subprocess
import tempfile
import warnings
from typing import Dict, Optional
import six
from six.moves import urllib
from .logging import get_logger
logger = get_logger(__name__)
def git_clone_repo(git_config: Dict[str, str], source_dir: Optional[str] = None):
"""Git clone the required repo and checkout the required branch and commit.
This method will clone the repo to a temporary directory, checkout the required branch and commit,
and return a dict that contains the updated value of ``source_dir``.
Example::
git_config = {
"repo": "https://github.com/your_repo.git",
"branch": "master",
"commit": "xxxxxxx",
"username": "xxxxxxx",
"password": "xxxxxxx",
"token": "xxxxxxx",
}
updated_args = git_clone_repo(git_config, source_dir="./train/src/")
Args:
git_config (Dict[str, str]): Git configuration used to clone the repo. Including
``repo``, ``branch``, ``commit``, ``username``, ``password`` and ``token``.
The ``repo`` is required. All other fields are optional. ``repo`` specifies
the Git repository. If you don't provide ``branch``, the default value 'master'
is used. If you don't provide ``commit``, the latest commit in the specified
branch is used. ``username``, ``password`` and ``token`` are for authentication
purpose.
source_dir (Optional[str], optional): A relative location to a directory in the git
repo (default: None). If you don't provide this argument, the root directory of
the git repo is used. If you provide this argument, the source directory must
exist in the git repo.
Returns:
dict: A dict that contains the updated value of ``source_dir``.
"""
_validate_git_config(git_config)
dest_dir = tempfile.mkdtemp()
_build_and_run_clone_command(git_config, dest_dir)
_checkout_commit(git_config, dest_dir)
updated_args = {
"source_dir": source_dir,
}
if source_dir:
if not os.path.isdir(os.path.join(dest_dir, source_dir)):
raise ValueError("Source directory does not exist in the repo.")
updated_args["source_dir"] = os.path.join(dest_dir, source_dir)
else:
updated_args["source_dir"] = dest_dir
return updated_args
def _validate_git_config(git_config):
"""Validate the git configuration.
Check if ``repo`` is provided and if the values in ``git_config`` are strings.
Args:
git_config (Dict[str, str]): Git configuration to be validated.
Raises:
ValueError: If ``repo`` is not provided or the values in ``git_config`` are not strings.
"""
if "repo" not in git_config:
raise ValueError(
"repo not found in git_config. Please provide a repo for git_config."
)
for key in git_config:
if not isinstance(git_config[key], six.string_types):
raise ValueError(f"'{key}' must be a string.")
def _build_and_run_clone_command(git_config, dest_dir):
"""Build and run the clone command.
If ``repo`` in git_config is valid, build and run the clone command. Otherwise, raise an error.
Args:
git_config (Dict[str, str]): Git configuration used to clone the repo.
dest_dir (str): The destination directory to clone the repo to.
Raises:
ValueError: If ``repo`` provided is not supported.
"""
if git_config["repo"].startswith("https://codeup.aliyun.com/") or git_config[
"repo"
].startswith("git@codeup.aliyun.com"):
_clone_command_for_codeup(git_config, dest_dir)
elif git_config["repo"].startswith("https://github.com/") or git_config[
"repo"
].startswith("git@github.com"):
_clone_command_for_github(git_config, dest_dir)
else:
raise ValueError("repo provided is not supported.")
def _clone_command_for_codeup(git_config, dest_dir):
"""Build and run the clone command for Alibaba Codeup.
If ``repo`` starts with ``https://``, use https to clone the repo. If ``repo`` starts with ``git@``,
use ssh to clone the repo. Otherwise, raise an error.
Args:
git_config (Dict[str, str]): Git configuration used to clone the repo.
dest_dir (str): The destination directory to clone the repo to.
"""
is_ssh = git_config["repo"].startswith("git@")
is_https = git_config["repo"].startswith("https://")
if is_ssh:
_clone_command_for_ssh(git_config, dest_dir)
elif is_https:
_clone_command_for_codeup_https(git_config, dest_dir)
else:
raise ValueError("repo must start with 'https://' or 'git@'.")
def _clone_command_for_github(git_config, dest_dir):
"""Build and run the clone command for GitHub.
If ``repo`` starts with ``https://``, use https to clone the repo. If ``repo`` starts with ``git@``,
use ssh to clone the repo. Otherwise, raise an error.
Args:
git_config (Dict[str, str]): Git configuration used to clone the repo.
dest_dir (str): The destination directory to clone the repo to.
"""
is_ssh = git_config["repo"].startswith("git@")
is_https = git_config["repo"].startswith("https://")
if is_ssh:
_clone_command_for_ssh(git_config, dest_dir)
elif is_https:
_clone_command_for_github_https(git_config, dest_dir)
else:
raise ValueError("repo must start with 'https://' or 'git@'.")
def _clone_command_for_ssh(git_config, dest_dir):
"""Build and run the clone command for GitHub via SSH.
Clone the repo via SSH. All credentials in ``git_config`` are ignored.
Args:
git_config (Dict[str, str]): Git configuration used to clone the repo.
dest_dir (str): The destination directory to clone the repo to.
"""
if "username" in git_config or "password" in git_config or "token" in git_config:
warnings.warn(
"``username``, ``password``, and ``token`` are not used when cloning via SSH."
)
_clone_command(git_config["repo"], dest_dir, branch=git_config.get("branch"))
def _clone_command_for_github_https(git_config, dest_dir):
"""Build and run the clone command for GitHub via HTTPS.
Clone the repo via HTTPS. If ``token`` is provided, use token to clone the repo. If ``username``
and ``password`` are provided, use ``username`` and ``password`` to clone the repo. Otherwise,
clone the repo without authentication.
Args:
git_config (Dict[str, str]): Git configuration used to clone the repo.
dest_dir (str): The destination directory to clone the repo to.
"""
repo_url = git_config["repo"]
updated_url = repo_url
if "token" in git_config:
if "username" in git_config or "password" in git_config:
warnings.warn(
"``username`` and ``password`` are not used when ``token`` is provided."
)
updated_url = _update_url_with_token(repo_url, git_config["token"])
elif "username" in git_config and "password" in git_config:
updated_url = _update_url_with_username_and_password(
repo_url, git_config["username"], git_config["password"]
)
elif "username" in git_config or "password" in git_config:
warnings.warn(
"``username`` and ``password`` need to be provided together. Credentials provided in git config will be ignored"
)
else:
warnings.warn(
"No credentials provided. If the repo is private, cloning will fail."
)
_clone_command(updated_url, dest_dir, branch=git_config.get("branch"))
def _clone_command_for_codeup_https(git_config, dest_dir):
"""Build and run the clone command for Codeup via HTTPS.
Clone the repo via HTTPS. If ``username`` and ``token`` are provided, use ``username`` and
``token`` to clone the repo. If ``username`` and ``password`` are provided, use ``username``
and ``password`` to clone the repo. Otherwise, clone the repo without authentication.
Args:
git_config (Dict[str, str]): Git configuration used to clone the repo.
dest_dir (str): The destination directory to clone the repo to.
"""
repo_url = git_config["repo"]
updated_url = repo_url
if "username" in git_config and "token" in git_config:
if "password" in git_config:
warnings.warn("``password`` is not used when ``token`` is provided.")
updated_url = _update_url_with_username_and_password(
repo_url, git_config["username"], git_config["token"]
)
elif "username" in git_config and "password" in git_config:
updated_url = _update_url_with_username_and_password(
repo_url, git_config["username"], git_config["password"]
)
elif "username" in git_config or "password" in git_config or "token" in git_config:
warnings.warn(
"``username`` and ``password``/``token`` of Codeup account need to be "
"provided together. Credentials provided in git config will be ignored."
)
else:
warnings.warn(
"No credentials provided. If the repo is private, cloning will fail."
)
if "commit" not in git_config and "branch" in git_config:
# do shallow clone for the specific branch
shallow_clone_branch = _clone_command(
updated_url, dest_dir, branch=git_config.get("branch")
)
else:
shallow_clone_branch = None
_clone_command(updated_url, dest_dir, branch=shallow_clone_branch)
def _clone_command(repo_url, dest_dir, branch=None):
"""Build and run the clone command.
Clone the repo to ``dest_dir``.
Args:
repo_url (str): The URL of the repo to be cloned.
dest_dir (str): The destination directory to clone the repo to.
branch (str): The specific branch to be cloned.
Raises:
ValueError: If ``repo_url`` does not start with ``https://`` or ``git@``.
"""
my_env = os.environ.copy()
if branch:
# shallow clone the specific branch/tag
git_command = [
"git",
"clone",
"-c",
"advice.detachedHead=false", # disable detached head warning
"--depth",
"1",
"--branch",
branch,
repo_url,
dest_dir,
]
else:
git_command = ["git", "clone", repo_url, dest_dir]
if repo_url.startswith("git@"):
with tempfile.NamedTemporaryFile() as sshnoprompt:
with open(sshnoprompt.name, "w") as write_pipe:
write_pipe.write("ssh -oBatchMode=yes $@")
os.chmod(sshnoprompt.name, 0o511)
my_env["GIT_SSH"] = sshnoprompt.name
subprocess.check_call(git_command, env=my_env)
elif repo_url.startswith("https://"):
my_env["GIT_TERMINAL_PROMPT"] = "0"
subprocess.check_call(git_command, env=my_env)
else:
raise ValueError("repo must start with 'https://' or 'git@'.")
def _update_url_with_token(repo_url, token):
"""Update the URL with token.
Update the URL with token. If the URL already contains token, return the URL as is.
Args:
repo_url (str): The URL of the repo to be cloned.
token (str): The token used to clone the repo.
Returns:
str: The updated URL for the git clone command.
"""
index = len("https://")
if repo_url.find(token) == index:
return repo_url
updated_url = repo_url[:index] + token + "@" + repo_url[index:]
return updated_url
def _update_url_with_username_and_password(repo_url, username, password):
"""Update the URL with username and password.
Update the URL with username and password.
Args:
repo_url (str): The URL of the repo to be cloned.
username (str): The username used to clone the repo.
password (str): The password used to clone the repo.
Returns:
str: The updated URL for the git clone command.
"""
index = len("https://")
password = urllib.parse.quote(password)
updated_url = repo_url[:index] + username + ":" + password + "@" + repo_url[index:]
return updated_url
def _checkout_commit(git_config, dest_dir):
"""Checkout the commit specified in ``git_config``.
If ``commit`` is specified in ``git_config``, checkout the commit.
Args:
git_config (Dict[str, str]): Git configuration used to clone the repo.
dest_dir (str): The destination directory to clone the repo to.
"""
if "branch" in git_config and "commit" in git_config:
logger.warning(
"commit and branch are both specified in git config, ignore branch."
)
if "commit" in git_config:
subprocess.check_call(
args=[
"git",
"-c",
"advice.detachedHead=false", # disable detached head warning
"checkout",
git_config["commit"],
],
cwd=str(dest_dir),
)