def predict_sequence()

in Synthesis_incorporation/models/prediction_model.py [0:0]


    def predict_sequence(self, example_sequence, top_n, beam_n, threshold, is_example, settings):
        if is_example:
            domain_embedding = self.embed_benchmark_example(example_sequence)
        else:
            embeddings = []
            for example in example_sequence:
                embeddings.append(self.embed_benchmark_value(example))
            for _ in range(len(example_sequence), 3):
                embeddings.append(torch.zeros(embeddings[0].shape))
            domain_embedding = torch.stack((embeddings[0], embeddings[1], embeddings[2]))

        with torch.no_grad():
            predicts, z3, z2, z1 = self.multi_ffn_model(domain_embedding)
            temp_z3 = torch.unsqueeze(z3,0)
            model_output, hidden, int_output = self.multi_model(temp_z3)

        topn_list = []
        topn_prob_list = []

        for i, m in enumerate(model_output):
            topn = []
            topn_prob = []
            prob = torch.nn.functional.softmax(m, dim=0).data
            # Taking the class with the highest probability score from the output
            topn_ops = torch.topk(prob,beam_n,dim=0)[1]
            if settings.printing.predicted_operations:
                print(i, topn_ops)
            for op in topn_ops.cpu().numpy():
                if settings.printing.predicted_operations:
                    print(self.indx2api[op])
                topn.append(self.indx2api[op])
                topn_prob.append(prob[op].item())
            topn_list.append(topn)
            topn_prob_list.append(topn_prob)
            if settings.printing.predicted_operations:
                print('====')

        topn_operations = list(product(topn_list[0], topn_list[1], topn_list[2]))
        topn_confidences = list(product(topn_prob_list[0], topn_prob_list[1], topn_prob_list[2]))
        topn_confidences = [c[0]*c[1]*c[2] for c in topn_confidences]

        num_gt_threshold = min(sum(c > threshold for c in topn_confidences), top_n)

        topn_operations = [operation for _, operation in sorted(zip(topn_confidences, topn_operations), reverse=True, key=lambda pair: pair[0])]
        topn_confidences = sorted(topn_confidences, reverse=True)

        return topn_operations[:num_gt_threshold], topn_confidences[:num_gt_threshold]