def move_transition_to_device()

in lerobot/common/utils/transition.py [0:0]


def move_transition_to_device(transition: Transition, device: str = "cpu") -> Transition:
    device = torch.device(device)
    non_blocking = device.type == "cuda"

    # Move state tensors to device
    transition["state"] = {
        key: val.to(device, non_blocking=non_blocking) for key, val in transition["state"].items()
    }

    # Move action to device
    transition["action"] = transition["action"].to(device, non_blocking=non_blocking)

    # Move reward and done if they are tensors
    if isinstance(transition["reward"], torch.Tensor):
        transition["reward"] = transition["reward"].to(device, non_blocking=non_blocking)

    if isinstance(transition["done"], torch.Tensor):
        transition["done"] = transition["done"].to(device, non_blocking=non_blocking)

    if isinstance(transition["truncated"], torch.Tensor):
        transition["truncated"] = transition["truncated"].to(device, non_blocking=non_blocking)

    # Move next_state tensors to device
    transition["next_state"] = {
        key: val.to(device, non_blocking=non_blocking) for key, val in transition["next_state"].items()
    }

    # Move complementary_info tensors if present
    if transition.get("complementary_info") is not None:
        for key, val in transition["complementary_info"].items():
            if isinstance(val, torch.Tensor):
                transition["complementary_info"][key] = val.to(device, non_blocking=non_blocking)
            elif isinstance(val, (int, float, bool)):
                transition["complementary_info"][key] = torch.tensor(val, device=device)
            else:
                raise ValueError(f"Unsupported type {type(val)} for complementary_info[{key}]")
    return transition