in parlai/agents/starspace/starspace.py [0:0]
def vectorize(self, observations):
"""
Convert a list of observations into input & target tensors.
"""
def valid(obs):
# check if this is an example our model should actually process
return 'text2vec' in obs and len(obs['text2vec']) > 0
try:
# valid examples and their indices
valid_inds, exs = zip(
*[(i, ex) for i, ex in enumerate(observations) if valid(ex)]
)
except ValueError:
# zero examples to process in this batch, so zip failed to unpack
return None, None, None, None
# `x` text is already tokenized and truncated
# sort by length so we can use pack_padded
parsed_x = [ex['text2vec'] for ex in exs]
x_lens = [len(x) for x in parsed_x]
ind_sorted = sorted(range(len(x_lens)), key=lambda k: -x_lens[k])
exs = [exs[k] for k in ind_sorted]
valid_inds = [valid_inds[k] for k in ind_sorted]
parsed_x = [parsed_x[k] for k in ind_sorted]
labels_avail = any(['labels' in ex for ex in exs])
max_x_len = max([len(x) for x in parsed_x])
for x in parsed_x:
x += [self.NULL_IDX] * (max_x_len - len(x))
xs = torch.LongTensor(parsed_x)
# set up the target tensors
ys = None
labels = None
if labels_avail:
# randomly select one of the labels to update on, if multiple
labels = [random.choice(ex.get('labels', [''])) for ex in exs]
# parse each label and append END
parsed_y = [deque(maxlen=self.truncate) for _ in labels]
for dq, y in zip(parsed_y, labels):
dq.extendleft(reversed(self.parse(y)))
max_y_len = max(len(y) for y in parsed_y)
for y in parsed_y:
y += [self.NULL_IDX] * (max_y_len - len(y))
ys = torch.LongTensor(parsed_y)
cands = []
cands_txt = []
if ys is None:
# only build candidates in eval mode.
for o in observations:
if o.get('label_candidates', False):
cs = []
ct = []
for c in o['label_candidates']:
cs.append(torch.LongTensor(self.parse(c)).unsqueeze(0))
ct.append(c)
cands.append(cs)
cands_txt.append(ct)
else:
cands.append(None)
cands_txt.append(None)
return xs, ys, cands, cands_txt