in src/deep_learning_container.py [0:0]
def parse_args():
"""
Parsing function to parse input arguments.
Return: args, which containers parsed input arguments.
"""
parser = argparse.ArgumentParser()
parser.add_argument(
"--framework",
choices=["tensorflow", "mxnet", "pytorch"],
help="framework of container image.",
required=True,
)
parser.add_argument(
"--framework-version", help="framework version of container image.", required=True
)
parser.add_argument(
"--container-type",
choices=["training", "inference"],
help="What kind of jobs you want to run on container. Either training or inference.",
required=True,
)
args, _unknown = parser.parse_known_args()
fw_version_pattern = r"\d+(\.\d+){1,2}(-rc\d)?"
# PT 1.10 and above has +cpu or +cu113 string, so handle accordingly
if args.framework == "pytorch":
pt_fw_version_pattern = r"(\d+(\.\d+){1,2}(-rc\d)?)((\+cpu)|(\+cu\d{3})|(a0\+git\w{7}))"
pt_fw_version_match = re.fullmatch(pt_fw_version_pattern, args.framework_version)
if pt_fw_version_match:
args.framework_version = pt_fw_version_match.group(1)
assert re.fullmatch(fw_version_pattern, args.framework_version), (
f"args.framework_version = {args.framework_version} does not match {fw_version_pattern}\n"
f"Please specify framework version as X.Y.Z or X.Y."
)
# TFS 2.12.1 still uses TF 2.12.0 and breaks the telemetry check as it is checking TF version
# instead of TFS version. WE are forcing the version we want.
if (
args.framework == "tensorflow"
and args.container_type == "inference"
and args.framework_version == "2.12.0"
):
args.framework_version = "2.12.1"
return args