in libraries/python/classification_compare.py [0:0]
def compare(self):
benchmark_data, dims_list = self.getData(self.args.benchmark_output)
with open(self.args.labels, "r") as f:
content = f.read()
golden_lines = [
item.strip().split(",") for item in content.strip().split("\n")
]
golden_data = [
{"index": int(item[0]), "label": item[1], "path": item[2]}
for item in golden_lines
]
if len(benchmark_data) != len(golden_data):
idx = dims_list.index(len(golden_data))
benchmark_data = np.reshape(
benchmark_data, (dims_list[idx], dims_list[idx + 1])
)
assert len(benchmark_data) == len(
golden_data
), "Benchmark data has {} entries, ".format(
len(benchmark_data)
) + "but golden data has {} entries".format(
len(golden_data)
)
def sort_key(elem):
return elem["value"]
for i in range(len(benchmark_data)):
benchmark_one_entry = benchmark_data[i]
golden_one_entry = golden_data[i]
benchmark_result = [
{
"index": j,
"value": benchmark_one_entry[j],
}
for j in range(len(benchmark_one_entry))
]
benchmark_result.sort(reverse=True, key=sort_key)
golden_one_entry["predict"] = (
1
if golden_one_entry["index"]
in [item["index"] for item in benchmark_result[: self.args.top]]
else 0
)
self.writeResult(golden_data)