def colab_app()

in src/autotrain/app/colab.py [0:0]


def colab_app():
    if not os.path.exists("data"):
        os.makedirs("data")
    MODEL_CHOICES = fetch_models()
    TASK_NAMES = [
        "LLM SFT",
        "LLM ORPO",
        "LLM Generic",
        "LLM DPO",
        "LLM Reward",
        "Text Classification",
        "Text Regression",
        "Sequence to Sequence",
        "Token Classification",
        "Image Classification",
        "Image Regression",
        "Object Detection",
        "Tabular Classification",
        "Tabular Regression",
        "ST Pair",
        "ST Pair Classification",
        "ST Pair Scoring",
        "ST Triplet",
        "ST Question Answering",
    ]

    TASK_MAP = {
        "LLM SFT": "llm:sft",
        "LLM ORPO": "llm:orpo",
        "LLM Generic": "llm:generic",
        "LLM DPO": "llm:dpo",
        "LLM Reward": "llm:reward",
        "Text Classification": "text-classification",
        "Text Regression": "text-regression",
        "Sequence to Sequence": "seq2seq",
        "Token Classification": "token-classification",
        "Image Classification": "image-classification",
        "Image Regression": "image-regression",
        "Object Detection": "image-object-detection",
        "Tabular Classification": "tabular:classification",
        "Tabular Regression": "tabular:regression",
        "ST Pair": "st:pair",
        "ST Pair Classification": "st:pair_class",
        "ST Pair Scoring": "st:pair_score",
        "ST Triplet": "st:triplet",
        "ST Question Answering": "st:qa",
    }

    def _get_params(task, param_type):
        _p = get_task_params(task, param_type=param_type)
        _p["push_to_hub"] = True
        _p = json.dumps(_p, indent=4)
        return _p

    hf_token_label = widgets.HTML("<h5 style='margin-bottom: 0; margin-top: 0;'>Hugging Face Write Token</h5>")
    hf_token = widgets.Password(
        value="", description="", disabled=False, layout=widgets.Layout(margin="0 0 0 0", width="200px")
    )

    hf_user_label = widgets.HTML("<h5 style='margin-bottom: 0; margin-top: 0;'>Hugging Face Username</h5>")
    hf_user = widgets.Text(
        value="", description="", disabled=False, layout=widgets.Layout(margin="0 0 0 0", width="200px")
    )

    base_model_label = widgets.HTML("<h5 style='margin-bottom: 0; margin-top: 0;'>Base Model</h5>")
    base_model = widgets.Text(value=MODEL_CHOICES["llm"][0], disabled=False, layout=widgets.Layout(width="420px"))

    project_name_label = widgets.HTML("<h5 style='margin-bottom: 0; margin-top: 0;'>Project Name</h5>")
    project_name = widgets.Text(
        value=generate_random_string(),
        description="",
        disabled=False,
        layout=widgets.Layout(margin="0 0 0 0", width="200px"),
    )

    task_dropdown_label = widgets.HTML("<h5 style='margin-bottom: 0; margin-top: 0;'>Task</h5>")
    task_dropdown = widgets.Dropdown(
        options=TASK_NAMES,
        value=TASK_NAMES[0],
        description="",
        disabled=False,
        layout=widgets.Layout(margin="0 0 0 0", width="200px"),
    )

    dataset_path_label = widgets.HTML("<h5 style='margin-bottom: 0; margin-top: 0;'>Path</h5>")
    dataset_path = widgets.Text(
        value="", description="", disabled=False, layout=widgets.Layout(margin="0 0 0 0", width="200px")
    )

    train_split_label = widgets.HTML("<h5 style='margin-bottom: 0; margin-top: 0;'>Train Split</h5>")
    train_split = widgets.Text(
        value="", description="", disabled=False, layout=widgets.Layout(margin="0 0 0 0", width="200px")
    )

    valid_split_label = widgets.HTML("<h5 style='margin-bottom: 0; margin-top: 0;'>Valid Split</h5>")
    valid_split = widgets.Text(
        value="",
        placeholder="optional",
        description="",
        disabled=False,
        layout=widgets.Layout(margin="0 0 0 0", width="200px"),
    )

    dataset_source_dropdown_label = widgets.HTML("<h5 style='margin-bottom: 0; margin-top: 0;'>Source</h5>")
    dataset_source_dropdown = widgets.Dropdown(
        options=["Hugging Face Hub", "Local"],
        value="Hugging Face Hub",
        description="",
        disabled=False,
        layout=widgets.Layout(margin="0 0 0 0", width="200px"),
    )

    col_mapping_label = widgets.HTML("<h5 style='margin-bottom: 0; margin-top: 0;'>Column Mapping</h5>")
    col_mapping = widgets.Text(
        value='{"text": "text"}',
        placeholder="",
        description="",
        disabled=False,
        layout=widgets.Layout(margin="0 0 0 0", width="420px"),
    )

    parameters_dropdown = widgets.Dropdown(
        options=["Basic", "Full"], value="Basic", description="", disabled=False, layout=widgets.Layout(width="400px")
    )

    parameters = widgets.Textarea(
        value=_get_params("llm:sft", "basic"),
        description="",
        disabled=False,
        layout=widgets.Layout(height="400px", width="400px"),
    )

    start_training_button = widgets.Button(
        description="Start Training",
        layout=widgets.Layout(width="1000px"),
        disabled=False,
        button_style="",  # 'success', 'info', 'warning', 'danger' or ''
        tooltip="Click to start training",
        icon="check",  # (FontAwesome names without the `fa-` prefix)
    )

    spacer = widgets.Box(layout=widgets.Layout(width="20px"))
    title_hbox0 = widgets.HTML("<h3>Hugging Face Credentials</h3>")
    title_hbox1 = widgets.HTML("<h3>Project Details</h3>")
    title_hbox2 = widgets.HTML("<h3>Dataset Details</h3>")
    title_hbox3 = widgets.HTML("<h3>Parameters</h3>")

    hbox0 = widgets.HBox(
        [
            widgets.VBox([hf_token_label, hf_token]),
            spacer,
            widgets.VBox([hf_user_label, hf_user]),
        ]
    )
    hbox1 = widgets.HBox(
        [
            widgets.VBox([project_name_label, project_name]),
            spacer,
            widgets.VBox([task_dropdown_label, task_dropdown]),
        ]
    )
    hbox2_1 = widgets.HBox(
        [
            widgets.VBox([dataset_source_dropdown_label, dataset_source_dropdown]),
            spacer,
            widgets.VBox([dataset_path_label, dataset_path]),
        ]
    )
    hbox2_2 = widgets.HBox(
        [
            widgets.VBox([train_split_label, train_split]),
            spacer,
            widgets.VBox([valid_split_label, valid_split]),
        ]
    )
    hbox2_3 = widgets.HBox(
        [
            widgets.VBox([col_mapping_label, col_mapping]),
        ]
    )
    hbox3 = widgets.VBox([parameters_dropdown, parameters])

    vbox0 = widgets.VBox([title_hbox0, hbox0])
    vbox1 = widgets.VBox([title_hbox1, base_model_label, base_model, hbox1])
    vbox2 = widgets.VBox([title_hbox2, hbox2_1, hbox2_2, hbox2_3])
    vbox3 = widgets.VBox([title_hbox3, hbox3])

    left_column = widgets.VBox([vbox0, vbox1, vbox2], layout=widgets.Layout(width="500px"))
    right_column = widgets.VBox([vbox3], layout=widgets.Layout(width="500px", align_items="flex-end"))

    separator = widgets.HTML('<div style="border-left: 1px solid black; height: 100%;"></div>')

    _main_layout = widgets.HBox([left_column, separator, right_column])
    main_layout = widgets.VBox([_main_layout, start_training_button])

    def on_dataset_change(change):
        if change["new"] == "Local":
            dataset_path.value = "data/"
            train_split.value = "train"
            valid_split.value = ""
        else:
            dataset_path.value = ""
            train_split.value = ""
            valid_split.value = ""

    def update_parameters(*args):
        task = TASK_MAP[task_dropdown.value]
        param_type = parameters_dropdown.value.lower()
        parameters.value = _get_params(task, param_type)

    def update_col_mapping(*args):
        task = TASK_MAP[task_dropdown.value]
        if task in ["llm:sft", "llm:generic"]:
            col_mapping.value = '{"text": "text"}'
            dataset_source_dropdown.disabled = False
            valid_split.disabled = True
        elif task in ["llm:dpo", "llm:orpo"]:
            col_mapping.value = '{"prompt": "prompt", "text": "text", "rejected_text": "rejected_text"}'
            dataset_source_dropdown.disabled = False
            valid_split.disabled = True
        elif task == "llm:reward":
            col_mapping.value = '{"text": "text", "rejected_text": "rejected_text"}'
            dataset_source_dropdown.disabled = False
            valid_split.disabled = True
        elif task == "text-classification":
            col_mapping.value = '{"text": "text", "label": "target"}'
            dataset_source_dropdown.disabled = False
            valid_split.disabled = False
        elif task == "text-regression":
            col_mapping.value = '{"text": "text", "label": "target"}'
            dataset_source_dropdown.disabled = False
            valid_split.disabled = False
        elif task == "token-classification":
            col_mapping.value = '{"text": "tokens", "label": "tags"}'
            dataset_source_dropdown.disabled = False
            valid_split.disabled = False
        elif task == "seq2seq":
            col_mapping.value = '{"text": "text", "label": "target"}'
            dataset_source_dropdown.disabled = False
            valid_split.disabled = False
        elif task == "image-classification":
            col_mapping.value = '{"image": "image", "label": "label"}'
            dataset_source_dropdown.disabled = False
            valid_split.disabled = False
        elif task == "image-regression":
            col_mapping.value = '{"image": "image", "label": "target"}'
            dataset_source_dropdown.disabled = False
            valid_split.disabled = False
        elif task == "image-object-detection":
            col_mapping.value = '{"image": "image", "objects": "objects"}'
            dataset_source_dropdown.disabled = False
            valid_split.disabled = False
        elif task == "tabular:classification":
            col_mapping.value = '{"id": "id", "label": ["target"]}'
            dataset_source_dropdown.disabled = False
            valid_split.disabled = False
        elif task == "tabular:regression":
            col_mapping.value = '{"id": "id", "label": ["target"]}'
            dataset_source_dropdown.disabled = False
            valid_split.disabled = False
        elif task == "st:pair":
            col_mapping.value = '{"sentence1": "anchor", "sentence2": "positive"}'
            dataset_source_dropdown.disabled = False
            valid_split.disabled = False
        elif task == "st:pair_class":
            col_mapping.value = '{"sentence1": "premise", "sentence2": "hypothesis", "target": "label"}'
            dataset_source_dropdown.disabled = False
            valid_split.disabled = False
        elif task == "st:pair_score":
            col_mapping.value = '{"sentence1": "sentence1", "sentence2": "sentence2", "target": "score"}'
            dataset_source_dropdown.disabled = False
            valid_split.disabled = False
        elif task == "st:triplet":
            col_mapping.value = '{"sentence1": "anchor", "sentence2": "positive", "sentence3": "negative"}'
            dataset_source_dropdown.disabled = False
            valid_split.disabled = False
        elif task == "st:qa":
            col_mapping.value = '{"sentence1": "query", "sentence1": "answer"}'
            dataset_source_dropdown.disabled = False
            valid_split.disabled = False
        else:
            col_mapping.value = "Enter column mapping..."

    def update_base_model(*args):
        if TASK_MAP[task_dropdown.value] == "text-classification":
            base_model.value = MODEL_CHOICES["text-classification"][0]
        elif TASK_MAP[task_dropdown.value].startswith("llm"):
            base_model.value = MODEL_CHOICES["llm"][0]
        elif TASK_MAP[task_dropdown.value] == "image-classification":
            base_model.value = MODEL_CHOICES["image-classification"][0]
        elif TASK_MAP[task_dropdown.value] == "seq2seq":
            base_model.value = MODEL_CHOICES["seq2seq"][0]
        elif TASK_MAP[task_dropdown.value] == "tabular:classification":
            base_model.value = MODEL_CHOICES["tabular-classification"][0]
        elif TASK_MAP[task_dropdown.value] == "tabular:regression":
            base_model.value = MODEL_CHOICES["tabular-regression"][0]
        elif TASK_MAP[task_dropdown.value] == "token-classification":
            base_model.value = MODEL_CHOICES["token-classification"][0]
        elif TASK_MAP[task_dropdown.value] == "text-regression":
            base_model.value = MODEL_CHOICES["text-regression"][0]
        elif TASK_MAP[task_dropdown.value] == "image-object-detection":
            base_model.value = MODEL_CHOICES["image-object-detection"][0]
        elif TASK_MAP[task_dropdown.value].startswith("st:"):
            base_model.value = MODEL_CHOICES["sentence-transformers"][0]
        else:
            base_model.value = "Enter base model..."

    def start_training(b):
        start_training_button.disabled = True
        try:
            print("Training is starting... Please wait!")
            os.environ["HF_USERNAME"] = hf_user.value
            os.environ["HF_TOKEN"] = hf_token.value
            train_split_value = train_split.value.strip() if train_split.value.strip() != "" else None
            valid_split_value = valid_split.value.strip() if valid_split.value.strip() != "" else None
            params_val = json.loads(parameters.value)
            if task_dropdown.value.startswith("llm") or task_dropdown.value.startswith("sentence-transformers"):
                params_val["trainer"] = task_dropdown.value.split(":")[1]
                # params_val = {k: v for k, v in params_val.items() if k != "trainer"}

            chat_template = params_val.get("chat_template")
            if chat_template is not None:
                params_val = {k: v for k, v in params_val.items() if k != "chat_template"}

            push_to_hub = params_val.get("push_to_hub", True)
            if "push_to_hub" in params_val:
                params_val = {k: v for k, v in params_val.items() if k != "push_to_hub"}

            config = {
                "task": TASK_MAP[task_dropdown.value].split(":")[0],
                "base_model": base_model.value,
                "project_name": project_name.value,
                "log": "tensorboard",
                "backend": "local",
                "data": {
                    "path": dataset_path.value,
                    "train_split": train_split_value,
                    "valid_split": valid_split_value,
                    "column_mapping": json.loads(col_mapping.value),
                },
                "params": params_val,
                "hub": {
                    "username": "${{HF_USERNAME}}",
                    "token": "${{HF_TOKEN}}",
                    "push_to_hub": push_to_hub,
                },
            }
            if TASK_MAP[task_dropdown.value].startswith("llm"):
                config["data"]["chat_template"] = chat_template
                if config["data"]["chat_template"] == "none":
                    config["data"]["chat_template"] = None

            with open("config.yml", "w") as f:
                yaml.dump(config, f)

            cmd = "autotrain --config config.yml"
            process = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True)
            while True:
                output = process.stdout.readline()
                if output == "" and process.poll() is not None:
                    break
                if output:
                    print(output.strip())

            poll_res = process.poll()
            if poll_res != 0:
                start_training_button.disabled = False
                raise Exception(f"Training failed with exit code: {poll_res}")
            print("Training completed successfully!")
            start_training_button.disabled = False
        except Exception as e:
            print("An error occurred while starting training!")
            print(f"Error: {e}")
            start_training_button.disabled = False

    start_training_button.on_click(start_training)
    dataset_source_dropdown.observe(on_dataset_change, names="value")
    task_dropdown.observe(update_col_mapping, names="value")
    task_dropdown.observe(update_parameters, names="value")
    task_dropdown.observe(update_base_model, names="value")
    parameters_dropdown.observe(update_parameters, names="value")
    return main_layout