in src/inference.py [0:0]
def predict_fn(input_data, model):
prediction=[]
for bat in tqdm.tqdm(input_data):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
nodes = bat['nodes']
neighbors = bat['neighbors']
x = move_to(torch.FloatTensor(nodes), device)
g = []
for n in nodes:
gg = nearest_neighbor_graph(n,
neighbors=neighbors,
knn_strat='percentage')
g.append(gg)
graph = move_to(torch.ByteTensor(g),
device)
cost, ll, pi = model(x, graph, return_pi=True)
print(f'cost:{cost}')
prediction.append(pi.tolist())
return prediction