in mm_action_prediction/loaders/loader_simmc.py [0:0]
def load_one_batch(self, sample_ids):
"""Loads a batch, given the sample ids.
Args:
sample_ids: List of instance ids to load data for.
Returns:
batch: Dictionary with relevant fields for training/evaluation.
"""
batch = {
"pad_token": self.pad_token,
"start_token": self.start_token,
"sample_ids": sample_ids,
}
batch["dialog_len"] = self.raw_data["dialog_len"][sample_ids]
batch["dialog_id"] = self.raw_data["dialog_id"][sample_ids]
max_dialog_len = max(batch["dialog_len"])
user_utt_id = self.raw_data["user_utt_id"][sample_ids]
batch["user_utt"], batch["user_utt_len"] = self._sample_utterance_pool(
user_utt_id,
self.raw_data["user_sent"],
self.raw_data["user_sent_len"],
self.params["max_encoder_len"],
)
for key in ("assist_in", "assist_out"):
batch[key], batch[key + "_len"] = self._sample_utterance_pool(
self.raw_data["assist_utt_id"][sample_ids],
self.raw_data[key],
self.raw_data["assist_sent_len"],
self.params["max_decoder_len"],
)
actions = self.raw_data["action"][sample_ids]
batch["action"] = np.vectorize(lambda x: self.action_map[x])(actions)
# Construct user, assistant, and dialog masks.
batch["dialog_mask"] = user_utt_id != -1
batch["user_mask"] = (batch["user_utt"] == batch["pad_token"]) | (
batch["user_utt"] == batch["start_token"]
)
batch["assist_mask"] = (batch["assist_out"] == batch["pad_token"]) | (
batch["assist_out"] == batch["start_token"]
)
# Get retrieval candidates if needed.
if self.params["get_retrieval_candidates"]:
retrieval_inds = self.raw_data["retrieval_candidates"][sample_ids]
batch_size, num_rounds, _ = retrieval_inds.shape
flat_inds = torch_support.flatten(
retrieval_inds, batch_size, num_rounds
)
for key in ("assist_in", "assist_out"):
new_key = key.replace("assist", "candidate")
cands, cands_len = self._sample_utterance_pool(
flat_inds,
self.raw_data[key],
self.raw_data["assist_sent_len"],
self.params["max_decoder_len"],
)
batch[new_key] = torch_support.unflatten(
cands, batch_size, num_rounds
)
batch[new_key + "_len"] = torch_support.unflatten(
cands_len, batch_size, num_rounds
)
batch["candidate_mask"] = (
(batch["candidate_out"] == batch["pad_token"])
| (batch["candidate_out"] == batch["start_token"])
)
# Action supervision.
batch["action_super"] = [
self.raw_data["action_supervision"][ii] for ii in sample_ids
]
# Fetch facts if required.
if self.params["encoder"] == "memory_network":
batch["fact"] = self.raw_data["fact"][sample_ids]
batch["fact_len"] = self.raw_data["fact_len"][sample_ids]
# Trim to the maximum dialog length.
for key in (
"assist_in",
"assist_out",
"candidate_in",
"candidate_out",
"user_utt",
"fact",
"user_mask",
"assist_mask",
"candidate_mask"
):
if key in batch:
batch[key] = batch[key][:, :max_dialog_len]
for key in (
"action",
"assist_in_len",
"assist_out_len",
"candidate_in_len",
"candidate_out_len",
"user_utt_len",
"dialog_mask",
"fact_len",
):
if key in batch:
batch[key] = batch[key][:, :max_dialog_len]
# TF-IDF features.
if self.params["encoder"] == "tf_idf":
batch["user_tf_idf"] = self.compute_tf_features(
batch["user_utt"], batch["user_utt_len"]
)
# Domain-specific processing.
if self.params["domain"] == "furniture":
# Carousel states.
if self.params["use_multimodal_state"]:
batch["carousel_state"] = [
self.raw_data["carousel_state"][ii] for ii in sample_ids
]
# Action output.
if self.params["use_action_output"]:
batch["action_output"] = [
self.raw_data["action_output_state"][ii] for ii in sample_ids
]
elif self.params["domain"] == "fashion":
# Asset embeddings -- memory, database, focus images.
for dtype in ["memory", "database", "focus"]:
indices = self.raw_data["{}_inds".format(dtype)][sample_ids]
image_embeds = self.embed_data["embedding"][indices]
batch["{}_images".format(dtype)] = image_embeds
else:
raise ValueError("Domain must be either furniture/fashion!")
return self._ship_torch_batch(batch)