vision/data/datasets_processing_scripts/create_evaluation_datasets/create_visdial.py (233 lines of code) (raw):
# Download images and dialogs
"""
mkdir /scratch/coco
cd /scratch/coco
wget http://images.cocodataset.org/zips/train2014.zip
unzip train2014.zip
rm train2014.zip
wget http://images.cocodataset.org/zips/val2014.zip
unzip val2014.zip
rm val2014.zip
mkdir /scratch/visdial
cd /scratch/visdial
aws s3 sync s3://m4-datasets/trash/visdial/ ./
unzip VisualDialog_val2018.zip
unzip visdial_1.0_train.zip
unzip visdial_1.0_val.zip
rm *.zip
"""
import copy
import json
import os
import random
import datasets
from datasets import Dataset, load_dataset
from PIL import Image
from tqdm import tqdm
with open("/scratch/visdial/visdial_1.0_train.json") as f:
data_train = json.load(f)["data"]
with open("/scratch/visdial/visdial_1.0_val.json") as f:
data_val = json.load(f)["data"]
with open("/scratch/visdial/visdial_1.0_val_dense_annotations.json") as f:
data_relevance_scores = json.load(f)
data = {
"train": {
"dialogs": data_train["dialogs"],
"questions": data_train["questions"],
"answers": data_train["answers"],
},
"validation": {
"dialogs": data_val["dialogs"],
"questions": data_val["questions"],
"answers": data_val["answers"],
},
}
incomplete_ds = load_dataset("jxu124/visdial")
all_answers = list(set(data["train"]["answers"] + data["validation"]["answers"]))
ds_data = {split: [] for split in ["train", "validation"]}
for split in ["train", "validation"]:
assert len(data[split]["dialogs"]) == len(incomplete_ds[split])
questions = data[split]["questions"]
answers = data[split]["answers"]
for idx_ex, example in enumerate(tqdm(data[split]["dialogs"])):
assert incomplete_ds[split][idx_ex]["image_path"].replace(".jpg", "").endswith(str(example["image_id"]))
image_path = os.path.join("/scratch/", incomplete_ds[split][idx_ex]["image_path"])
# Discard examples using images from COCO val from the train set (but doesn't remove anything in practice).
if "COCO_val2014" in image_path:
pass
caption = example["caption"]
questions_example = []
answers_example = []
possible_answers_example = []
assert len(example["dialog"]) == 10
for d_ex in example["dialog"]:
questions_example.append(questions[d_ex["question"]])
answers_example.append(answers[d_ex["answer"]])
assert len(d_ex["answer_options"]) == 100
possible_answers_example.append([answers[ans_id] for ans_id in d_ex["answer_options"]])
ds_data[split].append(
{
"image_path": image_path,
"caption": caption,
"questions": questions_example,
"answers": answers_example,
"answers_options": possible_answers_example,
}
)
number_of_examples_support_sets = 2048
number_of_examples_qa_validation_query_sets = 1024
repo_id = "HuggingFaceM4/VisDial_modif_support_query_sets"
indices_train_set = list(range(0, len(ds_data["train"])))
random.shuffle(indices_train_set)
remaining_indices_train_set = indices_train_set
indices_validation_support_set, remaining_indices_train_set = (
remaining_indices_train_set[:number_of_examples_support_sets],
remaining_indices_train_set[number_of_examples_support_sets:],
)
indices_test_support_set, remaining_indices_train_set = (
remaining_indices_train_set[:number_of_examples_support_sets],
remaining_indices_train_set[number_of_examples_support_sets:],
)
indices_validation_query_set, remaining_indices_train_set = (
remaining_indices_train_set[:number_of_examples_qa_validation_query_sets],
remaining_indices_train_set[number_of_examples_qa_validation_query_sets:],
)
# print lengths
print(
f"Lengths of the sets:\nvalidation query set: {len(indices_validation_query_set)}\nvalidation support set:"
f" {len(indices_validation_support_set)}\ntest support set: {len(indices_test_support_set)}"
)
# Check that we have no overlap between the sets
print("Intersection between the sets:\n")
print(set(indices_validation_query_set).intersection(set(indices_validation_support_set)))
print(set(indices_validation_query_set).intersection(set(indices_test_support_set)))
# Validation support set
new_ds_image_path = []
new_ds_caption = []
new_ds_context = []
new_ds_answer = []
new_ds_answer_options = []
new_ds_relevance_scores = []
set_indices_validation_support_set = set(indices_validation_support_set)
for idx_ex, example in enumerate(ds_data["train"]):
if idx_ex in set_indices_validation_support_set:
new_ds_image_path.append(example["image_path"])
caption = example["caption"]
new_ds_caption.append(caption)
context = ""
for idx_q_a, (ques, ans) in enumerate(zip(example["questions"], example["answers"])):
if idx_q_a < len(example["questions"]) - 1:
context += f"Question: {ques}? Answer: {ans}. "
else:
context += f"Question: {ques}? Answer: "
new_ds_context.append(context)
new_ds_answer.append(example["answers"][-1])
new_ds_answer_options.append(example["answers_options"][-1])
new_ds_relevance_scores.append([0.0] * 100)
ds_val_support_set = Dataset.from_dict(
{
"image_path": new_ds_image_path,
"caption": new_ds_caption,
"context": new_ds_context,
"answer": new_ds_answer,
"answer_options": new_ds_answer_options,
"relevance_scores": new_ds_relevance_scores,
}
)
# Test support set
new_ds_image_path = []
new_ds_caption = []
new_ds_context = []
new_ds_answer = []
new_ds_answer_options = []
new_ds_relevance_scores = []
set_indices_test_support_set = set(indices_test_support_set)
for idx_ex, example in enumerate(ds_data["train"]):
if idx_ex in set_indices_test_support_set:
new_ds_image_path.append(example["image_path"])
caption = example["caption"]
new_ds_caption.append(caption)
context = ""
for idx_q_a, (ques, ans) in enumerate(zip(example["questions"], example["answers"])):
if idx_q_a < len(example["questions"]) - 1:
context += f"Question: {ques}? Answer: {ans}. "
else:
context += f"Question: {ques}? Answer: "
new_ds_context.append(context)
new_ds_answer.append(example["answers"][-1])
new_ds_answer_options.append(example["answers_options"][-1])
new_ds_relevance_scores.append([0.0] * 100)
ds_test_support_set = Dataset.from_dict(
{
"image_path": new_ds_image_path,
"caption": new_ds_caption,
"context": new_ds_context,
"answer": new_ds_answer,
"answer_options": new_ds_answer_options,
"relevance_scores": new_ds_relevance_scores,
}
)
# Validation query set
new_ds_image_path = []
new_ds_caption = []
new_ds_context = []
new_ds_answer = []
new_ds_answer_options = []
new_ds_relevance_scores = []
set_indices_validation_query_set = set(indices_validation_query_set)
for idx_ex, example in enumerate(ds_data["train"]):
if idx_ex in set_indices_validation_query_set:
new_ds_image_path.append(example["image_path"])
caption = example["caption"]
new_ds_caption.append(caption)
context = ""
idx_q_a_chosen = random.randint(
0, 9
) # We choose one question-answer pair and the previous ones will form the dialog history
for idx_q_a, (ques, ans) in enumerate(zip(example["questions"], example["answers"])):
if idx_q_a < idx_q_a_chosen:
context += f"Question: {ques}? Answer: {ans}. "
elif idx_q_a == idx_q_a_chosen:
context += f"Question: {ques}? Answer: "
new_ds_context.append(context)
new_ds_answer.append(example["answers"][idx_q_a_chosen])
new_ds_answer_options.append(example["answers_options"][idx_q_a_chosen])
# Artificially created relevance scores, since they don't provide them for the training set
relevance_scores = [0] * len(example["answers_options"][idx_q_a_chosen])
relevance_scores[example["answers_options"][idx_q_a_chosen].index(example["answers"][idx_q_a_chosen])] = 1
new_ds_relevance_scores.append(relevance_scores)
ds_val_query_set = Dataset.from_dict(
{
"image_path": new_ds_image_path,
"caption": new_ds_caption,
"context": new_ds_context,
"answer": new_ds_answer,
"answer_options": new_ds_answer_options,
"relevance_scores": new_ds_relevance_scores,
}
)
# Test query set
new_ds_image_path = []
new_ds_caption = []
new_ds_context = []
new_ds_answer = []
new_ds_answer_options = []
new_ds_relevance_scores = []
for idx_ex, example in enumerate(ds_data["validation"]): # We now consider the true validation for the test
info_relevance_scores = data_relevance_scores[idx_ex]
assert str(data_relevance_scores[idx_ex]["image_id"]) in example["image_path"]
# We only have the relevance scores for one round out of the 10 per example
# Otherwise, we should have considered all the question-answer pairs for each example for the test set
idx_q_a_chosen = data_relevance_scores[0]["round_id"] - 1
new_ds_image_path.append(example["image_path"])
caption = example["caption"]
new_ds_caption.append(caption)
context = ""
for idx_q_a, (ques, ans) in enumerate(zip(example["questions"], example["answers"])):
if idx_q_a < idx_q_a_chosen:
context += f"Question: {ques}? Answer: {ans}. "
elif idx_q_a == idx_q_a_chosen:
context += f"Question: {ques}? Answer: "
new_ds_context.append(context)
new_ds_answer.append(example["answers"][idx_q_a_chosen])
new_ds_answer_options.append(example["answers_options"][idx_q_a_chosen])
new_ds_relevance_scores.append(info_relevance_scores["gt_relevance"])
ds_test_query_set = Dataset.from_dict(
{
"image_path": new_ds_image_path,
"caption": new_ds_caption,
"context": new_ds_context,
"answer": new_ds_answer,
"answer_options": new_ds_answer_options,
"relevance_scores": new_ds_relevance_scores,
}
)
def func_map_add_images(example):
path = example["image_path"]
example["image"] = Image.open(os.path.join("/scratch/", path))
return example
new_features = copy.deepcopy(ds_val_support_set.features)
new_features["image"] = datasets.Image()
new_features["answer"] = datasets.ClassLabel(num_classes=len(all_answers), names=all_answers)
ds_val_support_set = ds_val_support_set.map(func_map_add_images, num_proc=20, features=copy.deepcopy(new_features))
ds_test_support_set = ds_test_support_set.map(func_map_add_images, num_proc=20, features=copy.deepcopy(new_features))
ds_val_query_set = ds_val_query_set.map(func_map_add_images, num_proc=20, features=copy.deepcopy(new_features))
ds_test_query_set = ds_test_query_set.map(func_map_add_images, num_proc=20, features=copy.deepcopy(new_features))
# Save and push to hub newly created splits
ds_val_support_set.push_to_hub(repo_id, "validation_support_set", private=True)
ds_val_query_set.push_to_hub(repo_id, "validation_query_set", private=True)
ds_test_support_set.push_to_hub(repo_id, "test_support_set", private=True)
ds_test_query_set.push_to_hub(repo_id, "test_query_set", private=True)
# Load the newly created dataset from hub
ds_final = load_dataset(repo_id, use_auth_token=True)
# Print the final composition of the dataset
print(f"Composition of the final dataset: {ds_final}")