def parse_args()

in src/deep_learning_container.py [0:0]


def parse_args():
    """
    Parsing function to parse input arguments.
    Return: args, which containers parsed input arguments.
    """
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--framework",
        choices=["tensorflow", "mxnet", "pytorch"],
        help="framework of container image.",
        required=True,
    )
    parser.add_argument(
        "--framework-version", help="framework version of container image.", required=True
    )
    parser.add_argument(
        "--container-type",
        choices=["training", "inference"],
        help="What kind of jobs you want to run on container. Either training or inference.",
        required=True,
    )

    args, _unknown = parser.parse_known_args()

    fw_version_pattern = r"\d+(\.\d+){1,2}(-rc\d)?"

    # PT 1.10 and above has +cpu or +cu113 string, so handle accordingly
    if args.framework == "pytorch":
        pt_fw_version_pattern = r"(\d+(\.\d+){1,2}(-rc\d)?)((\+cpu)|(\+cu\d{3})|(a0\+git\w{7}))"
        pt_fw_version_match = re.fullmatch(pt_fw_version_pattern, args.framework_version)
        if pt_fw_version_match:
            args.framework_version = pt_fw_version_match.group(1)
    assert re.fullmatch(fw_version_pattern, args.framework_version), (
        f"args.framework_version = {args.framework_version} does not match {fw_version_pattern}\n"
        f"Please specify framework version as X.Y.Z or X.Y."
    )
    # TFS 2.12.1 still uses TF 2.12.0 and breaks the telemetry check as it is checking TF version
    # instead of TFS version. WE are forcing the version we want.
    if (
        args.framework == "tensorflow"
        and args.container_type == "inference"
        and args.framework_version == "2.12.0"
    ):
        args.framework_version = "2.12.1"

    return args