def get_task_params()

in src/autotrain/app/params.py [0:0]


def get_task_params(task, param_type):
    """
    Retrieve task-specific parameters while filtering out hidden parameters based on the task and parameter type.

    Args:
        task (str): The task identifier, which can include prefixes like "llm", "st:", "vlm:", etc.
        param_type (str): The type of parameters to retrieve, typically "basic" or other types.

    Returns:
        dict: A dictionary of task-specific parameters with hidden parameters filtered out.

    Notes:
        - The function handles various task prefixes and adjusts the task and trainer variables accordingly.
        - Hidden parameters are filtered out based on the task and parameter type.
        - Additional hidden parameters are defined for specific tasks and trainers.
    """
    if task.startswith("llm"):
        trainer = task.split(":")[1].lower()
        task = task.split(":")[0].lower()

    if task.startswith("st:"):
        trainer = task.split(":")[1].lower()
        task = task.split(":")[0].lower()

    if task.startswith("vlm:"):
        trainer = task.split(":")[1].lower()
        task = task.split(":")[0].lower()

    if task.startswith("tabular"):
        task = "tabular"

    if task not in PARAMS:
        return {}

    task_params = PARAMS[task]
    task_params = {k: v for k, v in task_params.items() if k not in HIDDEN_PARAMS}
    if task == "llm":
        more_hidden_params = []
        if trainer == "sft":
            more_hidden_params = [
                "model_ref",
                "dpo_beta",
                "add_eos_token",
                "max_prompt_length",
                "max_completion_length",
            ]
        elif trainer == "reward":
            more_hidden_params = [
                "model_ref",
                "dpo_beta",
                "add_eos_token",
                "max_prompt_length",
                "max_completion_length",
                "unsloth",
            ]
        elif trainer == "orpo":
            more_hidden_params = [
                "model_ref",
                "dpo_beta",
                "add_eos_token",
                "unsloth",
            ]
        elif trainer == "generic":
            more_hidden_params = [
                "model_ref",
                "dpo_beta",
                "max_prompt_length",
                "max_completion_length",
            ]
        elif trainer == "dpo":
            more_hidden_params = [
                "add_eos_token",
                "unsloth",
            ]
        if param_type == "basic":
            more_hidden_params.extend(
                [
                    "padding",
                    "use_flash_attention_2",
                    "disable_gradient_checkpointing",
                    "logging_steps",
                    "eval_strategy",
                    "save_total_limit",
                    "auto_find_batch_size",
                    "warmup_ratio",
                    "weight_decay",
                    "max_grad_norm",
                    "seed",
                    "quantization",
                    "merge_adapter",
                    "lora_r",
                    "lora_alpha",
                    "lora_dropout",
                    "max_completion_length",
                ]
            )
        task_params = {k: v for k, v in task_params.items() if k not in more_hidden_params}
    if task == "text-classification" and param_type == "basic":
        more_hidden_params = [
            "warmup_ratio",
            "weight_decay",
            "max_grad_norm",
            "seed",
            "logging_steps",
            "auto_find_batch_size",
            "save_total_limit",
            "eval_strategy",
            "early_stopping_patience",
            "early_stopping_threshold",
        ]
        task_params = {k: v for k, v in task_params.items() if k not in more_hidden_params}
    if task == "extractive-qa" and param_type == "basic":
        more_hidden_params = [
            "warmup_ratio",
            "weight_decay",
            "max_grad_norm",
            "seed",
            "logging_steps",
            "auto_find_batch_size",
            "save_total_limit",
            "eval_strategy",
            "early_stopping_patience",
            "early_stopping_threshold",
        ]
        task_params = {k: v for k, v in task_params.items() if k not in more_hidden_params}
    if task == "st" and param_type == "basic":
        more_hidden_params = [
            "warmup_ratio",
            "weight_decay",
            "max_grad_norm",
            "seed",
            "logging_steps",
            "auto_find_batch_size",
            "save_total_limit",
            "eval_strategy",
            "early_stopping_patience",
            "early_stopping_threshold",
        ]
        task_params = {k: v for k, v in task_params.items() if k not in more_hidden_params}
    if task == "vlm" and param_type == "basic":
        more_hidden_params = [
            "warmup_ratio",
            "weight_decay",
            "max_grad_norm",
            "seed",
            "logging_steps",
            "auto_find_batch_size",
            "save_total_limit",
            "eval_strategy",
            "early_stopping_patience",
            "early_stopping_threshold",
            "quantization",
            "lora_r",
            "lora_alpha",
            "lora_dropout",
        ]
        task_params = {k: v for k, v in task_params.items() if k not in more_hidden_params}
    if task == "text-regression" and param_type == "basic":
        more_hidden_params = [
            "warmup_ratio",
            "weight_decay",
            "max_grad_norm",
            "seed",
            "logging_steps",
            "auto_find_batch_size",
            "save_total_limit",
            "eval_strategy",
            "early_stopping_patience",
            "early_stopping_threshold",
        ]
        task_params = {k: v for k, v in task_params.items() if k not in more_hidden_params}
    if task == "image-classification" and param_type == "basic":
        more_hidden_params = [
            "warmup_ratio",
            "weight_decay",
            "max_grad_norm",
            "seed",
            "logging_steps",
            "auto_find_batch_size",
            "save_total_limit",
            "eval_strategy",
            "early_stopping_patience",
            "early_stopping_threshold",
        ]
        task_params = {k: v for k, v in task_params.items() if k not in more_hidden_params}
    if task == "image-regression" and param_type == "basic":
        more_hidden_params = [
            "warmup_ratio",
            "weight_decay",
            "max_grad_norm",
            "seed",
            "logging_steps",
            "auto_find_batch_size",
            "save_total_limit",
            "eval_strategy",
            "early_stopping_patience",
            "early_stopping_threshold",
        ]
        task_params = {k: v for k, v in task_params.items() if k not in more_hidden_params}
    if task == "image-object-detection" and param_type == "basic":
        more_hidden_params = [
            "warmup_ratio",
            "weight_decay",
            "max_grad_norm",
            "seed",
            "logging_steps",
            "auto_find_batch_size",
            "save_total_limit",
            "eval_strategy",
            "early_stopping_patience",
            "early_stopping_threshold",
        ]
        task_params = {k: v for k, v in task_params.items() if k not in more_hidden_params}
    if task == "seq2seq" and param_type == "basic":
        more_hidden_params = [
            "warmup_ratio",
            "weight_decay",
            "max_grad_norm",
            "seed",
            "logging_steps",
            "auto_find_batch_size",
            "save_total_limit",
            "eval_strategy",
            "quantization",
            "lora_r",
            "lora_alpha",
            "lora_dropout",
            "target_modules",
            "early_stopping_patience",
            "early_stopping_threshold",
        ]
        task_params = {k: v for k, v in task_params.items() if k not in more_hidden_params}
    if task == "token-classification" and param_type == "basic":
        more_hidden_params = [
            "warmup_ratio",
            "weight_decay",
            "max_grad_norm",
            "seed",
            "logging_steps",
            "auto_find_batch_size",
            "save_total_limit",
            "eval_strategy",
            "early_stopping_patience",
            "early_stopping_threshold",
        ]
        task_params = {k: v for k, v in task_params.items() if k not in more_hidden_params}

    return task_params