in ruta-ep-textruler/src/main/java/org/apache/uima/ruta/textruler/learner/trabal/TrabalLearner.java [500:592]
private List<TrabalRule> runAlgorithm(Map<String, List<AnnotationError>> errorGrps) {
removeBasics();
inducedRules.clear();
List<TrabalRule> rules = new ArrayList<TrabalRule>();
bestRulesForStatus.clear();
int i = 1;
for (List<AnnotationError> each : errorGrps.values()) {
if (shouldAbort()) {
break;
}
Collections.sort(each);
String status = each.get(0).toString();
sendStatusUpdateToDelegate("Creating basic rules: " + status,
TextRulerLearnerState.ML_RUNNING, false);
List<TrabalRule> basicRules = removeDuplicateRules(createBasicRules(each));
if (basicRules.size() > maxNumberOfBasicRules) {
basicRules = basicRules.subList(0, maxNumberOfBasicRules);
}
sendStatusUpdateToDelegate("Testing basic rules: " + status, TextRulerLearnerState.ML_RUNNING,
false);
basicRules = testTrabalRulesOnDocumentSet(basicRules, exampleDocuments, additionalDocuments,
"basic rules (" + i + " of " + errorGrps.size() + ")");
if (basicRules.size() > 0) {
Collections.sort(basicRules, basicComparator);
bestRulesForStatus.add(basicRules.get(0));
}
result = actualResult + getRuleStrings(bestRulesForStatus);
sendStatusUpdateToDelegate("Testing basic rules: " + status, TextRulerLearnerState.ML_RUNNING,
true);
List<TrabalRule> learntRules = new ArrayList<TrabalRule>();
for (TrabalRule rule : basicRules) {
if (rule.getCoveringStatistics().getCoveredPositivesCount() > 0
&& rule.getCoveringStatistics().getCoveredNegativesCount() == 0) {
learntRules.add(rule);
}
}
// Collections.sort(learntRules, basicComparator);
List<TrabalRule> enhancedRules = new ArrayList<TrabalRule>();
int rank = 1;
for (TrabalRule rule : basicRules) {
if (rule.getCoveringStatistics().getCoveredPositivesCount() > 0
&& rule.getCoveringStatistics().getCoveredNegativesCount() > 0) {
if (learntRules.size() > 0) {
if (rule.getCoveringStatistics().getCoveredPositivesCount() > learntRules.get(0)
.getCoveringStatistics().getCoveredPositivesCount()) {
rule.setRating(rank);
enhancedRules.add(rule);
rank++;
}
} else {
rule.setRating(rank);
enhancedRules.add(rule);
rank++;
}
}
}
basicRules.clear();
try {
enhancedRules = enhanceRules(enhancedRules, maxNumberOfIterations, new RankedList(idf));
Collections.sort(enhancedRules);
if (enhancedRules.size() > maxNumberOfRules) {
enhancedRules = enhancedRules.subList(0, maxNumberOfRules);
}
} catch (Exception e) {
e.printStackTrace();
}
sendStatusUpdateToDelegate("Testing enhanced rules: " + status,
TextRulerLearnerState.ML_RUNNING, false);
enhancedRules = testTrabalRulesOnDocumentSet(enhancedRules, exampleDocuments,
additionalDocuments, "enhanced rules (" + i + " of " + errorGrps.size() + ")");
for (TrabalRule rule : enhancedRules) {
if (rule.getErrorRate() <= maxErrorRate) {
learntRules.add(rule);
}
}
enhancedRules.clear();
learntRules = removeDuplicateRules(learntRules);
if (learntRules.size() > 0) {
Collections.sort(learntRules, enhancedComparator);
bestRulesForStatus.remove(bestRulesForStatus.size() - 1); // TODO
bestRulesForStatus.add(learntRules.get(0));
result = actualResult + getRuleStrings(bestRulesForStatus);
}
sendStatusUpdateToDelegate("Testing optimized rules: " + status,
TextRulerLearnerState.ML_RUNNING, true);
if (learntRules.size() > maxNumberOfRules)
learntRules = learntRules.subList(0, maxNumberOfRules);
rules.addAll(learntRules);
i++;
}
return getBest(rules);
}