in shap_e/util/data_util.py [0:0]
def normalize_input_batch(batch: AttrDict, *, pc_scale: float, color_scale: float) -> AttrDict:
res = batch.copy()
scale_vec = torch.tensor([*([pc_scale] * 3), *([color_scale] * 3)], device=batch.points.device)
res.points = res.points * scale_vec[:, None]
if "cameras" in res:
res.cameras = [[cam.scale_scene(pc_scale) for cam in cams] for cams in res.cameras]
if "depths" in res:
res.depths = [[depth * pc_scale for depth in depths] for depths in res.depths]
return res