void Graph::gibbs()

in src/beanmachine/graph/gibbs.cpp [20:176]


void Graph::gibbs(uint num_samples, uint seed, InferConfig infer_config) {
  std::mt19937 gen(seed);
  std::set<uint> supp = compute_support();
  // eval each node so that we have a starting value and verify that these
  // values are all scalar
  // also compute the pool of variables that we will infer over and
  // compute their descendants -- i.e. all stochastic non-observed nodes
  // that are in the support of the graph
  // pool : nodes that we will infer over -> det_desc, sto_desc
  std::map<uint, std::tuple<std::vector<uint>, std::vector<uint>>> pool;
  // cache_logodds : nodes that we will infer over -> log odds of not changing
  std::vector<double> cache_logodds = std::vector<double>(nodes.size());
  assert(cache_logodds.size() > 0); // keep linter happy
  // inv_sto : stochastic node -> parent nodes in pool
  // x in sto_desc[y] => y in inv_sto[x]
  // this is a temp object which is needed to construct markov_blanket (below)
  std::map<uint, std::set<uint>> inv_sto;
  std::vector<Node*> ordered_supp;
  for (uint node_id : supp) {
    Node* node = nodes[node_id].get();
    bool node_is_not_observed = observed.find(node_id) == observed.end();
    if (node_is_not_observed) {
      node->eval(gen); // evaluate the value of non-observed operator nodes
    }
    if (node->is_stochastic() and node_is_not_observed) {
      std::vector<uint> det_nodes;
      std::vector<uint> sto_nodes;
      std::tie(det_nodes, sto_nodes) = compute_affected_nodes(node_id, supp);
      pool[node_id] = std::make_tuple(det_nodes, sto_nodes);
      cache_logodds[node_id] = NAN; // nan => needs to be re-computed
      for (auto sto : sto_nodes) {
        if (inv_sto.find(sto) == inv_sto.end()) {
          inv_sto[sto] = std::set<uint>();
        }
        inv_sto[sto].insert(node_id);
      }
    }
    if (infer_config.keep_log_prob) {
      ordered_supp.push_back(node);
    }
  }
  // markov_blanket of a node is the set of other nodes whose conditional
  // probability changes when the value of this node changes. This is a
  // symmetric relation and we only track it for the subset of nodes that
  // are in the pool of variables to be inferred over.
  // Formally, x in markov_blanket[y]
  //                <==> exists z s.t. z in sto_desc[x] and z in sto_desc[y]
  // Note: x is in markov_blanket[x]
  std::map<uint, std::set<uint>> markov_blanket;
  for (auto it = inv_sto.begin(); it != inv_sto.end(); ++it) {
    for (auto it1 = it->second.begin(); it1 != it->second.end(); ++it1) {
      if (markov_blanket.find(*it1) == markov_blanket.end()) {
        markov_blanket[*it1] = std::set<uint>();
      }
      for (auto it2 = it1; it2 != it->second.end(); ++it2) {
        markov_blanket[*it1].insert(*it2);
        if (markov_blanket.find(*it2) == markov_blanket.end()) {
          markov_blanket[*it2] = std::set<uint>();
        }
        markov_blanket[*it2].insert(*it1);
      }
    }
  }
  std::vector<NodeValue> old_values = std::vector<NodeValue>(nodes.size());
  assert(old_values.size() > 0); // keep linter happy
  // convert the smart pointers in nodes to dumb pointers in node_ptrs
  // for faster access
  std::vector<Node*> node_ptrs;
  for (uint node_id = 0; node_id < static_cast<uint>(nodes.size()); node_id++) {
    node_ptrs.push_back(nodes[node_id].get());
  }
  assert(node_ptrs.size() > 0); // keep linter happy
  // sampling outer loop
  for (uint snum = 0; snum < num_samples + infer_config.num_warmup; snum++) {
    for (auto it = pool.begin(); it != pool.end(); ++it) {
      bool must_change = false; // must_change => must change current value
      // if we have a cached value of the transition odds then use that instead
      if (not std::isnan(cache_logodds[it->first])) {
        // do we keep the current value?
        if (util::sample_logodds(gen, cache_logodds[it->first])) {
          continue;
        } else {
          must_change = true;
        }
      }
      // for the target sampled node grab its deterministic and stochastic
      // children
      // the following dance of getting into a tuple is needed because this
      // version of C++ doesn't have structured bindings
      std::tuple<const std::vector<uint>&, const std::vector<uint>&> tmp_tuple =
          it->second;
      const std::vector<uint>& det_nodes = std::get<0>(tmp_tuple);
      const std::vector<uint>& sto_nodes = std::get<1>(tmp_tuple);
      assert(it->first == sto_nodes.front());
      // now, compute the probability of all the stochastic nodes that are
      // going to be affected when we change the value of the target node
      double old_logweight = 0;
      for (uint node_id : sto_nodes) {
        const Node* node = node_ptrs[node_id];
        old_logweight += node->log_prob();
      }
      // save the values of the deterministic descendants of the target node
      // as well the target node itself
      for (uint node_id : det_nodes) {
        const Node* node = node_ptrs[node_id];
        old_values[node_id] = node->value;
      }
      Node* tgt_node = node_ptrs[it->first];
      old_values[it->first] = tgt_node->value;
      // propose a new value for the target node and update all the
      // deterministic children note: assuming only boolean values
      if (tgt_node->value.type != AtomicType::BOOLEAN) {
        throw std::runtime_error(
            "all stochastic random variables should be boolean");
      }
      tgt_node->value._bool = not tgt_node->value._bool; // flip
      for (uint node_id : det_nodes) {
        Node* node = node_ptrs[node_id];
        node->eval(gen);
      }
      // compute the probability of the stochastic nodes with the new value
      // of the target node
      double new_logweight = 0;
      for (uint node_id : sto_nodes) {
        const Node* node = node_ptrs[node_id];
        new_logweight += node->log_prob();
      }
      // compute logodds of keeping the current value
      double logodds = old_logweight - new_logweight;
      // Time to make a decision! Do we keep the old value or pick a new value.
      if ((not must_change) and util::sample_logodds(gen, logodds)) {
        // if the move to the new value is rejected then we need to restore
        // all the deterministic decendants and the target node to original
        // values
        for (uint node_id : det_nodes) {
          Node* node = node_ptrs[node_id];
          node->value = old_values[node_id];
        }
        tgt_node->value = old_values[it->first];
        cache_logodds[it->first] = logodds;
      } else {
        // if we change the value of this node then all the other nodes in the
        // pool that depend on this need to be recomputed
        for (uint node_id : markov_blanket[it->first]) {
          cache_logodds[node_id] = NAN;
        }
        cache_logodds[it->first] = -logodds;
      }
    }
    if (infer_config.keep_log_prob) {
      collect_log_prob(_full_log_prob(ordered_supp));
    }
    if (infer_config.keep_warmup or snum >= infer_config.num_warmup) {
      collect_sample();
    }
  }
}