in ctakes-assertion/src/main/java/org/apache/ctakes/assertion/eval/AssertionEvaluation.java [278:430]
public static void main(String[] args) throws Exception {
System.out.println("Started assertion module at " + new Date());
resetOptions();
CmdLineParser parser = new CmdLineParser(options);
parser.parseArgument(args);
if (useEvaluationLogFile && evaluationLogFileOut == null) {
evaluationLogFile = new File(evaluationLogFilePath);
evaluationLogFileOut = new BufferedWriter(new FileWriter(evaluationLogFile), 32768);
}
printOptionsForDebugging(options);
List<File> trainFiles = new ArrayList<>();
if (null != options.trainDirectory) {
String[] dirs = options.trainDirectory.split("[;:]");
for (String dir : dirs) {
File trainDir = new File(dir);
if (trainDir.listFiles()!=null) {
for (File f : trainDir.listFiles()) {
trainFiles.add(f);
}
}
}
}
File modelsDir = options.modelsDirectory;
if(options.useTmp){
File tempModelDir = new File(options.modelsDirectory, "temp");
tempModelDir.mkdirs();
File curModelDir = File.createTempFile("assertion", null, tempModelDir);
curModelDir.delete();
curModelDir.mkdir();
modelsDir = curModelDir;
}
File evaluationOutputDirectory = options.evaluationOutputDirectory;
ArrayList<String> annotationTypes = new ArrayList<>();
if (!options.ignorePolarity) { annotationTypes.add("polarity"); }
if (!options.ignoreConditional) { annotationTypes.add("conditional"); }
if (!options.ignoreUncertainty) { annotationTypes.add("uncertainty"); }
if (!options.ignoreSubject) { annotationTypes.add("subject"); }
if (!options.ignoreGeneric) { annotationTypes.add("generic"); }
if (!options.ignoreHistory) { annotationTypes.add("historyOf"); }
String[] kernelParams = null;
if(options.kernelParams != null){
kernelParams = options.kernelParams.split("\\s+");
}else{
kernelParams = new String[]{"-c", "1.0"};
}
Class<? extends DataWriter<String>> dw = null;
if(options.featConfig == FEATURE_CONFIG.STK || options.featConfig == FEATURE_CONFIG.PTK){
// dw = TKLibSvmStringOutcomeDataWriter.class;
throw new UnsupportedOperationException("This requires cleartk-2.0 which");
}
dw = LibLinearStringOutcomeDataWriter.class;
AssertionEvaluation evaluation = new AssertionEvaluation(
modelsDir,
evaluationOutputDirectory,
annotationTypes,
dw,
kernelParams
);
// if preprocessing, don't do anything else
if(options.preprocessDir!=null ) {
preprocess(options.preprocessDir);
}
// run cross-validation
else if(options.crossValidationFolds != null) {
// run n-fold cross-validation
List<Map<String, AnnotationStatisticsCompact<String>>> foldStats = evaluation.crossValidation(trainFiles, options.crossValidationFolds);
//AnnotationStatisticsCompact overallStats = AnnotationStatisticsCompact.addAll(foldStats);
Map<String, AnnotationStatisticsCompact<String>> overallStats = new TreeMap<>();
for (String currentAnnotationType : annotationTypes)
{
AnnotationStatisticsCompact<String> currentAnnotationStatisticsCompact = new AnnotationStatisticsCompact<>();
overallStats.put(currentAnnotationType, currentAnnotationStatisticsCompact);
}
for (Map<String, AnnotationStatisticsCompact<String>> singleFoldMap : foldStats)
{
for (String currentAnnotationType : annotationTypes)
{
AnnotationStatisticsCompact<String> currentFoldStatistics = singleFoldMap.get(currentAnnotationType);
overallStats.get(currentAnnotationType).addAll(currentFoldStatistics);
}
}
AssertionEvaluation.printScore(overallStats, "CROSS FOLD OVERALL");
}
else if (Math.abs(options.portionOfDataToUse - 1.0) > 0.001){
int numIters = 5;
List<File> testFiles = Arrays.asList(options.testDirectory.listFiles());
Map<String, Double> overallStats = new TreeMap<>();
for(String annotationType : annotationTypes){
overallStats.put(annotationType, 0.0);
}
for(int iter = 0; iter < numIters; iter++){
Map<String,AnnotationStatisticsCompact<String>> stats = evaluation.trainAndTest(trainFiles, testFiles);
AssertionEvaluation.printScore(stats, "Sample " + iter + " score:");
for(String annotationType : stats.keySet()){
overallStats.put(annotationType, overallStats.get(annotationType) + stats.get(annotationType).f1("-1"));
}
}
for(String annotationType : annotationTypes){
System.out.println("Macro-average F-score for " + annotationType + " is: " + (overallStats.get(annotationType) / numIters));
}
// AssertionEvaluation.printScore(overallStats, "Learning Curve Proportion Average");
}
// run train and test
else {
// train on the entire training set and evaluate on the test set
List<File> testFiles;
if (options.evalOnly) {
testFiles = Arrays.asList(options.evaluationOutputDirectory.listFiles());
LOGGER.debug("evalOnly using files in directory " + evaluationOutputDirectory.getName() + " aka " + evaluationOutputDirectory.getCanonicalPath());
} else if (options.trainOnly){
testFiles = new ArrayList<>();
} else {
testFiles = Arrays.asList(options.testDirectory.listFiles());
}
if (!options.testOnly && !options.evalOnly) {
CollectionReader trainCollectionReader = evaluation.getCollectionReader(trainFiles);
evaluation.train(trainCollectionReader, modelsDir);
}
// run testing
if (!options.trainOnly) {
if (testFiles==null || testFiles.size()==0) {
throw new RuntimeException("testFiles = " + testFiles + " testFiles.size() = " + (testFiles==null ? "null": testFiles.size())) ;
}
LOGGER.debug("testFiles.size() = " + testFiles.size());
CollectionReader testCollectionReader = evaluation.getCollectionReader(testFiles);
Map<String, AnnotationStatisticsCompact<String>> stats = evaluation.test(testCollectionReader, modelsDir);
AssertionEvaluation.printScore(stats, modelsDir!=null? modelsDir.getAbsolutePath() : "no_model");
}
}
if(options.useTmp && modelsDir != null){
FileUtils.deleteRecursive(modelsDir);
}
System.out.println("Finished assertion module at " + new Date());
}