in assets/training/model_evaluation/src_distributed/model_prediction.py [0:0]
def main():
"""Initialize text-generation-inference server and client."""
extra_params = {}
logger.info("Init Start.")
parser = ArgumentParser()
# Inputs
parser.add_argument("--mlflow_model", type=str, dest="mlflow_model", required=True)
parser.add_argument("--parameters", type=str, dest="parameters", required=False, default="{}")
parser.add_argument("--task", type=str, dest=ArgumentLiterals.TASK, required=True, choices=TEXT_TOKEN_TASKS)
parser.add_argument("--data", type=str, dest=ArgumentLiterals.DATA, required=True)
parser.add_argument("--label-column-name", type=lambda x: x.split(","),
dest=ArgumentLiterals.LABEL_COLUMN_NAME, required=False, default=None)
parser.add_argument("--input-column-names",
type=lambda x: [i.strip() for i in x.split(",") if i and not i.isspace()],
dest=ArgumentLiterals.INPUT_COLUMN_NAMES, required=False, default=None)
parser.add_argument("--batch-size", type=int, dest=ArgumentLiterals.BATCH_SIZE, required=False, default=None)
parser.add_argument("--predictions", type=str, dest=ArgumentLiterals.PREDICTIONS, required=True)
parser.add_argument("--ground-truth", type=str, dest=ArgumentLiterals.GROUND_TRUTHS, required=True)
parser.add_argument("--performance-metadata", type=str, dest=ArgumentLiterals.PERFORMANCE_METADATA,
required=False, default=None)
parser.add_argument("--prediction-probabilities", type=str, dest=ArgumentLiterals.PREDICTION_PROBABILITIES,
required=False, default=None)
args, unknown_args = parser.parse_known_args()
logger.info(f"Distributed Type: {distributed_state.distributed_type}")
try:
tensor_parallel, num_replicas = get_smart_defaults(args.mlflow_model)
except Exception as e:
exception = get_azureml_exception(ModelLoadingException, ModelPredictionInternalError, e,
wrap_azureml_ex=False, error=repr(e))
log_traceback(exception, logger)
raise exception
logger.info(f"Setting Num Replicas to: {num_replicas} and Tensor Parallel to {tensor_parallel}")
os.environ["NUM_REPLICAS"] = str(num_replicas)
os.environ["TENSOR_PARALLEL"] = str(tensor_parallel)
data_path = args.data
logger.info(f"Torch Current Device Count:{torch.cuda.device_count()}")
logger.info(f"Got Params: {args.parameters}")
extra_params.update(json.loads(args.parameters))
logger.info(f"Got Model Path: {args.mlflow_model}")
task_type = args.task
input_column_names, label_column_name, extra_y_test_cols = validate_and_get_columns(vars(args))
try:
_init_cuda_visible_devices()
abs_mlmodel_path = os.path.join(
args.mlflow_model, ModelPath.MLMODEL_PATH
)
mlmodel = {}
if abs_mlmodel_path and os.path.exists(abs_mlmodel_path):
with open(abs_mlmodel_path) as f:
mlmodel = yaml.safe_load(f)
if os.path.exists(os.path.join(args.mlflow_model, ModelPath.DEFAULT_MLFLOW_MODEL_PATH)):
model_path = os.path.join(
args.mlflow_model,
ModelPath.DEFAULT_MLFLOW_MODEL_PATH,
)
config_path = os.path.join(model_path, "config.json")
tokenizer_path = os.path.join(
args.mlflow_model, ModelPath.DEFAULT_TOKENIZER_PATH
)
else:
model_path = os.path.join(args.mlflow_model, ModelPath.DEPRECATED_MLFLOW_MODEL_PATH)
config_path = os.path.join(
args.mlflow_model, ModelPath.DEPRECATED_MLFLOW_CONFIG_PATH, "config.json"
)
if not os.path.exists(config_path):
config_path = os.path.join(model_path, "config.json")
tokenizer_path = os.path.join(
args.mlflow_model, ModelPath.DEPRECATED_MLFLOW_TOKENIZER_PATH
)
if not os.path.exists(tokenizer_path):
tokenizer_path = model_path
inference_config = None
if os.path.exists(os.path.join(args.mlflow_model, ModelPath.INFERENCE_CONFIG_PATH)):
inference_config = os.path.join(args.mlflow_model, ModelPath.INFERENCE_CONFIG_PATH)
engine_config, task_config, default_generator_configs, task_type, model_info = build_configs_from_model(
mlmodel,
model_path,
config_path,
tokenizer_path,
inference_config
)
config = {
"engine": engine_config,
"task": task_config,
}
enable_character_counts, enable_token_counts = False, False
if extra_params.get("token_count_per_sample", False):
enable_token_counts = True
extra_params.pop("token_count_per_sample")
if extra_params.get("char_count_per_sample", False):
enable_character_counts = True
extra_params.pop("char_count_per_sample")
tokenizer = None
if (task_type in TEXT_TOKEN_TASKS and enable_token_counts) or (task_type == SupportedTask.CHAT_COMPLETION or task_type == TaskType.CONVERSATIONAL):
tokenizer = load_tokenizer(engine_config["tokenizer"], engine_config["ml_model_info"].get("hf_tokenizer_class", "AutoTokenizer"))
g_fmscorer = FMScore(config)
g_fmscorer.init()
if os.environ.get("LOGGING_WORKER_ID", "") == str(os.getpid()):
for k, v in os.environ.items():
logger.info(f"env: {k} = {v}")
logger.info(
f"updated default_generator_configs: "
f"{default_generator_configs}"
)
except Exception as e:
exception = get_azureml_exception(ModelLoadingException, BadModel, e, error=repr(e))
log_traceback(exception, logger)
raise exception
try:
data = load_data(task_type, data_path, label_column_name, input_column_names, extra_y_test_cols, args.batch_size)
except Exception as e:
exception = get_azureml_exception(DataLoaderException, BadInputData, e, error=repr(e))
log_traceback(exception, logger)
raise exception
full_data = [(x, y) for x, y in data]
logger.info(f"Dataset size: {len(full_data)}")
predictor = Predictor(g_fmscorer, task_type, extra_params, num_replicas, label_column_name, tokenizer, extra_y_test_cols)
collated_res = [{} for i in range(distributed_state.num_processes)]
with distributed_state.split_between_processes(full_data) as proc_data:
y_pred_proc, y_test_proc, y_perf_proc, y_pred_proba = predictor.predict(proc_data)
proc_res = {"predictions": y_pred_proc, "ground_truth": y_test_proc, "perf": y_perf_proc, "pred_probas": y_pred_proba}
dist.all_gather_object(object_list=collated_res, obj=proc_res)
logger.info("Waiting for all processes.....")
distributed_state.wait_for_everyone()
logger.info(f"Collated Results Lengths: {[len(i) for i in collated_res]}")
y_pred_df, y_test_df, y_perf_df, y_pred_proba_df = _gather_predictions(collated_res)
if task_type != SupportedTask.CHAT_COMPLETION and task_type != TaskType.CONVERSATIONAL:
y_pred_df.columns = ["predictions"]
ground_truth_columns = [label_column_name]
if extra_y_test_cols is not None:
ground_truth_columns += extra_y_test_cols
y_test_df.columns = ground_truth_columns[:]
if distributed_state.is_main_process:
y_pred_df.to_json(args.predictions, orient="records", lines=True)
y_test_df.to_json(args.ground_truths, orient="records", lines=True)
y_perf_df.to_json(args.performance_metadata, orient="records", lines=True)
y_pred_proba_df.to_json(args.prediction_probabilities, orient="records", lines=True)
return