def data_wrap()

in tbsm_pytorch.py [0:0]


def data_wrap(X, lS_o, lS_i, use_gpu, device):
    if use_gpu:  # .cuda()
        return ([xj.to(device) for xj in X],
                [[S_o.to(device) for S_o in row] for row in lS_o],
                [[S_i.to(device) for S_i in row] for row in lS_i])
    else:
        return X, lS_o, lS_i