in salina/agents/gyma.py [0:0]
def _format_frame(frame):
if isinstance(frame, dict):
r = {}
for k in frame:
r[k] = _format_frame(frame[k])
return r
elif isinstance(frame, list):
t = torch.tensor(frame).unsqueeze(0)
if t.dtype == torch.float64 or t.dtype == torch.float32:
t = t.float()
else:
t = t.long()
return t
elif isinstance(frame, np.ndarray):
t = torch.from_numpy(frame).unsqueeze(0)
if t.dtype == torch.float64 or t.dtype == torch.float32:
t = t.float()
else:
t = t.long()
return t
elif isinstance(frame, torch.Tensor):
return frame.unsqueeze(0) # .float()
elif isinstance(frame, bool):
return torch.tensor([frame]).bool()
elif isinstance(frame, int):
return torch.tensor([frame]).long()
elif isinstance(frame, float):
return torch.tensor([frame]).float()
else:
try:
# Check if its a LazyFrame from OpenAI Baselines
o = torch.from_numpy(frame.__array__()).unsqueeze(0).float()
return o
except:
assert False