def data_import()

in assets/training/distillation/src/generate_data.py [0:0]


def data_import(args: Namespace):
    """Copy the user data to output dir."""
    train_file_path = args.train_file_path
    validation_file_path = args.validation_file_path
    generated_train_file_path = args.generated_train_file_path
    generated_validation_file_path = args.generated_validation_file_path

    # add optional data-generator params
    teacher_model_endpoint_name = args.teacher_model_endpoint_name
    teacher_model_endpoint_url = args.teacher_model_endpoint_url
    teacher_model_endpoint_key = args.teacher_model_endpoint_key
    teacher_model_max_new_tokens = args.teacher_model_max_new_tokens
    teacher_model_temperature = args.teacher_model_temperature
    teacher_model_top_p = args.teacher_model_top_p
    teacher_model_frequency_penalty = args.teacher_model_frequency_penalty
    teacher_model_presence_penalty = args.teacher_model_presence_penalty
    teacher_model_stop = args.teacher_model_stop
    request_batch_size = args.request_batch_size
    min_endpoint_success_ratio = args.min_endpoint_success_ratio
    enable_cot_str = args.enable_chain_of_thought
    enable_cod_str = args.enable_chain_of_density
    max_len_summary = args.max_len_summary
    data_generation_task_type = args.data_generation_task_type
    model_asset_id = args.model_asset_id

    # validate file formats
    validate_file_paths_with_supported_formats(
        [args.train_file_path, args.validation_file_path]
    )
    logger.info("File format validation successful.")

    enable_cot = True if enable_cot_str.lower() == "true" else False
    enable_cod = True if enable_cod_str.lower() == "true" else False

    mlclient_ws = get_workspace_mlclient()
    if not mlclient_ws:
        raise Exception("Could not create MLClient for current workspace")

    if teacher_model_endpoint_name:
        endpoint_details = get_endpoint_details(
            mlclient_ws, teacher_model_endpoint_name
        )
        teacher_model_endpoint_key = endpoint_details.get_endpoint_key()
        teacher_model_endpoint_url = endpoint_details.get_endpoint_url()
        teacher_model_asset_id = endpoint_details.get_deployed_model_id()
        validate_teacher_model_details(teacher_model_asset_id)

    if not teacher_model_endpoint_url:
        raise Exception("Endpoint URL is a requried parameter for data generation")

    if not teacher_model_endpoint_key:
        raise Exception("Endpoint key is a requried parameter for data generation")

    if teacher_model_top_p < 0 or teacher_model_top_p > 1:
        raise Exception(
            f"Invalid teacher_model_top_p. Value should be 0<=val<=1, but it is {teacher_model_top_p}"
        )
    if teacher_model_temperature < 0 or teacher_model_temperature > 1:
        raise Exception(
            f"Invalid teacher_model_temperature. Value should be 0<=val<=1, but it is {teacher_model_temperature}"
        )
    if min_endpoint_success_ratio < 0 or min_endpoint_success_ratio > 1:
        raise Exception(
            f"Invalid min_endpoint_success_ratio. Value should be 0<=val<=1, but it is {min_endpoint_success_ratio}"
        )

    if request_batch_size <= 0 or request_batch_size > MAX_BATCH_SIZE:
        raise Exception(
            f"Invalid request_batch_size. Value should be 0<=val<={MAX_BATCH_SIZE}, but it is {request_batch_size}"
        )

    inference_params = {
        MAX_NEW_TOKENS: (
            DEFAULT_SUMMARY_MAX_NEW_TOKENS
            if data_generation_task_type == "SUMMARIZATION"
            and teacher_model_max_new_tokens == DEFAULT_MAX_NEW_TOKENS
            else teacher_model_max_new_tokens
        ),
        TEMPERATURE: teacher_model_temperature,
        TOP_P: teacher_model_top_p,
    }

    if teacher_model_frequency_penalty:
        inference_params[FREQUENCY_PENALTY] = teacher_model_frequency_penalty

    if teacher_model_presence_penalty:
        inference_params[PRESENCE_PENALTY] = teacher_model_presence_penalty

    if teacher_model_stop:
        inference_params[STOP_TOKEN] = teacher_model_stop

    if VLLM_CHAT_SCORE_PATH not in teacher_model_endpoint_url:
        teacher_model_endpoint_url += VLLM_CHAT_SCORE_PATH

    logger.info(f"Teacher Endpoint : {teacher_model_endpoint_url}")

    logger.info("Running data generation")
    generate_synthetic_data(
        teacher_model_endpoint_url=teacher_model_endpoint_url,
        teacher_model_endpoint_key=teacher_model_endpoint_key,
        inference_params=inference_params,
        request_batch_size=request_batch_size,
        min_endpoint_success_ratio=min_endpoint_success_ratio,
        enable_cot=enable_cot,
        enable_cod=enable_cod,
        max_len_summary=max_len_summary,
        generated_train_file_path=generated_train_file_path,
        generated_validation_file_path=generated_validation_file_path,
        train_file_path=train_file_path,
        data_generation_task_type=data_generation_task_type,
        student_model=StudentModels.parse_model_asset_id(model_asset_id),
        validation_file_path=validation_file_path
    )