in kie-dmn/kie-dmn-ruleset2dmn-parent/kie-dmn-ruleset2dmn/src/main/java/org/kie/dmn/ruleset2dmn/Converter.java [94:239]
public static String parse(String dmnModelName, InputStream is) throws Exception {
final PMML pmml = org.jpmml.model.PMMLUtil.unmarshal(is);
if (pmml.getModels().size() != 1) {
throw new UnsupportedOperationException("Only single model supported for Decision Table conversion");
}
Model model0 = pmml.getModels().get(0);
if (!(model0 instanceof RuleSetModel)) {
throw new UnsupportedOperationException("Only single RuleSetModel supported for Decision Table conversion");
}
RuleSetModel rsModel = (RuleSetModel) model0;
RuleSet rs = rsModel.getRuleSet();
if (rs.getRuleSelectionMethods().size() != 1) {
throw new UnsupportedOperationException("Only single RuleSelectionMethods supported for Decision Table conversion");
}
RuleSelectionMethod rssMethod0 = rs.getRuleSelectionMethods().get(0);
Stream<SimpleRule> s0 = rs.getRules().stream().map(SimpleRule.class::cast);
if (rssMethod0.getCriterion() == Criterion.WEIGHTED_MAX) { // if WEIGHTED_MAX then sort by weight desc
s0 = s0.sorted(new WeightComparator().reversed());
}
List<SimpleRuleRow> rsRules = s0.map(SimpleRuleRow::new).collect(Collectors.toList());
Set<String> usedPredictors = new LinkedHashSet<>();
for (SimpleRuleRow rr : rsRules) {
usedPredictors.addAll(rr.map.keySet());
LOG.debug("{}", rr);
}
LOG.debug("{}", usedPredictors);
Map<String, Set<String>> predictorsLoVs = new HashMap<>();
Definitions definitions = new TDefinitions();
setDefaultNSContext(definitions);
definitions.setId("dmnid_" + dmnModelName);
definitions.setName(dmnModelName);
String namespace = "ri2dmn_" + UUID.randomUUID();
definitions.setNamespace(namespace);
definitions.getNsContext().put(XMLConstants.DEFAULT_NS_PREFIX, namespace);
definitions.setExporter("kie-dmn-ri");
appendInputData(definitions, pmml, usedPredictors);
final String dtName = rssMethod0.getCriterion() == Criterion.WEIGHTED_SUM ? "dt" : null;
DecisionTable dt = appendDecisionDT(definitions, dtName, pmml, usedPredictors);
if (rssMethod0.getCriterion() == Criterion.WEIGHTED_SUM) {
dt.setHitPolicy(HitPolicy.COLLECT);
}
if (rs.getDefaultScore() != null) {
LiteralExpression le = leFromNumberOrString(rs.getDefaultScore());
dt.getOutput().get(0).setDefaultOutputEntry(le);
}
for (SimpleRuleRow r : rsRules) {
DecisionRule dr = new TDecisionRule();
for (String input : usedPredictors) {
List<SimplePredicate> predicatesForInput = r.map.get(input);
if (predicatesForInput != null && !predicatesForInput.isEmpty()) {
String fnLookup =input;
Optional<DataField> df = pmml.getDataDictionary().getDataFields().stream().filter(x-> x.getName().equals(fnLookup)).findFirst();
UnaryTests ut = processSimplePredicateUnaryOrBinary(predicatesForInput, df);
if (ut.getText().startsWith("\"") && ut.getText().endsWith("\"")) {
predictorsLoVs.computeIfAbsent(input, k -> new LinkedHashSet<String>()).add(ut.getText());
}
dr.getInputEntry().add(ut);
} else {
UnaryTests ut = new TUnaryTests();
ut.setText("-");
dr.getInputEntry().add(ut);
}
}
if (rssMethod0.getCriterion() != Criterion.WEIGHTED_SUM) {
dr.getOutputEntry().add(leFromNumberOrString(r.r.getScore()));
} else {
String output = "{score: "+ feelLiteralValue(r.r.getScore(), Optional.empty()) + " , weight: " + r.r.getWeight() + " }";
LiteralExpression le = new TLiteralExpression();
le.setText(output);
dr.getOutputEntry().add(le);
}
RuleAnnotation comment = new TRuleAnnotation();
String commentText = "recordCount="+r.r.getRecordCount()
+ " nbCorrect=" + r.r.getNbCorrect()
+ " confidence=" + r.r.getConfidence()
+ " weight" + r.r.getWeight();
comment.setText(commentText);
dr.getAnnotationEntry().add(comment);
dt.getRule().add(dr);
}
if (rssMethod0.getCriterion() == Criterion.WEIGHTED_SUM) {
decisionAggregated(definitions, dtName);
decisionMax(definitions);
Decision decision = new TDecision();
String decisionName = definitions.getName();
decision.setName(decisionName);
decision.setId("d_" + CodegenStringUtil.escapeIdentifier(decisionName));
InformationItem variable = new TInformationItem();
variable.setName(decisionName);
variable.setId("dvar_" + CodegenStringUtil.escapeIdentifier(decisionName));
variable.setTypeRef(new QName("Any"));
decision.setVariable(variable);
addRequiredDecisionByName(decision, "aggregated");
addRequiredDecisionByName(decision, "max");
LiteralExpression le = new TLiteralExpression();
le.setText("aggregated[total=max][1].score");
decision.setExpression(le);
definitions.getDrgElement().add(decision);
}
for (DataField df : pmml.getDataDictionary().getDataFields()) {
if (df.getDataType() == DataType.STRING && predictorsLoVs.containsKey(df.getName())) {
for (Value value : df.getValues()) {
predictorsLoVs.get(df.getName()).add("\""+value.getValue().toString()+"\"");
}
}
}
for (Set<String> v : predictorsLoVs.values()) {
v.add("\"<unknown>\"");
}
for (Entry<String, Set<String>> kv : predictorsLoVs.entrySet()) {
ItemDefinition idd = new TItemDefinition();
idd.setName(kv.getKey());
idd.setTypeRef(new QName("string"));
UnaryTests lov = new TUnaryTests();
String lovText = kv.getValue().stream().collect(Collectors.joining(", "));
lov.setText(lovText);
idd.setAllowedValues(lov);
definitions.getItemDefinition().add(idd);
Optional<InputData> optInputData = definitions.getDrgElement().stream().filter(InputData.class::isInstance).map(InputData.class::cast).filter(drg-> drg.getName().equals(kv.getKey())).findFirst();
if (optInputData.isPresent()) {
optInputData.get().getVariable().setTypeRef(new QName(kv.getKey()));
} else {
throw new IllegalStateException();
}
Optional<InputClause> optInputClause = dt.getInput().stream().filter(ic -> ic.getInputExpression().getText().equals(kv.getKey())).findFirst();
if (optInputClause.isPresent()) {
UnaryTests icLov = new TUnaryTests();
icLov.setText(lovText);
InputClause ic = optInputClause.get();
ic.setInputValues(icLov);
ic.getInputExpression().setTypeRef(new QName(kv.getKey()));
} else {
throw new IllegalStateException();
}
}
DMNMarshaller dmnMarshaller = DMNMarshallerFactory.newDefaultMarshaller();
String xml = dmnMarshaller.marshal(definitions);
LOG.debug("{}", predictorsLoVs);
return xml;
}