in assets/training/distillation/src/generate_data_preprocess.py [0:0]
def data_import(args: Namespace):
"""Copy the user data to output dir."""
train_file_path = args.train_file_path
validation_file_path = args.validation_file_path
generated_train_payload_path = args.generated_train_payload_path
generated_validation_payload_path = args.generated_validation_payload_path
teacher_model_endpoint_name = args.teacher_model_endpoint_name
teacher_model_endpoint_url = args.teacher_model_endpoint_url
teacher_model_endpoint_key = args.teacher_model_endpoint_key
# add optional data-generator params
teacher_model_max_new_tokens = args.teacher_model_max_new_tokens
teacher_model_temperature = args.teacher_model_temperature
teacher_model_top_p = args.teacher_model_top_p
teacher_model_frequency_penalty = args.teacher_model_frequency_penalty
teacher_model_presence_penalty = args.teacher_model_presence_penalty
teacher_model_stop = args.teacher_model_stop
enable_cot_str = args.enable_chain_of_thought
enable_cod_str = args.enable_chain_of_density
max_len_summary = args.max_len_summary
data_generation_task_type = args.data_generation_task_type
hash_train_data = args.hash_train_data
hash_validation_data = args.hash_validation_data
batch_config_connection = args.batch_config_connection
# validate file formats
validate_file_paths_with_supported_formats(
[args.train_file_path, args.validation_file_path]
)
logger.info("File format validation successful.")
enable_cot = True if enable_cot_str.lower() == "true" else False
enable_cod = True if enable_cod_str.lower() == "true" else False
mlclient_ws = get_workspace_mlclient()
if not mlclient_ws:
raise Exception("Could not create MLClient for current workspace")
if teacher_model_endpoint_name:
endpoint_details = get_endpoint_details(
mlclient_ws, teacher_model_endpoint_name
)
teacher_model_endpoint_key = endpoint_details.get_endpoint_key()
teacher_model_endpoint_url = endpoint_details.get_endpoint_url()
teacher_model_asset_id = endpoint_details.get_deployed_model_id()
validate_teacher_model_details(teacher_model_asset_id)
if not teacher_model_endpoint_url:
raise Exception("Endpoint URL is a requried parameter for data generation")
if not teacher_model_endpoint_key:
raise Exception("Endpoint key is a requried parameter for data generation")
if teacher_model_top_p < 0 or teacher_model_top_p > 1:
raise Exception(
f"Invalid teacher_model_top_p. Value should be 0<=val<=1, but it is {teacher_model_top_p}"
)
if teacher_model_temperature < 0 or teacher_model_temperature > 1:
raise Exception(
f"Invalid teacher_model_temperature. Value should be 0<=val<=1, but it is {teacher_model_temperature}"
)
inference_params = {
MAX_NEW_TOKENS: (
DEFAULT_SUMMARY_MAX_NEW_TOKENS
if data_generation_task_type == "SUMMARIZATION"
and teacher_model_max_new_tokens == DEFAULT_MAX_NEW_TOKENS
else teacher_model_max_new_tokens
),
TEMPERATURE: teacher_model_temperature,
TOP_P: teacher_model_top_p,
}
if teacher_model_frequency_penalty:
inference_params[FREQUENCY_PENALTY] = teacher_model_frequency_penalty
if teacher_model_presence_penalty:
inference_params[PRESENCE_PENALTY] = teacher_model_presence_penalty
if teacher_model_stop:
inference_params[STOP_TOKEN] = teacher_model_stop
if VLLM_CHAT_SCORE_PATH not in teacher_model_endpoint_url:
teacher_model_endpoint_url += VLLM_CHAT_SCORE_PATH
logger.info(f"Teacher Endpoint : {teacher_model_endpoint_url}")
try:
guid = uuid.uuid4()
short_guid = str(guid)[:8]
connection_name = f"distillation-ws-connection-{short_guid}"
mlclient_ws.connections.create_or_update(
ServerlessConnection(
name=connection_name,
endpoint=teacher_model_endpoint_url,
api_key=teacher_model_endpoint_key,
)
)
logger.info(f"Connection created with name: {connection_name}")
config = {}
config["scoring_url"] = teacher_model_endpoint_url
config["connection_name"] = connection_name
with open(batch_config_connection, "w") as f:
json.dump(config, f)
except Exception as e:
logger.error(
f"Failed to create connection for teacher model batch score invocation : {e}"
)
raise Exception(
"Failed to create workspace connection for teacher model batch score invocation "
)
logger.info("Running data preprocessing")
preprocess_data(
inference_params=inference_params,
enable_cot=enable_cot,
enable_cod=enable_cod,
max_len_summary=max_len_summary,
generated_train_payload_path=generated_train_payload_path,
generated_validation_payload_path=generated_validation_payload_path,
train_file_path=train_file_path,
data_generation_task_type=data_generation_task_type,
validation_file_path=validation_file_path,
hash_train_data=hash_train_data,
hash_validation_data=hash_validation_data,
)