in ts_scripts/marsgen.py [0:0]
def generate_mars(mar_config=MAR_CONFIG_FILE_PATH, model_store_dir=MODEL_STORE_DIR):
"""
By default generate_mars reads ts_scripts/mar_config.json and outputs mar files in dir model_store_gen
- mar_config.json defines a list of models' mar file parameters. They are:
- "model_name": model name
- "version": model version
- "model_file": the path of file model.py
- "serialized_file_remote": the url of file .pth or .pt
- "serialized_file_local": the path of file .pth or .pt
- "gen_scripted_file_path": the python script path of building .pt file
- "handler": handler can be either default handler or handler path
- "extra_files": the paths of extra files
Note: To generate .pt file, "serialized_file_remote" and "gen_scripted_file_path" must be provided
"""
print(f"## Starting generate_mars, mar_config:{mar_config}, model_store_dir:{model_store_dir}\n")
mar_set.clear()
cwd = os.getcwd()
os.chdir(REPO_ROOT)
with open(mar_config) as f:
models = json.loads(f.read())
for model in models:
serialized_file_path = None
if model.get("serialized_file_remote") and model["serialized_file_remote"]:
if model.get("gen_scripted_file_path") and model["gen_scripted_file_path"]:
subprocess.run(["python", model["gen_scripted_file_path"]])
else:
serialized_model_file_url = \
"https://download.pytorch.org/models/{}".format(model["serialized_file_remote"])
urllib.request.urlretrieve(
serialized_model_file_url,
f'{model_store_dir}/{model["serialized_file_remote"]}')
serialized_file_path = os.path.join(model_store_dir, model["serialized_file_remote"])
elif model.get("serialized_file_local") and model["serialized_file_local"]:
serialized_file_path = model["serialized_file_local"]
handler = None
if model.get("handler") and model["handler"]:
handler = model["handler"]
extra_files = None
if model.get("extra_files") and model["extra_files"]:
extra_files = model["extra_files"]
runtime = None
if model.get("runtime") and model["runtime"]:
runtime = model["runtime"]
archive_format = None
if model.get("archive_format") and model["archive_format"]:
archive_format = model["archive_format"]
requirements_file = None
if model.get("requirements_file") and model["requirements_file"]:
requirements_file = model["requirements_file"]
export_path = model_store_dir
if model.get("export_path") and model["export_path"]:
export_path = model["export_path"]
cmd = model_archiver_command_builder(model["model_name"], model["version"], model["model_file"],
serialized_file_path, handler, extra_files,
runtime, archive_format, requirements_file, export_path)
print(f"## In directory: {os.getcwd()} | Executing command: {cmd}\n")
try:
subprocess.check_call(cmd, shell=True)
marfile = "{}.mar".format(model["model_name"])
print("## {} is generated.\n".format(marfile))
mar_set.add(marfile)
except subprocess.CalledProcessError as exc:
print("## {} creation failed !, error: {}\n".format(model["model_name"], exc))
if model.get("serialized_file_remote") and \
model["serialized_file_remote"] and \
os.path.exists(serialized_file_path):
os.remove(serialized_file_path)
os.chdir(cwd)