def __init__()

in mm_action_prediction/loaders/loader_simmc.py [0:0]


    def __init__(self, params):
        self.params = params
        # Load the dataset.
        raw_data = np.load(params["data_read_path"], allow_pickle=True)
        self.raw_data = raw_data[()]
        if self.params["encoder"] != "pretrained_transformer":
            self.words = loaders.Vocabulary()
            self.words.set_vocabulary_state(self.raw_data["vocabulary"]["word"])
            # Aliases.
            self.start_token = self.words.index("<start>")
            self.end_token = self.words.index("<end>")
            self.pad_token = self.words.index("<pad>")
            self.unk_token = self.words.index("<unk>")
        else:
            from transformers import BertTokenizer
            self.words = BertTokenizer.from_pretrained(self.raw_data["vocabulary"])
            # Aliases.
            self.start_token = self.words.added_tokens_encoder["[start]"]
            self.end_token = self.words.added_tokens_encoder["[end]"]
            self.pad_token = self.words.pad_token_id
            self.unk_token = self.words.unk_token_id
            self.words.word = self.words.convert_ids_to_tokens
            self.words.index = self.words.convert_tokens_to_ids

        # Read the metainfo for the dataset.
        with open(params["metainfo_path"], "r") as file_id:
            self.metainfo = json.load(file_id)
        self.action_map = {ii["name"]: ii["id"] for ii in self.metainfo["actions"]}

        # Read the attribute vocabulary for the dataset.
        with open(params["attr_vocab_path"], "r") as file_id:
            attribute_map = json.load(file_id)
        print("Loading attribute vocabularies..")
        self.attribute_map = {}
        for attr, attr_vocab in attribute_map.items():
            self.attribute_map[attr] = loaders.Vocabulary(
                immutable=True, verbose=False
            )
            self.attribute_map[attr].set_vocabulary_state(attr_vocab)
        # Encode attribute supervision.
        for d_id, super_datum in enumerate(self.raw_data["action_supervision"]):
            for r_id, round_datum in enumerate(super_datum):
                if round_datum is None:
                    continue
                if self.params["domain"] == "furniture":
                    new_supervision = {
                        key: self.attribute_map[key].index(val)
                        for key, val in round_datum.items()
                        if key in self.attribute_map
                    }
                elif self.params["domain"] == "fashion":
                    ATTRIBUTE_FIXES = {
                        "embellishment": "embellishments", "hemlength": "hemLength"
                    }
                    new_supervision = {}
                    for key, val in round_datum.items():
                        # No dictionary to map attributes to indices.
                        # (Non-classification/categorical fields)
                        if key not in self.attribute_map:
                            continue
                        # Encode each attribute -- multi-class classification.
                        fixed_keys = [ATTRIBUTE_FIXES.get(ii, ii) for ii in val]
                        new_supervision[key] = [
                            self.attribute_map[key].index(ii)
                            if ii in self.attribute_map[key]
                            else self.attribute_map[key].index("other")
                            for ii in fixed_keys
                        ]
                else:
                    raise ValueError("Domain must be either furniture/fashion!")
                self.raw_data["action_supervision"][d_id][r_id] = new_supervision

        if self.params["domain"] == "furniture":
            if self.params["use_multimodal_state"]:
                # Read embeddings for furniture assets to model carousel state.
                self._prepare_carousel_states()
            if self.params["use_action_output"]:
                # Output for the actions.
                self._prepare_carousel_states(key="action_output_state")
        elif self.params["domain"] == "fashion":
            # Prepare embeddings for fashion items.
            self._prepare_asset_embeddings()
        else:
            raise ValueError("Domain must be either furniture/fashion!")

        # Additional data constructs (post-processing).
        if params["encoder"] == "memory_network":
            self._construct_fact()
        elif params["encoder"] == "tf_idf":
            self.compute_idf_features()
        super(DataloaderSIMMC, self).__init__()