in pyrela/utils.py [0:0]
def to_device(batch, device):
if isinstance(batch, torch.Tensor):
return batch.to(device).detach()
elif isinstance(batch, dict):
return {key: to_device(batch[key], device) for key in batch}
elif isinstance(batch, rela.FFTransition):
batch.obs = to_device(batch.obs, device)
batch.action = to_device(batch.action, device)
batch.reward = to_device(batch.reward, device)
batch.terminal = to_device(batch.terminal, device)
batch.bootstrap = to_device(batch.bootstrap, device)
batch.next_obs = to_device(batch.next_obs, device)
return batch
elif isinstance(batch, rela.RNNTransition):
batch.obs = to_device(batch.obs, device)
batch.h0 = to_device(batch.h0, device)
batch.action = to_device(batch.action, device)
batch.reward = to_device(batch.reward, device)
batch.terminal = to_device(batch.terminal, device)
batch.bootstrap = to_device(batch.bootstrap, device)
batch.seq_len = to_device(batch.seq_len, device)
return batch
else:
assert False, "unsupported type: %s" % type(batch)