def prepare()

in paper/experiments/mturk/prepare_mturk.py [0:0]


def prepare(data_folder, out_folder, task, fluency_samples=150, fidelity_samples=24):
    '''Select the specified number of samples for annotation from each system, sampling based on the classification
    of the heuristic semantic error rate checkers (where available) and the semantic fidelity classifiers'''
    data_folder = Path(data_folder)
    out_folder = Path(out_folder)

    for dataset in datasets:

        print(f"processing {dataset}")
        systems_data = {}

        text_key = dataset_fields[dataset]["text"]
        original_text_key = f"original_{text_key.strip()}"
        orig_data_key = dataset_fields[dataset]["original_data"]

        for system in systems:
            systems_data[system] = json.load(open(data_folder / dataset / f"{system}.json"))

        sota_data = systems_data["sota"]
        print(task)
        if task == "fluency":
            print(fluency_samples)
            sample_indices = random.sample(list(range(len(sota_data))), fluency_samples)
            print(sample_indices)
        elif task == "fidelity_annotations":
            print(fidelity_samples)
            sample_per_system = 1

            sampling_systems = ["human", "sota", "systemFcPost"]
            sample_indices = []
            while len(sample_indices) < fidelity_samples:

                for system in sampling_systems:
                    print(system)
                    if dataset == "ldc":
                        # For LDC there is no heuristic semantic error rate, only the semantic fidelity classifier
                        sfc_correct = [
                            i
                            for i, item in enumerate(systems_data[system])
                            if item["sfc_correct"] == 1 and i not in sample_indices
                        ]
                        sampled = sample(sfc_correct, sample_per_system)
                        print(f"sfc_correct: {len(sampled)}")
                        sample_indices.extend(sampled)

                        sfc_wrong = [
                            i
                            for i, item in enumerate(systems_data[system])
                            if item["sfc_correct"] == 0 and i not in sample_indices
                        ]
                        sampled = sample(sfc_wrong, sample_per_system)
                        print(f"sfc_wrong: {len(sampled)}")
                        sample_indices.extend(sampled)

                    else:
                        sfc_correct_ser_wrong = [
                            i
                            for i, item in enumerate(systems_data[system])
                            if item["sfc_correct"] == 1 and item["ser_correct"] == 0 and i not in sample_indices
                        ]
                        sampled = sample(sfc_correct_ser_wrong, sample_per_system)
                        sample_indices.extend(sampled)
                        print(f"sfc_correct_ser_wrong: {len(sampled)}")

                        sfc_wrong_ser_correct = [
                            i
                            for i, item in enumerate(systems_data[system])
                            if item["sfc_correct"] == 0 and item["ser_correct"] == 1 and i not in sample_indices
                        ]
                        sampled = sample(sfc_wrong_ser_correct, sample_per_system)
                        sample_indices.extend(sampled)
                        print(f"sfc_wrong_ser_correct: {len(sampled)}")

                        both_wrong = [
                            i
                            for i, item in enumerate(systems_data[system])
                            if item["sfc_correct"] == 0 and item["ser_correct"] == 0 and i not in sample_indices
                        ]
                        sampled = sample(both_wrong, sample_per_system)
                        sample_indices.extend(sampled)
                        print(f"both_wrong: {len(sampled)}")

                        both_correct = [
                            i
                            for i, item in enumerate(systems_data[system])
                            if item["sfc_correct"] == 1 and item["ser_correct"] == 1 and i not in sample_indices
                        ]
                        sampled = sample(both_correct, sample_per_system)
                        sample_indices.extend(sampled)
                        print(f"both_correct: {len(sampled)}")

            sample_indices = random.sample(sample_indices, fidelity_samples)
            assert len(sample_indices) == len(set(sample_indices))
            assert len(sample_indices) == fidelity_samples
        mturk_data = []

        for i in sample_indices:
            texts = []
            for system in systems_data:
                system_text = systems_data[system][i][text_key][0]
                original_text = systems_data[system][i][original_text_key]
                texts.append((system_text, system))

            random.shuffle(texts)

            def preprocess_text(t):
                if dataset == "ldc":
                    # The SOTA results for LDC are lowercased, so we lowercase also for consistent annotations
                    return t.lower()
                else:
                    return t

            data_i = {f"text{j + 1}": preprocess_text(texts[j][0]) for j in range(len(texts))}
            data_i.update({f"system{j + 1}": texts[j][1] for j in range(len(texts))})
            data_i["humantext"] = original_text
            data_i["index"] = i
            data_i["data"] = sota_data[i][orig_data_key]

            data_i["data"] = data_i["data"].replace("<", "{").replace(">", "}").replace(";", "<br>")
            data_i["data"] = data_i["data"].replace("|||", "<br>")
            data_i["data"] = data_i["data"].strip()

            mturk_data.append(data_i)

        mturk_df = pd.DataFrame(mturk_data)

        mturk_out_folder = out_folder / task
        print(f"writing files to {mturk_out_folder}")
        mturk_out_folder.mkdir(parents=True, exist_ok=True)
        mturk_df.to_csv(mturk_out_folder / f"mturk_{dataset}.csv", sep=",", index=False)