def _load_data_from_sample()

in Models/exprsynth/seq2graphmodel.py [0:0]


    def _load_data_from_sample(hyperparameters: Dict[str, Any],
                               metadata: Dict[str, Any],
                               raw_sample: Dict[str, Any],
                               result_holder: Dict[str, Any],
                               is_train: bool=True) -> bool:
        keep_sample = super(Seq2GraphModel, Seq2GraphModel)._load_data_from_sample(hyperparameters, metadata, raw_sample, result_holder, is_train)
        if not keep_sample:
            return False
        num_cx_tokens_per_side = hyperparameters['num_cx_tokens_per_side']
        max_num_types = hyperparameters['cx_max_num_types']
        full_cx_size = 2 * num_cx_tokens_per_side + 1

        # Variables are sorted alphabetically
        result_holder['cx_sorted_variables_in_scope'] = []
        result_holder['cx_var_usage_context_tokens'] = []
        result_holder['cx_var_usage_context_types'] = []
        result_holder['cx_var_usage_context_types_mask'] = []
        for i, variable_usage_contexts in enumerate(sorted(raw_sample['VariableUsageContexts'], key=lambda cx: cx['Name'])):
            variable_token_idx = metadata['cx_token_vocab'].get_id_or_unk(variable_usage_contexts['Name'])
            variable_node_id = variable_usage_contexts['NodeId']
            variable_type = 'type:' + raw_sample['ContextGraph']['NodeTypes'][str(variable_node_id)]
            variable_type_idxs = metadata['cx_type_vocab'].get_id_or_unk(variable_type)[:max_num_types]

            num_var_usage_contexts = min(len(variable_usage_contexts['TokenContexts']), hyperparameters['max_num_contexts_per_variable'])
            var_usage_context_tokens = np.zeros((num_var_usage_contexts, full_cx_size), dtype=np.int32)
            var_usage_context_types = np.zeros((num_var_usage_contexts, full_cx_size, max_num_types), dtype=np.int32)
            var_usage_context_type_mask = np.zeros((num_var_usage_contexts, full_cx_size, max_num_types), dtype=np.bool)
            assert len(variable_usage_contexts['TokenContexts']) > 0

            random.shuffle(variable_usage_contexts['TokenContexts'])
            for context_idx, usage_context in enumerate(variable_usage_contexts['TokenContexts'][:num_var_usage_contexts]):
                before_context, after_context = usage_context
                before_tokens, before_token_types, before_token_type_masks = \
                    _convert_and_pad_token_sequence(hyperparameters, metadata, before_context, num_cx_tokens_per_side, start_from_left=False)
                after_tokens, after_token_types, after_token_type_masks = \
                    _convert_and_pad_token_sequence(hyperparameters, metadata, after_context, num_cx_tokens_per_side)

                var_usage_context_tokens[context_idx, :num_cx_tokens_per_side] = before_tokens
                var_usage_context_types[context_idx, :num_cx_tokens_per_side] = before_token_types
                var_usage_context_type_mask[context_idx, :num_cx_tokens_per_side] = before_token_type_masks

                var_usage_context_tokens[context_idx][num_cx_tokens_per_side] = variable_token_idx
                var_usage_context_types[context_idx, num_cx_tokens_per_side, :len(variable_type_idxs)] = variable_type_idxs
                var_usage_context_type_mask[context_idx, num_cx_tokens_per_side, :len(variable_type_idxs)] = True

                var_usage_context_tokens[context_idx, num_cx_tokens_per_side + 1:] = after_tokens
                var_usage_context_types[context_idx, num_cx_tokens_per_side + 1:] = after_token_types
                var_usage_context_type_mask[context_idx, num_cx_tokens_per_side + 1:] = after_token_type_masks
                context_idx += 1

            result_holder['cx_sorted_variables_in_scope'].append(variable_usage_contexts['Name'])
            result_holder['cx_var_usage_context_tokens'].append(var_usage_context_tokens)
            result_holder['cx_var_usage_context_types'].append(var_usage_context_types)
            result_holder['cx_var_usage_context_types_mask'].append(var_usage_context_type_mask)

        return NAGDecoder.load_data_from_sample(hyperparameters, metadata, raw_sample, result_holder, is_train)