def feed()

in rlpytorch/utils/utils.py [0:0]


    def feed(self, batch_states, curr_batch, forwarded):
        # Dump all entries with _, and with fa_
        if curr_batch is None and forwarded is None:
            state_info = { k : v for k, v in batch_states[0].items() if is_viskey(k) }
            print_dict("[batch states]: ", state_info, tight=True)
            return

        batch_info = { k : v if isinstance(v, (int, float, str)) else v[0] for k, v in curr_batch.items() if is_viskey(k) }
        fd_info = { k : v.data[0] for k, v in forwarded.items() if is_viskey(k) }

        t0 = batch_info["_seq"]
        additional_info = { }
        used_fd_info = defaultdict(lambda : [0] * self.max_delay)

        for k, v in batch_info.items():
            pred = self.prediction[k]
            # If there is prediction of the current value, also show them.
            if t0 in pred:
                cp = pred[t0]
                # Also compute th error.
                for delay, p in enumerate(cp["pred"]):
                    self.sum_sqr_err[k][2*delay] += (p - v) ** 2
                    self.sum_sqr_err[k][2*delay + 1] += 1

                for delay, p in enumerate(cp["baseline"]):
                    self.sum_sqr_err_bl[k][2*delay] += (p - v) ** 2
                    self.sum_sqr_err_bl[k][2*delay + 1] += 1

                additional_info[k + "_pred"] = ", ".join(["[%d] %.2f" % (delay, p) for delay, p in enumerate(cp["pred"]) if delay != 0])
                additional_info[k + "_bl"] = ", ".join(["[%d] %.2f" % (delay, p) for delay, p in enumerate(cp["baseline"]) if delay != 0])
                del pred[t0]

            for t in range(1, self.max_delay):
                k_f = k + "_T" + str(t)
                if not (k_f in fd_info): continue
                predictions = pred[t0 + t]
                predictions["pred"][t] = fd_info[k_f] + v
                predictions["baseline"][t] = v
                used_fd_info[k][t] = fd_info[k_f]

        batch_info.update(additional_info)
        used_fd_info = { k : ", ".join(["[%d] %.2f" % (i, vv) for i, vv in enumerate(v) if i != 0]) for k, v in used_fd_info.items() }