def load_one_batch()

in mm_action_prediction/loaders/loader_simmc.py [0:0]


    def load_one_batch(self, sample_ids):
        """Loads a batch, given the sample ids.

        Args:
            sample_ids: List of instance ids to load data for.

        Returns:
            batch: Dictionary with relevant fields for training/evaluation.
        """
        batch = {
            "pad_token": self.pad_token,
            "start_token": self.start_token,
            "sample_ids": sample_ids,
        }
        batch["dialog_len"] = self.raw_data["dialog_len"][sample_ids]
        batch["dialog_id"] = self.raw_data["dialog_id"][sample_ids]
        max_dialog_len = max(batch["dialog_len"])

        user_utt_id = self.raw_data["user_utt_id"][sample_ids]
        batch["user_utt"], batch["user_utt_len"] = self._sample_utterance_pool(
            user_utt_id,
            self.raw_data["user_sent"],
            self.raw_data["user_sent_len"],
            self.params["max_encoder_len"],
        )

        for key in ("assist_in", "assist_out"):
            batch[key], batch[key + "_len"] = self._sample_utterance_pool(
                self.raw_data["assist_utt_id"][sample_ids],
                self.raw_data[key],
                self.raw_data["assist_sent_len"],
                self.params["max_decoder_len"],
            )
        actions = self.raw_data["action"][sample_ids]
        batch["action"] = np.vectorize(lambda x: self.action_map[x])(actions)
        # Construct user, assistant, and dialog masks.
        batch["dialog_mask"] = user_utt_id != -1
        batch["user_mask"] = (batch["user_utt"] == batch["pad_token"]) | (
            batch["user_utt"] == batch["start_token"]
        )
        batch["assist_mask"] = (batch["assist_out"] == batch["pad_token"]) | (
            batch["assist_out"] == batch["start_token"]
        )

        # Get retrieval candidates if needed.
        if self.params["get_retrieval_candidates"]:
            retrieval_inds = self.raw_data["retrieval_candidates"][sample_ids]
            batch_size, num_rounds, _ = retrieval_inds.shape
            flat_inds = torch_support.flatten(
                retrieval_inds, batch_size, num_rounds
            )
            for key in ("assist_in", "assist_out"):
                new_key = key.replace("assist", "candidate")
                cands, cands_len = self._sample_utterance_pool(
                    flat_inds,
                    self.raw_data[key],
                    self.raw_data["assist_sent_len"],
                    self.params["max_decoder_len"],
                )
                batch[new_key] = torch_support.unflatten(
                    cands, batch_size, num_rounds
                )
                batch[new_key + "_len"] = torch_support.unflatten(
                    cands_len, batch_size, num_rounds
                )
            batch["candidate_mask"] = (
                (batch["candidate_out"] == batch["pad_token"])
                | (batch["candidate_out"] == batch["start_token"])
            )

        # Action supervision.
        batch["action_super"] = [
            self.raw_data["action_supervision"][ii] for ii in sample_ids
        ]

        # Fetch facts if required.
        if self.params["encoder"] == "memory_network":
            batch["fact"] = self.raw_data["fact"][sample_ids]
            batch["fact_len"] = self.raw_data["fact_len"][sample_ids]

        # Trim to the maximum dialog length.
        for key in (
            "assist_in",
            "assist_out",
            "candidate_in",
            "candidate_out",
            "user_utt",
            "fact",
            "user_mask",
            "assist_mask",
            "candidate_mask"
        ):
            if key in batch:
                batch[key] = batch[key][:, :max_dialog_len]
        for key in (
            "action",
            "assist_in_len",
            "assist_out_len",
            "candidate_in_len",
            "candidate_out_len",
            "user_utt_len",
            "dialog_mask",
            "fact_len",
        ):
            if key in batch:
                batch[key] = batch[key][:, :max_dialog_len]
        # TF-IDF features.
        if self.params["encoder"] == "tf_idf":
            batch["user_tf_idf"] = self.compute_tf_features(
                batch["user_utt"], batch["user_utt_len"]
            )

        # Domain-specific processing.
        if self.params["domain"] == "furniture":
            # Carousel states.
            if self.params["use_multimodal_state"]:
                batch["carousel_state"] = [
                    self.raw_data["carousel_state"][ii] for ii in sample_ids
                ]
            # Action output.
            if self.params["use_action_output"]:
                batch["action_output"] = [
                    self.raw_data["action_output_state"][ii] for ii in sample_ids
                ]
        elif self.params["domain"] == "fashion":
            # Asset embeddings -- memory, database, focus images.
            for dtype in ["memory", "database", "focus"]:
                indices = self.raw_data["{}_inds".format(dtype)][sample_ids]
                image_embeds = self.embed_data["embedding"][indices]
                batch["{}_images".format(dtype)] = image_embeds
        else:
            raise ValueError("Domain must be either furniture/fashion!")
        return self._ship_torch_batch(batch)