in src/beanmachine/graph/distribution/distribution.cpp [119:196]
void Distribution::sample(std::mt19937& gen, graph::NodeValue& sample_value)
const {
// sample a single SCALAR
if (sample_value.type.variable_type == graph::VariableType::SCALAR) {
switch (sample_value.type.atomic_type) {
case graph::AtomicType::BOOLEAN:
sample_value._bool = _bool_sampler(gen);
break;
case graph::AtomicType::REAL:
case graph::AtomicType::POS_REAL:
case graph::AtomicType::PROBABILITY:
sample_value._double = _double_sampler(gen);
break;
case graph::AtomicType::NATURAL:
sample_value._natural = _natural_sampler(gen);
break;
default:
throw std::runtime_error("Unsupported sample type.");
break;
}
return;
}
// iid sample SCALARs
if (sample_type.variable_type == graph::VariableType::SCALAR and
sample_value.type.variable_type ==
graph::VariableType::BROADCAST_MATRIX) {
uint size = sample_value.type.cols * sample_value.type.rows;
assert(size > 1);
switch (sample_value.type.atomic_type) {
case graph::AtomicType::BOOLEAN:
for (uint i = 0; i < size; i++) {
*(sample_value._bmatrix.data() + i) = _bool_sampler(gen);
}
break;
case graph::AtomicType::REAL:
case graph::AtomicType::POS_REAL:
case graph::AtomicType::PROBABILITY:
for (uint i = 0; i < size; i++) {
*(sample_value._matrix.data() + i) = _double_sampler(gen);
}
break;
case graph::AtomicType::NATURAL:
for (uint i = 0; i < size; i++) {
*(sample_value._nmatrix.data() + i) = _natural_sampler(gen);
}
break;
default:
throw std::runtime_error("Unsupported sample type.");
break;
}
return;
} else if (
sample_type.variable_type == graph::VariableType::COL_SIMPLEX_MATRIX) {
switch (sample_type.atomic_type) {
case graph::AtomicType::PROBABILITY:
sample_value._matrix = _matrix_sampler(gen);
break;
default:
throw std::runtime_error("Unsupported sample type.");
break;
}
return;
} else if (
sample_type.variable_type == graph::VariableType::BROADCAST_MATRIX) {
switch (sample_type.atomic_type) {
case graph::AtomicType::REAL:
case graph::AtomicType::POS_REAL:
case graph::AtomicType::PROBABILITY:
sample_value._matrix = _matrix_sampler(gen);
break;
default:
throw std::runtime_error("Unsupported sample type.");
break;
}
return;
}
throw std::runtime_error("Unsupported sample type.");
}