in mm_action_prediction/loaders/loader_simmc.py [0:0]
def __init__(self, params):
self.params = params
# Load the dataset.
raw_data = np.load(params["data_read_path"], allow_pickle=True)
self.raw_data = raw_data[()]
if self.params["encoder"] != "pretrained_transformer":
self.words = loaders.Vocabulary()
self.words.set_vocabulary_state(self.raw_data["vocabulary"]["word"])
# Aliases.
self.start_token = self.words.index("<start>")
self.end_token = self.words.index("<end>")
self.pad_token = self.words.index("<pad>")
self.unk_token = self.words.index("<unk>")
else:
from transformers import BertTokenizer
self.words = BertTokenizer.from_pretrained(self.raw_data["vocabulary"])
# Aliases.
self.start_token = self.words.added_tokens_encoder["[start]"]
self.end_token = self.words.added_tokens_encoder["[end]"]
self.pad_token = self.words.pad_token_id
self.unk_token = self.words.unk_token_id
self.words.word = self.words.convert_ids_to_tokens
self.words.index = self.words.convert_tokens_to_ids
# Read the metainfo for the dataset.
with open(params["metainfo_path"], "r") as file_id:
self.metainfo = json.load(file_id)
self.action_map = {ii["name"]: ii["id"] for ii in self.metainfo["actions"]}
# Read the attribute vocabulary for the dataset.
with open(params["attr_vocab_path"], "r") as file_id:
attribute_map = json.load(file_id)
print("Loading attribute vocabularies..")
self.attribute_map = {}
for attr, attr_vocab in attribute_map.items():
self.attribute_map[attr] = loaders.Vocabulary(
immutable=True, verbose=False
)
self.attribute_map[attr].set_vocabulary_state(attr_vocab)
# Encode attribute supervision.
for d_id, super_datum in enumerate(self.raw_data["action_supervision"]):
for r_id, round_datum in enumerate(super_datum):
if round_datum is None:
continue
if self.params["domain"] == "furniture":
new_supervision = {
key: self.attribute_map[key].index(val)
for key, val in round_datum.items()
if key in self.attribute_map
}
elif self.params["domain"] == "fashion":
ATTRIBUTE_FIXES = {
"embellishment": "embellishments", "hemlength": "hemLength"
}
new_supervision = {}
for key, val in round_datum.items():
# No dictionary to map attributes to indices.
# (Non-classification/categorical fields)
if key not in self.attribute_map:
continue
# Encode each attribute -- multi-class classification.
fixed_keys = [ATTRIBUTE_FIXES.get(ii, ii) for ii in val]
new_supervision[key] = [
self.attribute_map[key].index(ii)
if ii in self.attribute_map[key]
else self.attribute_map[key].index("other")
for ii in fixed_keys
]
else:
raise ValueError("Domain must be either furniture/fashion!")
self.raw_data["action_supervision"][d_id][r_id] = new_supervision
if self.params["domain"] == "furniture":
if self.params["use_multimodal_state"]:
# Read embeddings for furniture assets to model carousel state.
self._prepare_carousel_states()
if self.params["use_action_output"]:
# Output for the actions.
self._prepare_carousel_states(key="action_output_state")
elif self.params["domain"] == "fashion":
# Prepare embeddings for fashion items.
self._prepare_asset_embeddings()
else:
raise ValueError("Domain must be either furniture/fashion!")
# Additional data constructs (post-processing).
if params["encoder"] == "memory_network":
self._construct_fact()
elif params["encoder"] == "tf_idf":
self.compute_idf_features()
super(DataloaderSIMMC, self).__init__()