def clean_dataset()

in community-efforts/image_preferences/02_image_prefernces_cleaned_filtered_sfw.py [0:0]


def clean_dataset(batch):
    try:
        batch["nsfw_text"] = []
        batch["nsfw_image"] = []
        evaluated_results_image = defaultdict(list)
        evaluated_results_text = defaultdict(list)

        image_columns = [
            "image_quality_dev",
            "image_simplified_dev",
            "image_quality_sd",
            "image_simplified_sd",
        ]

        for image_column in image_columns:
            results_image = pipe_image(batch[image_column])
            evaluated_results_image[image_column] = [
                res[0]["label"] in ["UNSAFE", "QUESTIONABLE"] for res in results_image
            ]

        try:
            results_text = pipe_text(batch["prompt"])
            results_text_2 = pipe_text_2(batch["prompt"])
            evaluated_results_text["text"] = [
                res["label"] == "NSFW" for res in results_text
            ]
            evaluated_results_text["text_2"] = [
                res["label"] == "NSFW" for res in results_text_2
            ]
        except Exception:
            try:
                results_text_2 = pipe_text_2(batch["prompt"])
                evaluated_results_text["text_2"] = [
                    res["label"] == "NSFW" for res in results_text_2
                ]
                evaluated_results_text["text"] = [False] * len(results_text_2)
            except Exception:
                try:
                    results_text = pipe_text(batch["prompt"])
                    evaluated_results_text["text"] = [
                        res["label"] == "NSFW" for res in results_text
                    ]
                    evaluated_results_text["text_2"] = [False] * len(results_text)
                except Exception:
                    for item in batch["prompt"]:
                        try:
                            evaluated_results_text["text"].append(
                                pipe_text(item)["label"] == "NSFW"
                            )
                        except Exception:
                            evaluated_results_text["text"].append(True)
                        try:
                            evaluated_results_text["text_2"].append(
                                pipe_text_2(item)["label"] == "NSFW"
                            )
                        except Exception:
                            evaluated_results_text["text_2"].append(True)

        for i in range(len(evaluated_results_text["text"])):
            if any(evaluated_results_text[col][i] for col in evaluated_results_text):
                batch["nsfw_text"].append(True)
            else:
                batch["nsfw_text"].append(False)
        for i in range(len(evaluated_results_image["image_quality_dev"])):
            if any(evaluated_results_image[col][i] for col in evaluated_results_image):
                batch["nsfw_image"].append(True)
            else:
                batch["nsfw_image"].append(False)
    except Exception as e:
        raise Exception(e)
    return batch