def batch_to_shape()

in data/data.py [0:0]


def batch_to_shape(batch):
    shape = Shape(batch["vert"].squeeze().to(device), batch["triv"].squeeze().to(device, torch.long) - 1)

    if "D" in batch:
        shape.D = batch["D"].squeeze().to(device)

    if "sub" in batch:
        shape.sub = batch["sub"]
        for i_s in range(len(shape.sub)):
            for i_p in range(len(shape.sub[i_s])):
                shape.sub[i_s][i_p] = shape.sub[i_s][i_p].to(device)

    if "idx" in batch:
        shape.samples = batch["idx"].squeeze().to(device, torch.long)

    if "vert_full" in batch:
        shape.vert_full = batch["vert_full"].squeeze().to(device)

    return shape