void Distribution::sample()

in src/beanmachine/graph/distribution/distribution.cpp [119:196]


void Distribution::sample(std::mt19937& gen, graph::NodeValue& sample_value)
    const {
  // sample a single SCALAR
  if (sample_value.type.variable_type == graph::VariableType::SCALAR) {
    switch (sample_value.type.atomic_type) {
      case graph::AtomicType::BOOLEAN:
        sample_value._bool = _bool_sampler(gen);
        break;
      case graph::AtomicType::REAL:
      case graph::AtomicType::POS_REAL:
      case graph::AtomicType::PROBABILITY:
        sample_value._double = _double_sampler(gen);
        break;
      case graph::AtomicType::NATURAL:
        sample_value._natural = _natural_sampler(gen);
        break;
      default:
        throw std::runtime_error("Unsupported sample type.");
        break;
    }
    return;
  }
  // iid sample SCALARs
  if (sample_type.variable_type == graph::VariableType::SCALAR and
      sample_value.type.variable_type ==
          graph::VariableType::BROADCAST_MATRIX) {
    uint size = sample_value.type.cols * sample_value.type.rows;
    assert(size > 1);
    switch (sample_value.type.atomic_type) {
      case graph::AtomicType::BOOLEAN:
        for (uint i = 0; i < size; i++) {
          *(sample_value._bmatrix.data() + i) = _bool_sampler(gen);
        }
        break;
      case graph::AtomicType::REAL:
      case graph::AtomicType::POS_REAL:
      case graph::AtomicType::PROBABILITY:
        for (uint i = 0; i < size; i++) {
          *(sample_value._matrix.data() + i) = _double_sampler(gen);
        }
        break;
      case graph::AtomicType::NATURAL:
        for (uint i = 0; i < size; i++) {
          *(sample_value._nmatrix.data() + i) = _natural_sampler(gen);
        }
        break;
      default:
        throw std::runtime_error("Unsupported sample type.");
        break;
    }
    return;
  } else if (
      sample_type.variable_type == graph::VariableType::COL_SIMPLEX_MATRIX) {
    switch (sample_type.atomic_type) {
      case graph::AtomicType::PROBABILITY:
        sample_value._matrix = _matrix_sampler(gen);
        break;
      default:
        throw std::runtime_error("Unsupported sample type.");
        break;
    }
    return;
  } else if (
      sample_type.variable_type == graph::VariableType::BROADCAST_MATRIX) {
    switch (sample_type.atomic_type) {
      case graph::AtomicType::REAL:
      case graph::AtomicType::POS_REAL:
      case graph::AtomicType::PROBABILITY:
        sample_value._matrix = _matrix_sampler(gen);
        break;
      default:
        throw std::runtime_error("Unsupported sample type.");
        break;
    }
    return;
  }
  throw std::runtime_error("Unsupported sample type.");
}