in data/data.py [0:0]
def batch_to_shape(batch):
shape = Shape(batch["vert"].squeeze().to(device), batch["triv"].squeeze().to(device, torch.long) - 1)
if "D" in batch:
shape.D = batch["D"].squeeze().to(device)
if "sub" in batch:
shape.sub = batch["sub"]
for i_s in range(len(shape.sub)):
for i_p in range(len(shape.sub[i_s])):
shape.sub[i_s][i_p] = shape.sub[i_s][i_p].to(device)
if "idx" in batch:
shape.samples = batch["idx"].squeeze().to(device, torch.long)
if "vert_full" in batch:
shape.vert_full = batch["vert_full"].squeeze().to(device)
return shape