lerobot/common/transport/utils.py (86 lines of code) (raw):

#!/usr/bin/env python # Copyright 2025 The HuggingFace Inc. team. # All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import io import logging import pickle # nosec B403: Safe usage for internal serialization only from multiprocessing import Event, Queue from typing import Any import torch from lerobot.common.transport import services_pb2 from lerobot.common.utils.transition import Transition CHUNK_SIZE = 2 * 1024 * 1024 # 2 MB def bytes_buffer_size(buffer: io.BytesIO) -> int: buffer.seek(0, io.SEEK_END) result = buffer.tell() buffer.seek(0) return result def send_bytes_in_chunks(buffer: bytes, message_class: Any, log_prefix: str = "", silent: bool = True): buffer = io.BytesIO(buffer) size_in_bytes = bytes_buffer_size(buffer) sent_bytes = 0 logging_method = logging.info if not silent else logging.debug logging_method(f"{log_prefix} Buffer size {size_in_bytes / 1024 / 1024} MB with") while sent_bytes < size_in_bytes: transfer_state = services_pb2.TransferState.TRANSFER_MIDDLE if sent_bytes + CHUNK_SIZE >= size_in_bytes: transfer_state = services_pb2.TransferState.TRANSFER_END elif sent_bytes == 0: transfer_state = services_pb2.TransferState.TRANSFER_BEGIN size_to_read = min(CHUNK_SIZE, size_in_bytes - sent_bytes) chunk = buffer.read(size_to_read) yield message_class(transfer_state=transfer_state, data=chunk) sent_bytes += size_to_read logging_method(f"{log_prefix} Sent {sent_bytes}/{size_in_bytes} bytes with state {transfer_state}") logging_method(f"{log_prefix} Published {sent_bytes / 1024 / 1024} MB") def receive_bytes_in_chunks(iterator, queue: Queue, shutdown_event: Event, log_prefix: str = ""): # type: ignore bytes_buffer = io.BytesIO() step = 0 logging.info(f"{log_prefix} Starting receiver") for item in iterator: logging.debug(f"{log_prefix} Received item") if shutdown_event.is_set(): logging.info(f"{log_prefix} Shutting down receiver") return if item.transfer_state == services_pb2.TransferState.TRANSFER_BEGIN: bytes_buffer.seek(0) bytes_buffer.truncate(0) bytes_buffer.write(item.data) logging.debug(f"{log_prefix} Received data at step 0") step = 0 elif item.transfer_state == services_pb2.TransferState.TRANSFER_MIDDLE: bytes_buffer.write(item.data) step += 1 logging.debug(f"{log_prefix} Received data at step {step}") elif item.transfer_state == services_pb2.TransferState.TRANSFER_END: bytes_buffer.write(item.data) logging.debug(f"{log_prefix} Received data at step end size {bytes_buffer_size(bytes_buffer)}") queue.put(bytes_buffer.getvalue()) bytes_buffer.seek(0) bytes_buffer.truncate(0) step = 0 logging.debug(f"{log_prefix} Queue updated") else: logging.warning(f"{log_prefix} Received unknown transfer state {item.transfer_state}") raise ValueError(f"Received unknown transfer state {item.transfer_state}") def state_to_bytes(state_dict: dict[str, torch.Tensor]) -> bytes: """Convert model state dict to flat array for transmission""" buffer = io.BytesIO() torch.save(state_dict, buffer) return buffer.getvalue() def bytes_to_state_dict(buffer: bytes) -> dict[str, torch.Tensor]: buffer = io.BytesIO(buffer) buffer.seek(0) return torch.load(buffer, weights_only=True) def python_object_to_bytes(python_object: Any) -> bytes: return pickle.dumps(python_object) def bytes_to_python_object(buffer: bytes) -> Any: buffer = io.BytesIO(buffer) buffer.seek(0) obj = pickle.load(buffer) # nosec B301: Safe usage of pickle.load # Add validation checks here return obj def bytes_to_transitions(buffer: bytes) -> list[Transition]: buffer = io.BytesIO(buffer) buffer.seek(0) transitions = torch.load(buffer, weights_only=True) return transitions def transitions_to_bytes(transitions: list[Transition]) -> bytes: buffer = io.BytesIO() torch.save(transitions, buffer) return buffer.getvalue()