in src/fairseq/fairseq/data/language_pair_dataset.py [0:0]
def __getitem__(self, index):
tgt_item = self.tgt[index] if self.tgt is not None else None
src_item = self.src[index]
if self.enable_graph_encoder:
try:
split_index = torch.nonzero(src_item==self.graph_split_index)[0][0]
except: #handle empty graph
split_index = src_item.size(0)-1
graph_item = src_item[split_index:]
src_item = src_item[:split_index]
bos = self.src_dict.bos()
eos = self.src_dict.eos()
if self.src[index][0] == bos:
# Add same of graph too
graph_item = torch.cat([torch.LongTensor([bos]), graph_item])
if self.src[index][-1] == eos:
# Add same to new src split
src_item = torch.cat([src_item, torch.LongTensor([eos])])
# Append EOS to end of tgt sentence if it does not have an EOS and remove
# EOS from end of src sentence if it exists. This is useful when we use
# use existing datasets for opposite directions i.e., when we want to
# use tgt_dataset as src_dataset and vice versa
if self.append_eos_to_target:
eos = self.tgt_dict.eos() if self.tgt_dict else self.src_dict.eos()
if self.tgt and self.tgt[index][-1] != eos:
tgt_item = torch.cat([self.tgt[index], torch.LongTensor([eos])])
if self.append_bos:
bos = self.tgt_dict.bos() if self.tgt_dict else self.src_dict.bos()
if self.tgt and self.tgt[index][0] != bos:
tgt_item = torch.cat([torch.LongTensor([bos]), self.tgt[index]])
bos = self.src_dict.bos()
if self.src[index][-1] != bos:
if not self.enable_graph_encoder:
src_item = torch.cat([torch.LongTensor([bos]), self.src[index]])
else:
graph_item = torch.cat([torch.LongTensor([bos]), graph_item])
src_item = torch.cat([torch.LongTensor([bos]), src_item])
if self.remove_eos_from_source:
eos = self.src_dict.eos()
if self.src[index][-1] == eos:
if not self.enable_graph_encoder:
src_item = self.src[index][:-1]
else:
src_item = src_item[:-1]
graph_item = graph_item[:-1]
example = {
'id': index,
'source': src_item,
'target': tgt_item,
}
if self.align_dataset is not None:
example['alignment'] = self.align_dataset[index]
if self.extra_input_dataset is not None:
example['extra_input'] = self.extra_input_dataset[index]
if self.enable_graph_encoder:
example['graph'] = graph_item
return example