def normalize_input_batch()

in shap_e/util/data_util.py [0:0]


def normalize_input_batch(batch: AttrDict, *, pc_scale: float, color_scale: float) -> AttrDict:
    res = batch.copy()
    scale_vec = torch.tensor([*([pc_scale] * 3), *([color_scale] * 3)], device=batch.points.device)
    res.points = res.points * scale_vec[:, None]

    if "cameras" in res:
        res.cameras = [[cam.scale_scene(pc_scale) for cam in cams] for cams in res.cameras]

    if "depths" in res:
        res.depths = [[depth * pc_scale for depth in depths] for depths in res.depths]

    return res