in ctakes-temporal/src/main/java/org/apache/ctakes/temporal/nn/eval/EvaluationOfNeuralJointRelations.java [306:478]
protected AnnotationStatistics<String> test(CollectionReader collectionReader, File directory)
throws Exception {
this.useClosure=false;
AggregateBuilder aggregateBuilder = this.getPreprocessorAggregateBuilder();
aggregateBuilder.add(CopyFromGold.getDescription(EventMention.class, TimeMention.class));
aggregateBuilder.add(CopyFromSystem.getDescription(Sentence.class));
aggregateBuilder.add(AnalysisEngineFactory.createEngineDescription(RemoveNonContainsRelations.class),
CAS.NAME_DEFAULT_SOFA,
GOLD_VIEW_NAME);
aggregateBuilder.add(AnalysisEngineFactory.createEngineDescription(RemoveRelations.class));
AnalysisEngineDescription aed = null;
aed = AnalysisEngineFactory.createEngineDescription(WindowBasedAnnotator.class,//WindowBasedAnnotator.class,
CleartkAnnotator.PARAM_IS_TRAINING,
false,
GenericJarClassifierFactory.PARAM_CLASSIFIER_JAR_PATH,
new File(new File(directory,"joint"), "model.jar").getPath());
aggregateBuilder.add(aed);
//closure for system:
// aggregateBuilder.add(
// AnalysisEngineFactory.createEngineDescription(AddClosure.class)//AnalysisEngineFactory.createPrimitiveDescription(AddTransitiveContainsRelations.class),
// );
// aed = DocTimeRelAnnotator.createAnnotatorDescription(new File("target/eval/event-properties/train_and_test/docTimeRel/model.jar").getAbsolutePath());
// aggregateBuilder.add(aed);
// aggregateBuilder.add(AnalysisEngineFactory.createEngineDescription(CrossSentenceTemporalRelationAnnotator.class));
// aggregateBuilder.add(AnalysisEngineFactory.createEngineDescription(WithinSentenceBeforeRelationAnnotator.class));
if(this.anaforaOutput != null){
aed = AnalysisEngineFactory.createEngineDescription(WriteAnaforaXML.class, WriteAnaforaXML.PARAM_OUTPUT_DIR, this.anaforaOutput);
aggregateBuilder.add(aed, "TimexView", CAS.NAME_DEFAULT_SOFA);
}
File outf = null;
if (recallModeEvaluation && this.useClosure) {//add closure for system output
aggregateBuilder.add(
AnalysisEngineFactory.createEngineDescription(AddClosure.class),//AnalysisEngineFactory.createPrimitiveDescription(AddTransitiveContainsRelations.class),
GOLD_VIEW_NAME,
CAS.NAME_DEFAULT_SOFA
);
outf = new File("target/brain_biLstm_recall_dev.txt");
}else if (!recallModeEvaluation && this.useClosure){
outf = new File("target/brain_biLstm_precision_dev.txt");
}else{
outf = new File("target/colon_ernie2filtered_contains_colon_test_closure.txt");
}
PrintWriter outDrop =null;
outDrop = new PrintWriter(new BufferedWriter(new FileWriter(outf, false)));
Function<BinaryTextRelation, ?> getSpan = new Function<BinaryTextRelation, HashableArguments>() {
public HashableArguments apply(BinaryTextRelation relation) {
return new HashableArguments(relation);
}
};
Function<BinaryTextRelation, String> getOutcome = AnnotationStatistics.annotationToFeatureValue("category");
int withinSentRelations = 0;
int crossSentRelations = 0;
int withinSentCorrect = 0;
int crossSentCorrect = 0;
int withinSentGolds = 0;
int crossSentGolds = 0;
AnnotationStatistics<String> stats = new AnnotationStatistics<>();
JCasIterator jcasIter =new JCasIterator(collectionReader, aggregateBuilder.createAggregate());
JCas jCas = null;
while(jcasIter.hasNext()) {
jCas = jcasIter.next();
JCas goldView = jCas.getView(GOLD_VIEW_NAME);
JCas systemView = jCas.getView(CAS.NAME_DEFAULT_SOFA);
Map<Annotation, List<Sentence>> sentCoveringMap = JCasUtil.indexCovering(systemView, Annotation.class, Sentence.class);
Map<Annotation, List<Sentence>> goldSentCoveringMap = JCasUtil.indexCovering(goldView, Annotation.class, Sentence.class);
Collection<BinaryTextRelation> goldRelations = JCasUtil.select(
goldView,
BinaryTextRelation.class);
Collection<BinaryTextRelation> systemRelations = JCasUtil.select(
systemView,
BinaryTextRelation.class);
stats.add(goldRelations, systemRelations, getSpan, getOutcome);
if(this.printRelations){
URI uri = ViewUriUtil.getURI(jCas);
String[] path = uri.getPath().split("/");
printRelationAnnotations(path[path.length - 1], systemRelations);
}
if(this.printErrors){
Map<HashableArguments, BinaryTextRelation> goldMap = Maps.newHashMap();
for (BinaryTextRelation relation : goldRelations) {
goldMap.put(new HashableArguments(relation), relation);
}
Map<HashableArguments, BinaryTextRelation> systemMap = Maps.newHashMap();
for (BinaryTextRelation relation : systemRelations) {
systemMap.put(new HashableArguments(relation), relation);
}
Set<HashableArguments> all = Sets.union(goldMap.keySet(), systemMap.keySet());
List<HashableArguments> sorted = Lists.newArrayList(all);
Collections.sort(sorted);
if(jCas != null){
outDrop.println("Doc id: " + ViewUriUtil.getURI(jCas).toString());
for (HashableArguments key : sorted) {
BinaryTextRelation goldRelation = goldMap.get(key);
BinaryTextRelation systemRelation = systemMap.get(key);
if (goldRelation == null) {
//outDrop.println("System added: " + formatRelation(systemRelation));
if(checkArgumentsInTheSameSent(systemRelation, sentCoveringMap)){
withinSentRelations+=1;
outDrop.println("System added within-sent: " + formatRelation(systemRelation));
}else{
crossSentRelations+=1;
outDrop.println("System added cross-sent: " + formatRelation(systemRelation));
}
} else if (systemRelation == null) {
//outDrop.println("System dropped: " + formatRelation(goldRelation));
if(checkArgumentsInTheSameSent(goldRelation, goldSentCoveringMap)){
withinSentGolds+=1;
outDrop.println("System dropped within-sent: " + formatRelation(goldRelation));
}else{
crossSentGolds+=1;
outDrop.println("System dropped cross-sent: " + formatRelation(goldRelation));
}
} else if (!systemRelation.getCategory().equals(goldRelation.getCategory())) {
String label = systemRelation.getCategory();
//outDrop.printf("System labeled %s for %s\n", label, formatRelation(goldRelation));
if(checkArgumentsInTheSameSent(systemRelation, sentCoveringMap)){
withinSentRelations+=1;
outDrop.printf("System labeled within-sent %s for %s\n", label, formatRelation(goldRelation));
}else{
crossSentRelations+=1;
outDrop.printf("System labeled cross-sent %s for %s\n", label, formatRelation(goldRelation));
}
if(checkArgumentsInTheSameSent(goldRelation, goldSentCoveringMap)){
withinSentGolds+=1;
}else{
crossSentGolds+=1;
}
} else{
//outDrop.println("Nailed it! " + formatRelation(systemRelation));
if(checkArgumentsInTheSameSent(systemRelation, sentCoveringMap)){
withinSentRelations+=1;
withinSentCorrect +=1;
outDrop.println("Nailed it within-sent! " + formatRelation(systemRelation));
}else{
crossSentRelations+=1;
crossSentCorrect +=1;
outDrop.println("Nailed it cross-sent! " + formatRelation(systemRelation));
}
if(checkArgumentsInTheSameSent(goldRelation, goldSentCoveringMap)){
withinSentGolds+=1;
}else{
crossSentGolds+=1;
}
}
}
}
}
}
System.out.print("There are "+ withinSentRelations + " within Sentence Predictions; " + withinSentCorrect+ " are correct predictions\n");
System.out.print("There are "+ crossSentRelations + " cross Sentence Predictions; " + crossSentCorrect+ " are correct predictions\n");
System.out.print("There are "+ crossSentGolds + " cross Sentence Gold Relations; " + withinSentGolds+ " are within-sent gold relations\n");
outDrop.close();
return stats;
}