in community-content/vertex_model_garden/model_oss/imagebind/handler.py [0:0]
def preprocess(self, data: Any) -> List[Dict[str, Any]]:
"""Preprocesses input data, including text, image, audio and video data.
Args:
data: Input data.
Returns:
A list of processed data samples, with each sample being a dictionary of
modality (key): input (value) pairs.
"""
logging.info("Preprocessing: %d instances received.", len(data))
preprocessed_sample_list = []
for item in data:
preprocessed_sample = {}
if ModalityType.TEXT in item:
preprocessed_sample[ModalityType.TEXT] = (
data_util.load_and_transform_text(
item[ModalityType.TEXT], self.device
)
)
for image_modality in [
ModalityType.VISION,
ModalityType.DEPTH,
ModalityType.THERMAL,
]:
if image_modality in item:
image_paths = item[image_modality]
local_image_paths = fileutils.download_gcs_file_list_to_local(
image_paths, constants.LOCAL_DATA_DIR
)
is_depth_or_thermal = image_modality in [
ModalityType.DEPTH,
ModalityType.THERMAL,
]
preprocessed_sample[image_modality] = (
self._load_and_transform_image_data(
local_image_paths,
self.device,
is_depth_or_thermal=is_depth_or_thermal,
)
)
if ModalityType.AUDIO in item:
audio_paths = item[ModalityType.AUDIO]
local_audio_paths = fileutils.download_gcs_file_list_to_local(
audio_paths, constants.LOCAL_DATA_DIR
)
preprocessed_sample[ModalityType.AUDIO] = (
data_util.load_and_transform_audio_data(
local_audio_paths, self.device
)
)
if _VIDEO_KEY_TO_AVOID_CONFLICT_WITH_IMAGE in item:
video_paths = item[_VIDEO_KEY_TO_AVOID_CONFLICT_WITH_IMAGE]
local_video_paths = fileutils.download_gcs_file_list_to_local(
video_paths, constants.LOCAL_DATA_DIR
)
preprocessed_sample[_VIDEO_KEY_TO_AVOID_CONFLICT_WITH_IMAGE] = (
data_util.load_and_transform_video_data(
local_video_paths, self.device
)
)
if ModalityType.IMU in item:
# Input data in the IMU modality are expected in shape [B, 6, 2000].
preprocessed_sample[ModalityType.IMU] = torch.tensor(
item[ModalityType.IMU], dtype=torch.float32, device=self.device
)
if preprocessed_sample:
preprocessed_sample_list.append(preprocessed_sample)
return preprocessed_sample_list