def data_import()

in assets/training/distillation/src/generate_data_preprocess.py [0:0]


def data_import(args: Namespace):
    """Copy the user data to output dir."""
    train_file_path = args.train_file_path
    validation_file_path = args.validation_file_path
    generated_train_payload_path = args.generated_train_payload_path
    generated_validation_payload_path = args.generated_validation_payload_path
    teacher_model_endpoint_name = args.teacher_model_endpoint_name
    teacher_model_endpoint_url = args.teacher_model_endpoint_url
    teacher_model_endpoint_key = args.teacher_model_endpoint_key
    # add optional data-generator params
    teacher_model_max_new_tokens = args.teacher_model_max_new_tokens
    teacher_model_temperature = args.teacher_model_temperature
    teacher_model_top_p = args.teacher_model_top_p
    teacher_model_frequency_penalty = args.teacher_model_frequency_penalty
    teacher_model_presence_penalty = args.teacher_model_presence_penalty
    teacher_model_stop = args.teacher_model_stop
    enable_cot_str = args.enable_chain_of_thought
    enable_cod_str = args.enable_chain_of_density
    max_len_summary = args.max_len_summary
    data_generation_task_type = args.data_generation_task_type
    hash_train_data = args.hash_train_data
    hash_validation_data = args.hash_validation_data
    batch_config_connection = args.batch_config_connection

    # validate file formats
    validate_file_paths_with_supported_formats(
        [args.train_file_path, args.validation_file_path]
    )
    logger.info("File format validation successful.")

    enable_cot = True if enable_cot_str.lower() == "true" else False
    enable_cod = True if enable_cod_str.lower() == "true" else False

    mlclient_ws = get_workspace_mlclient()
    if not mlclient_ws:
        raise Exception("Could not create MLClient for current workspace")

    if teacher_model_endpoint_name:
        endpoint_details = get_endpoint_details(
            mlclient_ws, teacher_model_endpoint_name
        )
        teacher_model_endpoint_key = endpoint_details.get_endpoint_key()
        teacher_model_endpoint_url = endpoint_details.get_endpoint_url()
        teacher_model_asset_id = endpoint_details.get_deployed_model_id()
        validate_teacher_model_details(teacher_model_asset_id)

    if not teacher_model_endpoint_url:
        raise Exception("Endpoint URL is a requried parameter for data generation")

    if not teacher_model_endpoint_key:
        raise Exception("Endpoint key is a requried parameter for data generation")

    if teacher_model_top_p < 0 or teacher_model_top_p > 1:
        raise Exception(
            f"Invalid teacher_model_top_p. Value should be 0<=val<=1, but it is {teacher_model_top_p}"
        )
    if teacher_model_temperature < 0 or teacher_model_temperature > 1:
        raise Exception(
            f"Invalid teacher_model_temperature. Value should be 0<=val<=1, but it is {teacher_model_temperature}"
        )
    inference_params = {
        MAX_NEW_TOKENS: (
            DEFAULT_SUMMARY_MAX_NEW_TOKENS
            if data_generation_task_type == "SUMMARIZATION"
            and teacher_model_max_new_tokens == DEFAULT_MAX_NEW_TOKENS
            else teacher_model_max_new_tokens
        ),
        TEMPERATURE: teacher_model_temperature,
        TOP_P: teacher_model_top_p,
    }

    if teacher_model_frequency_penalty:
        inference_params[FREQUENCY_PENALTY] = teacher_model_frequency_penalty

    if teacher_model_presence_penalty:
        inference_params[PRESENCE_PENALTY] = teacher_model_presence_penalty

    if teacher_model_stop:
        inference_params[STOP_TOKEN] = teacher_model_stop

    if VLLM_CHAT_SCORE_PATH not in teacher_model_endpoint_url:
        teacher_model_endpoint_url += VLLM_CHAT_SCORE_PATH

    logger.info(f"Teacher Endpoint : {teacher_model_endpoint_url}")

    try:
        guid = uuid.uuid4()
        short_guid = str(guid)[:8]
        connection_name = f"distillation-ws-connection-{short_guid}"
        mlclient_ws.connections.create_or_update(
            ServerlessConnection(
                name=connection_name,
                endpoint=teacher_model_endpoint_url,
                api_key=teacher_model_endpoint_key,
            )
        )
        logger.info(f"Connection created with name: {connection_name}")
        config = {}
        config["scoring_url"] = teacher_model_endpoint_url
        config["connection_name"] = connection_name
        with open(batch_config_connection, "w") as f:
            json.dump(config, f)
    except Exception as e:
        logger.error(
            f"Failed to create connection for teacher model batch score invocation : {e}"
        )
        raise Exception(
            "Failed to create workspace connection for teacher model batch score invocation "
        )

    logger.info("Running data preprocessing")
    preprocess_data(
        inference_params=inference_params,
        enable_cot=enable_cot,
        enable_cod=enable_cod,
        max_len_summary=max_len_summary,
        generated_train_payload_path=generated_train_payload_path,
        generated_validation_payload_path=generated_validation_payload_path,
        train_file_path=train_file_path,
        data_generation_task_type=data_generation_task_type,
        validation_file_path=validation_file_path,
        hash_train_data=hash_train_data,
        hash_validation_data=hash_validation_data,
    )