def eval_model_runs_batch()

in research/lfads/lfads.py [0:0]


  def eval_model_runs_batch(self, data_name, data_bxtxd, ext_input_bxtxi=None,
                            do_eval_cost=False, do_average_batch=False):
    """Returns all the goodies for the entire model, per batch.

    If data_bxtxd and ext_input_bxtxi can have fewer than batch_size along dim 1
    in which case this handles the padding and truncating automatically

    Args:
      data_name: The name of the data dict, to select which in/out matrices
        to use.
      data_bxtxd:  Numpy array training data with shape:
        batch_size x # time steps x # dimensions
      ext_input_bxtxi: Numpy array training external input with shape:
        batch_size x # time steps x # external input dims
      do_eval_cost (optional): If true, the IWAE (Importance Weighted
         Autoencoder) log likeihood bound, instead of the VAE version.
      do_average_batch (optional): average over the batch, useful for getting
      good IWAE costs, and model outputs for a single data point.

    Returns:
      A dictionary with the outputs of the model decoder, namely:
        prior g0 mean, prior g0 variance, approx. posterior mean, approx
        posterior mean, the generator initial conditions, the control inputs (if
        enabled), the state of the generator, the factors, and the rates.
    """
    session = tf.get_default_session()

    # if fewer than batch_size provided, pad to batch_size
    hps = self.hps
    batch_size = hps.batch_size
    E, _, _ = data_bxtxd.shape
    if E < hps.batch_size:
      data_bxtxd = np.pad(data_bxtxd, ((0, hps.batch_size-E), (0, 0), (0, 0)),
                          mode='constant', constant_values=0)
      if ext_input_bxtxi is not None:
        ext_input_bxtxi = np.pad(ext_input_bxtxi,
                                 ((0, hps.batch_size-E), (0, 0), (0, 0)),
                                 mode='constant', constant_values=0)

    feed_dict = self.build_feed_dict(data_name, data_bxtxd,
                                     ext_input_bxtxi, keep_prob=1.0)

    # Non-temporal signals will be batch x dim.
    # Temporal signals are list length T with elements batch x dim.
    tf_vals = [self.gen_ics, self.gen_states, self.factors,
               self.output_dist_params]
    tf_vals.append(self.cost)
    tf_vals.append(self.nll_bound_vae)
    tf_vals.append(self.nll_bound_iwae)
    tf_vals.append(self.train_step) # not train_op!
    if self.hps.ic_dim > 0:
      tf_vals += [self.prior_zs_g0.mean, self.prior_zs_g0.logvar,
                  self.posterior_zs_g0.mean, self.posterior_zs_g0.logvar]
    if self.hps.co_dim > 0:
      tf_vals.append(self.controller_outputs)
    tf_vals_flat, fidxs = flatten(tf_vals)

    np_vals_flat = session.run(tf_vals_flat, feed_dict=feed_dict)

    ff = 0
    gen_ics = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1
    gen_states = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1
    factors = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1
    out_dist_params = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1
    costs = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1
    nll_bound_vaes = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1
    nll_bound_iwaes = [np_vals_flat[f] for f in fidxs[ff]]; ff +=1
    train_steps = [np_vals_flat[f] for f in fidxs[ff]]; ff +=1
    if self.hps.ic_dim > 0:
      prior_g0_mean = [np_vals_flat[f] for f in fidxs[ff]]; ff +=1
      prior_g0_logvar = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1
      post_g0_mean = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1
      post_g0_logvar = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1
    if self.hps.co_dim > 0:
      controller_outputs = [np_vals_flat[f] for f in fidxs[ff]]; ff += 1

    # [0] are to take out the non-temporal items from lists
    gen_ics = gen_ics[0]
    costs = costs[0]
    nll_bound_vaes = nll_bound_vaes[0]
    nll_bound_iwaes = nll_bound_iwaes[0]
    train_steps = train_steps[0]

    # Convert to full tensors, not lists of tensors in time dim.
    gen_states = list_t_bxn_to_tensor_bxtxn(gen_states)
    factors = list_t_bxn_to_tensor_bxtxn(factors)
    out_dist_params = list_t_bxn_to_tensor_bxtxn(out_dist_params)
    if self.hps.ic_dim > 0:
      # select first time point
      prior_g0_mean = prior_g0_mean[0]
      prior_g0_logvar = prior_g0_logvar[0]
      post_g0_mean = post_g0_mean[0]
      post_g0_logvar = post_g0_logvar[0]
    if self.hps.co_dim > 0:
      controller_outputs = list_t_bxn_to_tensor_bxtxn(controller_outputs)

    # slice out the trials in case < batch_size provided
    if E < hps.batch_size:
      idx = np.arange(E)
      gen_ics = gen_ics[idx, :]
      gen_states = gen_states[idx, :]
      factors = factors[idx, :, :]
      out_dist_params = out_dist_params[idx, :, :]
      if self.hps.ic_dim > 0:
        prior_g0_mean = prior_g0_mean[idx, :]
        prior_g0_logvar = prior_g0_logvar[idx, :]
        post_g0_mean = post_g0_mean[idx, :]
        post_g0_logvar = post_g0_logvar[idx, :]
      if self.hps.co_dim > 0:
        controller_outputs = controller_outputs[idx, :, :]

    if do_average_batch:
      gen_ics = np.mean(gen_ics, axis=0)
      gen_states = np.mean(gen_states, axis=0)
      factors = np.mean(factors, axis=0)
      out_dist_params = np.mean(out_dist_params, axis=0)
      if self.hps.ic_dim > 0:
        prior_g0_mean = np.mean(prior_g0_mean, axis=0)
        prior_g0_logvar = np.mean(prior_g0_logvar, axis=0)
        post_g0_mean = np.mean(post_g0_mean, axis=0)
        post_g0_logvar = np.mean(post_g0_logvar, axis=0)
      if self.hps.co_dim > 0:
        controller_outputs = np.mean(controller_outputs, axis=0)

    model_vals = {}
    model_vals['gen_ics'] = gen_ics
    model_vals['gen_states'] = gen_states
    model_vals['factors'] = factors
    model_vals['output_dist_params'] = out_dist_params
    model_vals['costs'] = costs
    model_vals['nll_bound_vaes'] = nll_bound_vaes
    model_vals['nll_bound_iwaes'] = nll_bound_iwaes
    model_vals['train_steps'] = train_steps
    if self.hps.ic_dim > 0:
      model_vals['prior_g0_mean'] = prior_g0_mean
      model_vals['prior_g0_logvar'] = prior_g0_logvar
      model_vals['post_g0_mean'] = post_g0_mean
      model_vals['post_g0_logvar'] = post_g0_logvar
    if self.hps.co_dim > 0:
      model_vals['controller_outputs'] = controller_outputs

    return model_vals