in rlpytorch/utils/utils.py [0:0]
def feed(self, batch_states, curr_batch, forwarded):
# Dump all entries with _, and with fa_
if curr_batch is None and forwarded is None:
state_info = { k : v for k, v in batch_states[0].items() if is_viskey(k) }
print_dict("[batch states]: ", state_info, tight=True)
return
batch_info = { k : v if isinstance(v, (int, float, str)) else v[0] for k, v in curr_batch.items() if is_viskey(k) }
fd_info = { k : v.data[0] for k, v in forwarded.items() if is_viskey(k) }
t0 = batch_info["_seq"]
additional_info = { }
used_fd_info = defaultdict(lambda : [0] * self.max_delay)
for k, v in batch_info.items():
pred = self.prediction[k]
# If there is prediction of the current value, also show them.
if t0 in pred:
cp = pred[t0]
# Also compute th error.
for delay, p in enumerate(cp["pred"]):
self.sum_sqr_err[k][2*delay] += (p - v) ** 2
self.sum_sqr_err[k][2*delay + 1] += 1
for delay, p in enumerate(cp["baseline"]):
self.sum_sqr_err_bl[k][2*delay] += (p - v) ** 2
self.sum_sqr_err_bl[k][2*delay + 1] += 1
additional_info[k + "_pred"] = ", ".join(["[%d] %.2f" % (delay, p) for delay, p in enumerate(cp["pred"]) if delay != 0])
additional_info[k + "_bl"] = ", ".join(["[%d] %.2f" % (delay, p) for delay, p in enumerate(cp["baseline"]) if delay != 0])
del pred[t0]
for t in range(1, self.max_delay):
k_f = k + "_T" + str(t)
if not (k_f in fd_info): continue
predictions = pred[t0 + t]
predictions["pred"][t] = fd_info[k_f] + v
predictions["baseline"][t] = v
used_fd_info[k][t] = fd_info[k_f]
batch_info.update(additional_info)
used_fd_info = { k : ", ".join(["[%d] %.2f" % (i, vv) for i, vv in enumerate(v) if i != 0]) for k, v in used_fd_info.items() }