def compute_gradients()

in baselines/common/mpi_adam_optimizer.py [0:0]


    def compute_gradients(self, loss, var_list, **kwargs):
        grads_and_vars = tf.train.AdamOptimizer.compute_gradients(self, loss, var_list, **kwargs)
        grads_and_vars = [(g, v) for g, v in grads_and_vars if g is not None]
        flat_grad = tf.concat([tf.reshape(g, (-1,)) for g, v in grads_and_vars], axis=0) * self.mpi_rank_weight
        shapes = [v.shape.as_list() for g, v in grads_and_vars]
        sizes = [int(np.prod(s)) for s in shapes]

        total_weight = np.zeros(1, np.float32)
        self.comm.Allreduce(np.array([self.mpi_rank_weight], dtype=np.float32), total_weight, op=MPI.SUM)
        total_weight = total_weight[0]

        buf = np.zeros(sum(sizes), np.float32)
        countholder = [0] # Counts how many times _collect_grads has been called
        stat = tf.reduce_sum(grads_and_vars[0][1]) # sum of first variable
        def _collect_grads(flat_grad, np_stat):
            if self.grad_clip is not None:
                gradnorm = np.linalg.norm(flat_grad)
                if gradnorm > 1:
                    flat_grad /= gradnorm
                logger.logkv_mean('gradnorm', gradnorm)
                logger.logkv_mean('gradclipfrac', float(gradnorm > 1))
            self.comm.Allreduce(flat_grad, buf, op=MPI.SUM)
            np.divide(buf, float(total_weight), out=buf)
            if countholder[0] % 100 == 0:
                check_synced(np_stat, self.comm)
            countholder[0] += 1
            return buf

        avg_flat_grad = tf.py_func(_collect_grads, [flat_grad, stat], tf.float32)
        avg_flat_grad.set_shape(flat_grad.shape)
        avg_grads = tf.split(avg_flat_grad, sizes, axis=0)
        avg_grads_and_vars = [(tf.reshape(g, v.shape), v)
                    for g, (_, v) in zip(avg_grads, grads_and_vars)]
        return avg_grads_and_vars