def _process_query()

in modelling/src/neuraldb/dataset/instance_generator/spj_generator.py [0:0]


    def _process_query(self, query_obj, update_tokens):
        query_tokens = self.tokenizer.tokenize(query_obj["query"])

        if "predicted_facts" in query_obj and self.test_mode:
            assert (
                self.only_allow_predictions is None
                or self.only_allow_predictions is True
            )
            if not self.only_allow_predictions:
                logger.warning("Using predicted facts")
            self.only_allow_predictions = True

            for fact_group in query_obj["predicted_facts"]:
                context_tokens = [update_tokens[fact] for fact in set(fact_group)]
                yield self.maybe_decorate_with_metadata(
                    {"query": query_tokens, "context": context_tokens},
                    query_obj,
                )

        # elif (
        #
        #     len(query_obj["facts"]) == 0 or len(query_obj["derivations"]) == 0
        # ) and query_obj["height"] > 2:
        #     assert (
        #         self.only_allow_predictions is None
        #         or self.only_allow_predictions is False
        #     )
        #     self.only_allow_predictions = False
        #
        #     population = list(set(range(query_obj["height"])))
        #
        #     negative = random.sample(
        #         population, k=min(len(population), random.randint(1, 3))
        #     )
        #
        #     # And also add the regular group
        #     context_tokens = [update_tokens[fact] for fact in negative]
        #     yield self.maybe_decorate_with_metadata(
        #         {
        #             "query": query_tokens,
        #             "context": context_tokens,
        #             "output": self._prepend_prediction_type_answer(
        #                 [self.null_answer_special], query_obj["type"]
        #             ),
        #         },
        #         query_obj,
        #     )

        else:
            assert (
                self.only_allow_predictions is None
                or self.only_allow_predictions is False
            )
            self.only_allow_predictions = False

            for fact_group, derivation in zip(
                query_obj["facts"], query_obj["derivations"]
            ):

                derivation_tokens = self.tokenizer.tokenize(derivation)

                # Augment with randomly sampled facts to simulate false positive instances from IR
                augmented_fact_group = copy(fact_group)
                if (
                    self.augment_training
                    and not self.test_mode
                    and random.uniform(0, 1) < self.sample_probability_add_negatives
                ):

                    # Make a list of all the facts for this query, and sample facts not in it
                    flat_facts = list(itertools.chain(*query_obj["facts"]))
                    population = list(
                        set(range(query_obj["height"])).difference(flat_facts)
                    )

                    # Add some of this negative evidence to the fact group
                    if len(population):
                        negative = random.sample(
                            population, k=min(len(population), random.randint(1, 3))
                        )
                        augmented_fact_group.extend(negative)
                        random.shuffle(augmented_fact_group)

                        # Add the augmented group
                        context_tokens = [
                            update_tokens[fact] for fact in augmented_fact_group
                        ]
                        yield self.maybe_decorate_with_metadata(
                            {
                                "query": query_tokens,
                                "context": context_tokens,
                                "output": self._prepend_prediction_type_answer(
                                    derivation_tokens, query_obj["type"]
                                ),
                            },
                            query_obj,
                        )

                if (
                    self.augment_training
                    and not self.test_mode
                    and random.uniform(0, 1) < self.sample_probability_add_nulls
                ):
                    # Make a list of all the facts for this query, and sample facts not in it
                    flat_facts = list(itertools.chain(*query_obj["facts"]))
                    population = list(
                        set(range(query_obj["height"])).difference(flat_facts)
                    )

                    # Add some of this negative evidence to the fact group
                    if len(population):
                        negative = random.sample(
                            population, k=min(len(population), random.randint(1, 3))
                        )

                        random.shuffle(negative)

                        # Add the augmented group
                        context_tokens = [update_tokens[fact] for fact in negative]
                        yield self.maybe_decorate_with_metadata(
                            {
                                "query": query_tokens,
                                "context": context_tokens,
                                "output": self._prepend_prediction_type_answer(
                                    [self.null_answer_special], query_obj["type"]
                                ),
                            },
                            query_obj,
                        )

                # And also add the regular group
                context_tokens = [update_tokens[fact] for fact in fact_group]
                yield self.maybe_decorate_with_metadata(
                    {
                        "query": query_tokens,
                        "context": context_tokens,
                        "output": self._prepend_prediction_type_answer(
                            derivation_tokens, query_obj["type"]
                        ),
                    },
                    query_obj,
                )