def predict_fn()

in src/inference.py [0:0]


def predict_fn(input_data, model):
    prediction=[]
    for bat in tqdm.tqdm(input_data):
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        nodes = bat['nodes']
        neighbors = bat['neighbors']
        x = move_to(torch.FloatTensor(nodes), device)
        g = []
        for n in nodes:
            gg = nearest_neighbor_graph(n,
                                       neighbors=neighbors,
                                       knn_strat='percentage')
            g.append(gg)
        graph = move_to(torch.ByteTensor(g),
                        device)
        cost, ll, pi = model(x, graph, return_pi=True)
        print(f'cost:{cost}')
        prediction.append(pi.tolist())
    return prediction