in mm_action_prediction/tools/build_multimodal_inputs.py [0:0]
def build_multimodal_inputs(input_json_file):
"""Convert splits into format injested by the dataloader.
Args:
input_json_file: Path to the JSON file to injest
Returns:
mm_inputs: Dictionary of multimodal inputs to train/evaluate
"""
# Read the raw data.
print("Reading: {}".format(input_json_file))
with open(input_json_file, "r") as file_id:
data = json.load(file_id)
# Read action supervision.
print("Reading action supervision: {}".format(FLAGS.action_json_path))
with open(FLAGS.action_json_path, "r") as file_id:
extracted_actions = json.load(file_id)
# Convert into a dictionary.
extracted_actions = {ii["dialog_id"]: ii for ii in extracted_actions}
# Obtain maximum dialog length.
dialog_lens = np.array(
[len(ii["dialogue"]) for ii in data["dialogue_data"]], dtype="int32"
)
max_dialog_len = np.max(dialog_lens)
num_dialogs = len(data["dialogue_data"])
# Setup datastructures for recoding utterances, actions, action supervision,
# carousel states, and outputs.
encoded_dialogs = {
"user": np.full((num_dialogs, max_dialog_len), -1, dtype="int32"),
"assistant": np.full((num_dialogs, max_dialog_len), -1, dtype="int32")
}
empty_action_list = [
[None for _ in range(max_dialog_len)] for _ in range(num_dialogs)
]
action_info = {
"action": np.full((num_dialogs, max_dialog_len), "None", dtype="object_"),
"action_supervision": copy.deepcopy(empty_action_list),
"carousel_state": copy.deepcopy(empty_action_list),
"action_output_state": copy.deepcopy(empty_action_list)
}
dialog_ids = np.zeros(num_dialogs, dtype="int32")
action_counts = collections.defaultdict(lambda : 0)
# Compile dictionaries for user and assitant utterances separately.
utterance_dict = {"user": {}, "assistant": {}}
action_keys = ("action",)
if FLAGS.domain == "furniture":
action_keys += ("carousel_state", "action_output_state")
elif FLAGS.domain == "fashion":
task_mapping = {ii["task_id"]: ii for ii in data["task_mapping"]}
dialog_image_ids = {
"memory_images": [], "focus_images": [], "database_images": []
}
# If retrieval candidates file is available, encode the candidates.
if FLAGS.retrieval_candidate_file:
print("Reading retrieval candidates: {}".format(
FLAGS.retrieval_candidate_file)
)
with open(FLAGS.retrieval_candidate_file, "r") as file_id:
candidates_data = json.load(file_id)
candidate_pool = candidates_data["system_transcript_pool"]
candidate_ids = candidates_data["retrieval_candidates"]
candidate_ids = {ii["dialogue_idx"]: ii for ii in candidate_ids}
def get_candidate_ids(dialog_id, round_id):
"""Given the dialog_id and round_id, get the candidates.
Args:
candidate_ids: Dictionary of candidate ids
dialog_id: Dialog id
round_id: Round id
Returns:
candidates: List of candidates, indexed by the pool
"""
candidates = candidate_ids[dialog_id]["retrieval_candidates"]
candidates = candidates[round_id]["retrieval_candidates"]
return candidates
# Read the first dialog to get number of candidates.
random_dialog_id = list(candidate_ids.keys())[0]
num_candidates = len(get_candidate_ids(random_dialog_id, 0))
encoded_candidates = np.full(
(num_dialogs, max_dialog_len, num_candidates), -1, dtype=np.int32
)
for datum_id, datum in enumerate(data["dialogue_data"]):
dialog_id = datum["dialogue_idx"]
dialog_ids[datum_id] = dialog_id
# Get action supervision.
dialog_action_data = extracted_actions[dialog_id]["actions"]
# Record images for fashion.
if FLAGS.domain == "fashion":
# Assign random task if not found (1-2 dialogs).
if "dialogue_task_id" not in datum:
print("Dialog task id not found, using 1874 (random)!")
task_info = task_mapping[datum.get("dialogue_task_id", 1874)]
for key in ("memory_images", "database_images"):
dialog_image_ids[key].append(task_info[key])
dialog_image_ids["focus_images"].append(
extracted_actions[dialog_id]["focus_images"]
)
for round_id, round_datum in enumerate(datum["dialogue"]):
for key, speaker in (
("transcript", "user"), ("system_transcript", "assistant")
):
utterance_clean = round_datum[key].lower().strip(" ")
speaker_pool = utterance_dict[speaker]
if utterance_clean not in speaker_pool:
speaker_pool[utterance_clean] = len(speaker_pool)
encoded_dialogs[speaker][datum_id, round_id] = (
speaker_pool[utterance_clean]
)
# Record action related keys.
action_datum = dialog_action_data[round_id]
cur_action_supervision = action_datum["action_supervision"]
if FLAGS.domain == "furniture":
if cur_action_supervision is not None:
# Retain only the args of supervision.
cur_action_supervision = cur_action_supervision["args"]
action_info["action_supervision"][datum_id][round_id] = (
cur_action_supervision
)
for key in action_keys:
action_info[key][datum_id][round_id] = action_datum[key]
action_counts[action_datum["action"]] += 1
support.print_distribution(action_counts, "Action distribution:")
# Record retrieval candidates, if path is provided.
if FLAGS.retrieval_candidate_file:
for datum_id, datum in enumerate(data["dialogue_data"]):
dialog_id = datum["dialogue_idx"]
for round_id, _ in enumerate(datum["dialogue"]):
round_candidates = get_candidate_ids(dialog_id, round_id)
encoded_round_candidates = []
for cand_ind in round_candidates:
cand_str = candidate_pool[cand_ind].lower().strip(" ")
# If candidate is not in the dict, add it.
if cand_str not in utterance_dict["assistant"]:
utterance_dict["assistant"][cand_str] = (
len(utterance_dict["assistant"])
)
pool_ind = utterance_dict["assistant"][cand_str]
encoded_round_candidates.append(pool_ind)
encoded_candidates[datum_id, round_id] = encoded_round_candidates
# Sort utterance list for consistency.
utterance_list = {
key: sorted(value.keys(), key=lambda x: value[x])
for key, value in utterance_dict.items()
}
# Convert the pools into matrices.
mm_inputs = {}
mm_inputs.update(action_info)
# If token-wise encoding is to be used.
print("Vocabulary: {}".format(FLAGS.vocab_file))
if not FLAGS.pretrained_tokenizer:
with open(FLAGS.vocab_file, "r") as file_id:
vocabulary = json.load(file_id)
mm_inputs["vocabulary"] = vocabulary
word2ind = {word: index for index, word in enumerate(vocabulary["word"])}
mm_inputs["user_sent"], mm_inputs["user_sent_len"] = convert_pool_matrices(
utterance_list["user"], word2ind
)
mm_inputs["assist_sent"], mm_inputs["assist_sent_len"] = convert_pool_matrices(
utterance_list["assistant"], word2ind
)
# Token aliases.
pad_token = word2ind["<pad>"]
start_token = word2ind["<start>"]
end_token = word2ind["<end>"]
else:
# Use pretrained BERT tokenizer.
from transformers import BertTokenizer
tokenizer = BertTokenizer.from_pretrained(FLAGS.vocab_file)
mm_inputs["vocabulary"] = FLAGS.vocab_file
mm_inputs["user_sent"], mm_inputs["user_sent_len"] = (
convert_pool_matrices_pretrained_tokenizer(
utterance_list["user"], tokenizer
)
)
mm_inputs["assist_sent"], mm_inputs["assist_sent_len"] = (
convert_pool_matrices_pretrained_tokenizer(
utterance_list["assistant"], tokenizer
)
)
# Token aliases.
pad_token = tokenizer.pad_token_id
start_token = tokenizer.added_tokens_encoder["[start]"]
end_token = tokenizer.added_tokens_encoder["[end]"]
# Get the input and output version for RNN for assistant_sent.
extra_slice = np.full((len(mm_inputs["assist_sent"]), 1), start_token, np.int32)
mm_inputs["assist_in"] = np.concatenate(
[extra_slice, mm_inputs["assist_sent"]], axis=1
)
extra_slice.fill(pad_token)
mm_inputs["assist_out"] = np.concatenate(
[mm_inputs["assist_sent"], extra_slice], axis=1
)
for ii in range(len(mm_inputs["assist_out"])):
mm_inputs["assist_out"][ii, mm_inputs["assist_sent_len"][ii]] = end_token
mm_inputs["assist_sent_len"] += 1
# Save the memory and dataset image_ids for each instance.
if FLAGS.domain == "fashion":
mm_inputs.update(dialog_image_ids)
# Save the retrieval candidates.
if FLAGS.retrieval_candidate_file:
mm_inputs["retrieval_candidates"] = encoded_candidates
# Save the dialogs by user/assistant utterances.
mm_inputs["user_utt_id"] = encoded_dialogs["user"]
mm_inputs["assist_utt_id"] = encoded_dialogs["assistant"]
mm_inputs["dialog_len"] = dialog_lens
mm_inputs["dialog_id"] = dialog_ids
mm_inputs["paths"] = {
"data": FLAGS.json_path,
"action": FLAGS.action_json_path,
"retrieval": FLAGS.retrieval_candidate_file,
"vocabulary": FLAGS.vocab_file
}
return mm_inputs