in Models/exprsynth/seq2graphmodel.py [0:0]
def _load_data_from_sample(hyperparameters: Dict[str, Any],
metadata: Dict[str, Any],
raw_sample: Dict[str, Any],
result_holder: Dict[str, Any],
is_train: bool=True) -> bool:
keep_sample = super(Seq2GraphModel, Seq2GraphModel)._load_data_from_sample(hyperparameters, metadata, raw_sample, result_holder, is_train)
if not keep_sample:
return False
num_cx_tokens_per_side = hyperparameters['num_cx_tokens_per_side']
max_num_types = hyperparameters['cx_max_num_types']
full_cx_size = 2 * num_cx_tokens_per_side + 1
# Variables are sorted alphabetically
result_holder['cx_sorted_variables_in_scope'] = []
result_holder['cx_var_usage_context_tokens'] = []
result_holder['cx_var_usage_context_types'] = []
result_holder['cx_var_usage_context_types_mask'] = []
for i, variable_usage_contexts in enumerate(sorted(raw_sample['VariableUsageContexts'], key=lambda cx: cx['Name'])):
variable_token_idx = metadata['cx_token_vocab'].get_id_or_unk(variable_usage_contexts['Name'])
variable_node_id = variable_usage_contexts['NodeId']
variable_type = 'type:' + raw_sample['ContextGraph']['NodeTypes'][str(variable_node_id)]
variable_type_idxs = metadata['cx_type_vocab'].get_id_or_unk(variable_type)[:max_num_types]
num_var_usage_contexts = min(len(variable_usage_contexts['TokenContexts']), hyperparameters['max_num_contexts_per_variable'])
var_usage_context_tokens = np.zeros((num_var_usage_contexts, full_cx_size), dtype=np.int32)
var_usage_context_types = np.zeros((num_var_usage_contexts, full_cx_size, max_num_types), dtype=np.int32)
var_usage_context_type_mask = np.zeros((num_var_usage_contexts, full_cx_size, max_num_types), dtype=np.bool)
assert len(variable_usage_contexts['TokenContexts']) > 0
random.shuffle(variable_usage_contexts['TokenContexts'])
for context_idx, usage_context in enumerate(variable_usage_contexts['TokenContexts'][:num_var_usage_contexts]):
before_context, after_context = usage_context
before_tokens, before_token_types, before_token_type_masks = \
_convert_and_pad_token_sequence(hyperparameters, metadata, before_context, num_cx_tokens_per_side, start_from_left=False)
after_tokens, after_token_types, after_token_type_masks = \
_convert_and_pad_token_sequence(hyperparameters, metadata, after_context, num_cx_tokens_per_side)
var_usage_context_tokens[context_idx, :num_cx_tokens_per_side] = before_tokens
var_usage_context_types[context_idx, :num_cx_tokens_per_side] = before_token_types
var_usage_context_type_mask[context_idx, :num_cx_tokens_per_side] = before_token_type_masks
var_usage_context_tokens[context_idx][num_cx_tokens_per_side] = variable_token_idx
var_usage_context_types[context_idx, num_cx_tokens_per_side, :len(variable_type_idxs)] = variable_type_idxs
var_usage_context_type_mask[context_idx, num_cx_tokens_per_side, :len(variable_type_idxs)] = True
var_usage_context_tokens[context_idx, num_cx_tokens_per_side + 1:] = after_tokens
var_usage_context_types[context_idx, num_cx_tokens_per_side + 1:] = after_token_types
var_usage_context_type_mask[context_idx, num_cx_tokens_per_side + 1:] = after_token_type_masks
context_idx += 1
result_holder['cx_sorted_variables_in_scope'].append(variable_usage_contexts['Name'])
result_holder['cx_var_usage_context_tokens'].append(var_usage_context_tokens)
result_holder['cx_var_usage_context_types'].append(var_usage_context_types)
result_holder['cx_var_usage_context_types_mask'].append(var_usage_context_type_mask)
return NAGDecoder.load_data_from_sample(hyperparameters, metadata, raw_sample, result_holder, is_train)