in src/beanmachine/graph/distribution/distribution.cpp [28:111]
std::unique_ptr<Distribution> Distribution::new_distribution(
graph::DistributionType dist_type,
graph::ValueType sample_type,
const std::vector<graph::Node*>& in_nodes) {
// call the appropriate distribution constructor
if (sample_type.variable_type == graph::VariableType::SCALAR) {
auto atype = sample_type.atomic_type;
switch (dist_type) {
case graph::DistributionType::TABULAR: {
return std::make_unique<Tabular>(atype, in_nodes);
}
case graph::DistributionType::BERNOULLI: {
return std::make_unique<Bernoulli>(atype, in_nodes);
}
case graph::DistributionType::BERNOULLI_NOISY_OR: {
return std::make_unique<BernoulliNoisyOr>(atype, in_nodes);
}
case graph::DistributionType::BETA: {
return std::make_unique<Beta>(atype, in_nodes);
}
case graph::DistributionType::BINOMIAL: {
return std::make_unique<Binomial>(atype, in_nodes);
}
case graph::DistributionType::FLAT: {
return std::make_unique<Flat>(atype, in_nodes);
}
case graph::DistributionType::NORMAL: {
return std::make_unique<Normal>(atype, in_nodes);
}
case graph::DistributionType::HALF_NORMAL: {
return std::make_unique<Half_Normal>(atype, in_nodes);
}
case graph::DistributionType::HALF_CAUCHY: {
return std::make_unique<HalfCauchy>(atype, in_nodes);
}
case graph::DistributionType::STUDENT_T: {
return std::make_unique<StudentT>(atype, in_nodes);
}
case graph::DistributionType::BERNOULLI_LOGIT: {
return std::make_unique<BernoulliLogit>(atype, in_nodes);
}
case graph::DistributionType::GAMMA: {
return std::make_unique<Gamma>(atype, in_nodes);
}
case graph::DistributionType::BIMIXTURE: {
return std::make_unique<Bimixture>(atype, in_nodes);
}
case graph::DistributionType::CATEGORICAL: {
return std::make_unique<Categorical>(atype, in_nodes);
}
default: {
throw std::invalid_argument(
"Unknown distribution " +
std::to_string(static_cast<int>(dist_type)) +
" for univariate sample type.");
}
}
} else if (
sample_type.variable_type == graph::VariableType::COL_SIMPLEX_MATRIX) {
switch (dist_type) {
case graph::DistributionType::DIRICHLET: {
return std::make_unique<Dirichlet>(sample_type, in_nodes);
}
default: {
throw std::invalid_argument(
"Unknown distribution " +
std::to_string(static_cast<int>(dist_type)) +
" for multivariate sample type.");
}
}
} else {
switch (dist_type) {
case graph::DistributionType::FLAT: {
return std::make_unique<Flat>(sample_type, in_nodes);
}
default: {
throw std::invalid_argument(
"Unknown distribution " +
std::to_string(static_cast<int>(dist_type)) +
" for multivariate sample type.");
}
}
}
}