in parlai/core/torch_ranker_agent.py [0:0]
def _build_candidates(self, batch, source, mode):
"""
Build a candidate set for this batch.
:param batch:
a Batch object (defined in torch_agent.py)
:param source:
the source from which candidates should be built, one of
['batch', 'batch-all-cands', 'inline', 'fixed']
:param mode:
'train' or 'eval'
:return: tuple of tensors (label_inds, cands, cand_vecs)
label_inds: A [bsz] LongTensor of the indices of the labels for each
example from its respective candidate set
cands: A [num_cands] list of (text) candidates
OR a [batchsize] list of such lists if source=='inline'
cand_vecs: A padded [num_cands, seqlen] LongTensor of vectorized candidates
OR a [batchsize, num_cands, seqlen] LongTensor if source=='inline'
Possible sources of candidates:
* batch: the set of all labels in this batch
Use all labels in the batch as the candidate set (with all but the
example's label being treated as negatives).
Note: with this setting, the candidate set is identical for all
examples in a batch. This option may be undesirable if it is possible
for duplicate labels to occur in a batch, since the second instance of
the correct label will be treated as a negative.
* batch-all-cands: the set of all candidates in this batch
Use all candidates in the batch as candidate set.
Note 1: This can result in a very large number of candidates.
Note 2: In this case we will deduplicate candidates.
Note 3: just like with 'batch' the candidate set is identical
for all examples in a batch.
* inline: batch_size lists, one list per example
If each example comes with a list of possible candidates, use those.
Note: With this setting, each example will have its own candidate set.
* fixed: one global candidate list, provided in a file from the user
If self.fixed_candidates is not None, use a set of fixed candidates for
all examples.
Note: this setting is not recommended for training unless the
universe of possible candidates is very small.
* vocab: one global candidate list, extracted from the vocabulary with the
exception of self.NULL_IDX.
"""
label_vecs = batch.label_vec # [bsz] list of lists of LongTensors
label_inds = None
batchsize = (
batch.text_vec.size(0)
if batch.text_vec is not None
else batch.image.size(0)
)
if label_vecs is not None:
assert label_vecs.dim() == 2
if source == 'batch':
warn_once(
'[ Executing {} mode with batch labels as set of candidates. ]'
''.format(mode)
)
if batchsize == 1:
warn_once(
"[ Warning: using candidate source 'batch' and observed a "
"batch of size 1. This may be due to uneven batch sizes at "
"the end of an epoch. ]"
)
if label_vecs is None:
raise ValueError(
"If using candidate source 'batch', then batch.label_vec cannot be "
"None."
)
cands = batch.labels
cand_vecs = label_vecs
label_inds = label_vecs.new_tensor(range(batchsize))
elif source == 'batch-all-cands':
warn_once(
'[ Executing {} mode with all candidates provided in the batch ]'
''.format(mode)
)
if batch.candidate_vecs is None:
raise ValueError(
"If using candidate source 'batch-all-cands', then batch."
"candidate_vecs cannot be None. If your task does not have "
"inline candidates, consider using one of "
"--{m}={{'batch','fixed','vocab'}}."
"".format(m='candidates' if mode == 'train' else 'eval-candidates')
)
# initialize the list of cands with the labels
cands = []
all_cands_vecs = []
# dictionary used for deduplication
cands_to_id = {}
for i, cands_for_sample in enumerate(batch.candidates):
for j, cand in enumerate(cands_for_sample):
if cand not in cands_to_id:
cands.append(cand)
cands_to_id[cand] = len(cands_to_id)
all_cands_vecs.append(batch.candidate_vecs[i][j])
cand_vecs, _ = self._pad_tensor(all_cands_vecs)
cand_vecs = cand_vecs.to(batch.label_vec.device)
label_inds = label_vecs.new_tensor(
[cands_to_id[label] for label in batch.labels]
)
elif source == 'inline':
warn_once(
'[ Executing {} mode with provided inline set of candidates ]'
''.format(mode)
)
if batch.candidate_vecs is None:
raise ValueError(
"If using candidate source 'inline', then batch.candidate_vecs "
"cannot be None. If your task does not have inline candidates, "
"consider using one of --{m}={{'batch','fixed','vocab'}}."
"".format(m='candidates' if mode == 'train' else 'eval-candidates')
)
cands = batch.candidates
cand_vecs = padded_3d(
batch.candidate_vecs, self.NULL_IDX, fp16friendly=self.fp16
)
if self.use_cuda:
cand_vecs = cand_vecs.to(
0 if self.opt['gpu'] == -1 else self.opt['gpu']
)
if label_vecs is not None:
label_inds = label_vecs.new_empty((batchsize))
bad_batch = False
for i, label_vec in enumerate(label_vecs):
label_vec_pad = label_vec.new_zeros(cand_vecs[i].size(1)).fill_(
self.NULL_IDX
)
if cand_vecs[i].size(1) < len(label_vec):
label_vec = label_vec[0 : cand_vecs[i].size(1)]
label_vec_pad[0 : label_vec.size(0)] = label_vec
label_inds[i] = self._find_match(cand_vecs[i], label_vec_pad)
if label_inds[i] == -1:
bad_batch = True
if bad_batch:
if self.ignore_bad_candidates and not self.is_training:
label_inds = None
else:
raise RuntimeError(
'At least one of your examples has a set of label candidates '
'that does not contain the label. To ignore this error '
'set `--ignore-bad-candidates True`.'
)
elif source == 'fixed':
if self.fixed_candidates is None:
raise ValueError(
"If using candidate source 'fixed', then you must provide the path "
"to a file of candidates with the flag --fixed-candidates-path or "
"the name of a task with --fixed-candidates-task."
)
warn_once(
"[ Executing {} mode with a common set of fixed candidates "
"(n = {}). ]".format(mode, len(self.fixed_candidates))
)
cands = self.fixed_candidates
cand_vecs = self.fixed_candidate_vecs
if label_vecs is not None:
label_inds = label_vecs.new_empty((batchsize))
bad_batch = False
for batch_idx, label_vec in enumerate(label_vecs):
max_c_len = cand_vecs.size(1)
label_vec_pad = label_vec.new_zeros(max_c_len).fill_(self.NULL_IDX)
if max_c_len < len(label_vec):
label_vec = label_vec[0:max_c_len]
label_vec_pad[0 : label_vec.size(0)] = label_vec
label_inds[batch_idx] = self._find_match(cand_vecs, label_vec_pad)
if label_inds[batch_idx] == -1:
bad_batch = True
if bad_batch:
if self.ignore_bad_candidates and not self.is_training:
label_inds = None
else:
raise RuntimeError(
'At least one of your examples has a set of label candidates '
'that does not contain the label. To ignore this error '
'set `--ignore-bad-candidates True`.'
)
elif source == 'vocab':
warn_once(
'[ Executing {} mode with tokens from vocabulary as candidates. ]'
''.format(mode)
)
cands = self.vocab_candidates
cand_vecs = self.vocab_candidate_vecs
# NOTE: label_inds is None here, as we will not find the label in
# the set of vocab candidates
else:
raise Exception("Unrecognized source: %s" % source)
return (cands, cand_vecs, label_inds)