in assets/training/distillation/src/generate_data.py [0:0]
def data_import(args: Namespace):
"""Copy the user data to output dir."""
train_file_path = args.train_file_path
validation_file_path = args.validation_file_path
generated_train_file_path = args.generated_train_file_path
generated_validation_file_path = args.generated_validation_file_path
# add optional data-generator params
teacher_model_endpoint_name = args.teacher_model_endpoint_name
teacher_model_endpoint_url = args.teacher_model_endpoint_url
teacher_model_endpoint_key = args.teacher_model_endpoint_key
teacher_model_max_new_tokens = args.teacher_model_max_new_tokens
teacher_model_temperature = args.teacher_model_temperature
teacher_model_top_p = args.teacher_model_top_p
teacher_model_frequency_penalty = args.teacher_model_frequency_penalty
teacher_model_presence_penalty = args.teacher_model_presence_penalty
teacher_model_stop = args.teacher_model_stop
request_batch_size = args.request_batch_size
min_endpoint_success_ratio = args.min_endpoint_success_ratio
enable_cot_str = args.enable_chain_of_thought
enable_cod_str = args.enable_chain_of_density
max_len_summary = args.max_len_summary
data_generation_task_type = args.data_generation_task_type
model_asset_id = args.model_asset_id
# validate file formats
validate_file_paths_with_supported_formats(
[args.train_file_path, args.validation_file_path]
)
logger.info("File format validation successful.")
enable_cot = True if enable_cot_str.lower() == "true" else False
enable_cod = True if enable_cod_str.lower() == "true" else False
mlclient_ws = get_workspace_mlclient()
if not mlclient_ws:
raise Exception("Could not create MLClient for current workspace")
if teacher_model_endpoint_name:
endpoint_details = get_endpoint_details(
mlclient_ws, teacher_model_endpoint_name
)
teacher_model_endpoint_key = endpoint_details.get_endpoint_key()
teacher_model_endpoint_url = endpoint_details.get_endpoint_url()
teacher_model_asset_id = endpoint_details.get_deployed_model_id()
validate_teacher_model_details(teacher_model_asset_id)
if not teacher_model_endpoint_url:
raise Exception("Endpoint URL is a requried parameter for data generation")
if not teacher_model_endpoint_key:
raise Exception("Endpoint key is a requried parameter for data generation")
if teacher_model_top_p < 0 or teacher_model_top_p > 1:
raise Exception(
f"Invalid teacher_model_top_p. Value should be 0<=val<=1, but it is {teacher_model_top_p}"
)
if teacher_model_temperature < 0 or teacher_model_temperature > 1:
raise Exception(
f"Invalid teacher_model_temperature. Value should be 0<=val<=1, but it is {teacher_model_temperature}"
)
if min_endpoint_success_ratio < 0 or min_endpoint_success_ratio > 1:
raise Exception(
f"Invalid min_endpoint_success_ratio. Value should be 0<=val<=1, but it is {min_endpoint_success_ratio}"
)
if request_batch_size <= 0 or request_batch_size > MAX_BATCH_SIZE:
raise Exception(
f"Invalid request_batch_size. Value should be 0<=val<={MAX_BATCH_SIZE}, but it is {request_batch_size}"
)
inference_params = {
MAX_NEW_TOKENS: (
DEFAULT_SUMMARY_MAX_NEW_TOKENS
if data_generation_task_type == "SUMMARIZATION"
and teacher_model_max_new_tokens == DEFAULT_MAX_NEW_TOKENS
else teacher_model_max_new_tokens
),
TEMPERATURE: teacher_model_temperature,
TOP_P: teacher_model_top_p,
}
if teacher_model_frequency_penalty:
inference_params[FREQUENCY_PENALTY] = teacher_model_frequency_penalty
if teacher_model_presence_penalty:
inference_params[PRESENCE_PENALTY] = teacher_model_presence_penalty
if teacher_model_stop:
inference_params[STOP_TOKEN] = teacher_model_stop
if VLLM_CHAT_SCORE_PATH not in teacher_model_endpoint_url:
teacher_model_endpoint_url += VLLM_CHAT_SCORE_PATH
logger.info(f"Teacher Endpoint : {teacher_model_endpoint_url}")
logger.info("Running data generation")
generate_synthetic_data(
teacher_model_endpoint_url=teacher_model_endpoint_url,
teacher_model_endpoint_key=teacher_model_endpoint_key,
inference_params=inference_params,
request_batch_size=request_batch_size,
min_endpoint_success_ratio=min_endpoint_success_ratio,
enable_cot=enable_cot,
enable_cod=enable_cod,
max_len_summary=max_len_summary,
generated_train_file_path=generated_train_file_path,
generated_validation_file_path=generated_validation_file_path,
train_file_path=train_file_path,
data_generation_task_type=data_generation_task_type,
student_model=StudentModels.parse_model_asset_id(model_asset_id),
validation_file_path=validation_file_path
)