in main/src/search-service/search/src/main/java/com/bioimage/search/App.java [331:475]
private void searchByImageId(String[] messageArr) {
int i=1;
String searchId=messageArr[i++];
String trainId=messageArr[i++];
String imageId=messageArr[i++];
String metric=messageArr[i++];
int maxHits=new Integer(messageArr[i++]);
Boolean requireMoa=messageArr[i++].equals("true");
int inclusionTagCount = new Integer(messageArr[i++]);
Set<Integer> inclusionTags = new HashSet<>();
int j=0;
for (;j<inclusionTagCount;j++) {
inclusionTags.add(new Integer(messageArr[i++]));
}
int exclusionTagCount = new Integer(messageArr[i++]);
Set<Integer> exclusionTags = new HashSet<>();
for (j=0;j<exclusionTagCount;j++) {
exclusionTags.add(new Integer(messageArr[i++]));
}
System.out.println("trainId="+trainId+", imageId="+imageId+", metric="+metric+", maxHits="+maxHits);
distanceType=DEFAULT_METRIC;
if (metric.equals("Euclidean")) {
distanceType=EUCLIDEAN_TYPE;
} else if (metric.equals("Cosine")) {
distanceType=COSINE_TYPE;
}
if (inclusionTagCount==0) {
System.out.println("inclusionTags=none");
} else {
System.out.println("inclusionTags=");
for (Integer tag : inclusionTags) {
System.out.println(tag);
}
}
if (exclusionTagCount==0) {
System.out.println("exclusionTags=none");
} else {
System.out.println("exclusionTags=");
for (Integer tag : exclusionTags) {
System.out.println(tag);
}
}
long timestamp1=new Date().getTime();
Map<String, ImageEmbedding> trainImageMap = getImageMap(trainId);
System.out.println("trainImageMap size="+trainImageMap.size());
// We apply inclusion filter first, then exclusion filter.
// If there are no inclusion entries, then everything is included.
// If there are no exclusion entries, then nothing is excluded.
List<ImageEmbedding> filteredImages = null;
if (inclusionTagCount==0 && exclusionTagCount==0 && (!requireMoa)) {
filteredImages = trainImageMap.values().stream().collect(Collectors.toList());
} else {
// First inclusion pass
if (inclusionTagCount>0) {
filteredImages = trainImageMap.values().stream().
filter(e -> {
int[] imageTags = tagMap.get(e.imageId);
if (imageTags==null) {
return false;
}
for (Integer tag : imageTags) {
if (inclusionTags.contains(tag)) {
return true;
}
}
return false;
}).collect(Collectors.toList());
} else {
filteredImages = trainImageMap.values().stream().collect(Collectors.toList());
}
System.out.println("Post inclusion pass, filteredImages size="+filteredImages.size());
// Then exclusion pass
if (exclusionTagCount>0 || requireMoa) {
filteredImages = filteredImages.stream().
filter(e -> {
int[] imageTags = tagMap.get(e.imageId);
if (imageTags==null && (!requireMoa)) {
return true;
}
boolean hasMoaLabel=false;
for (Integer tag : imageTags) {
if (exclusionTags.contains(tag)) {
return false;
}
String label = tagLabelMap.get(tag);
if (label.startsWith("moa:")) {
hasMoaLabel=true;
}
}
if (requireMoa) {
if (hasMoaLabel) {
return true;
} else {
return false;
}
} else {
return true;
}
}).collect(Collectors.toList());
}
}
System.out.println("filteredImages size="+filteredImages.size());
ImageEmbedding[] arr = new ImageEmbedding[filteredImages.size()];
arr = filteredImages.toArray(arr);
ImageEmbedding queryImage=trainImageMap.get(imageId);
if (queryImage==null) {
System.out.println("queryImage is null");
updateSearchStatus(searchId, "error");
return;
}
System.out.println("pre arr length="+arr.length);
long timestamp2=new Date().getTime();
Arrays.parallelSort(arr, queryImage);
long timestamp3=new Date().getTime();
System.out.println("post arr length="+arr.length);
long createArrayMs=timestamp2-timestamp1;
long sortMs=timestamp3-timestamp2;
System.out.println("Closest two matches are "+arr[0].imageId+", "+arr[1].imageId);
ImageEmbedding queryImageEmbedding = trainImageMap.get(imageId);
ImageEmbedding hit0 = trainImageMap.get(arr[0].imageId);
ImageEmbedding hit1 = trainImageMap.get(arr[1].imageId);
// System.out.println("DEBUG===");
// System.out.println(queryImageEmbedding);
// System.out.println(hit0);
// System.out.println(hit1);
// System.out.println("===");
System.out.println("Create array ms="+createArrayMs);
System.out.println("Sort ms="+sortMs);
if (hitArray==null || hitArray.length!=maxHits) {
hitArray = new ImageEmbedding[maxHits];
}
for (i=0;i<maxHits;i++) {
hitArray[i]=arr[i];
}
createSearchResults(searchId, queryImage, hitArray);
long timestamp4=new Date().getTime();
long messageMs=timestamp4-timestamp3;
System.out.println("Result message ms="+messageMs);
}