def to_device()

in pyrela/utils.py [0:0]


def to_device(batch, device):
    if isinstance(batch, torch.Tensor):
        return batch.to(device).detach()
    elif isinstance(batch, dict):
        return {key: to_device(batch[key], device) for key in batch}
    elif isinstance(batch, rela.FFTransition):
        batch.obs = to_device(batch.obs, device)
        batch.action = to_device(batch.action, device)
        batch.reward = to_device(batch.reward, device)
        batch.terminal = to_device(batch.terminal, device)
        batch.bootstrap = to_device(batch.bootstrap, device)
        batch.next_obs = to_device(batch.next_obs, device)
        return batch
    elif isinstance(batch, rela.RNNTransition):
        batch.obs = to_device(batch.obs, device)
        batch.h0 = to_device(batch.h0, device)
        batch.action = to_device(batch.action, device)
        batch.reward = to_device(batch.reward, device)
        batch.terminal = to_device(batch.terminal, device)
        batch.bootstrap = to_device(batch.bootstrap, device)
        batch.seq_len = to_device(batch.seq_len, device)
        return batch
    else:
        assert False, "unsupported type: %s" % type(batch)