in accuracy.py [0:0]
def accuracy(args):
checkpoint = args.ckpt_id_or_path
run_name = checkpoint.split("/")[-1]
wandb.init(project="shot-categorizer", name=f"{run_name}-eval")
florence = FlorenceTool(
checkpoint,
device="cuda",
dtype="float16",
check_task_types=False,
)
florence.load_model()
dataset = load_dataset(args.dataset_id, split=args.dataset_split)
color = {"match": {}, "extra": {}, "missing": {}}
lighting = {"match": {}, "extra": {}, "missing": {}}
lighting_type = {"match": {}, "extra": {}, "missing": {}}
composition = {"match": {}, "extra": {}, "missing": {}}
count = 0
for i, still in tqdm.tqdm(enumerate(dataset)):
image = still["image"].convert("RGB")
original = {
"<COLOR>": still["Color"] if still["Color"] else NONE_KEY_MAP["Color"],
"<LIGHTING>": still["Lighting"] if still["Lighting"] else NONE_KEY_MAP["Lighting"],
"<LIGHTING_TYPE>": still["Lighting Type"] if still["Lighting Type"] else NONE_KEY_MAP["Lighting Type"],
"<COMPOSITION>": still["Composition"] if still["Composition"] else NONE_KEY_MAP["Composition"],
}
output = florence.run(
image=image,
task_prompt=["<COLOR>", "<LIGHTING>", "<LIGHTING_TYPE>", "<COMPOSITION>"],
)
output["<COLOR>"] = output["<COLOR>"].split(", ")
output["<LIGHTING>"] = output["<LIGHTING>"].split(", ")
output["<LIGHTING_TYPE>"] = output["<LIGHTING_TYPE>"].split(", ")
output["<COMPOSITION>"] = output["<COMPOSITION>"].split(", ")
matching_color = set(original["<COLOR>"]).intersection(set(output["<COLOR>"]))
extra_color = set(output["<COLOR>"]) - set(original["<COLOR>"])
missing_color = set(original["<COLOR>"]) - set(output["<COLOR>"])
matching_lighting = set(original["<LIGHTING>"]).intersection(set(output["<LIGHTING>"]))
extra_lighting = set(output["<LIGHTING>"]) - set(original["<LIGHTING>"])
missing_lighting = set(original["<LIGHTING>"]) - set(output["<LIGHTING>"])
matching_lighting_type = set(original["<LIGHTING_TYPE>"]).intersection(set(output["<LIGHTING_TYPE>"]))
extra_lighting_type = set(output["<LIGHTING_TYPE>"]) - set(original["<LIGHTING_TYPE>"])
missing_lighting_type = set(original["<LIGHTING_TYPE>"]) - set(output["<LIGHTING_TYPE>"])
matching_composition = set(original["<COMPOSITION>"]).intersection(set(output["<COMPOSITION>"]))
extra_composition = set(output["<COMPOSITION>"]) - set(original["<COMPOSITION>"])
missing_composition = set(original["<COMPOSITION>"]) - set(output["<COMPOSITION>"])
color["match"][i] = len(matching_color)
color["extra"][i] = len(extra_color)
color["missing"][i] = len(missing_color)
lighting["match"][i] = len(matching_lighting)
lighting["extra"][i] = len(extra_lighting)
lighting["missing"][i] = len(missing_lighting)
lighting_type["match"][i] = len(matching_lighting_type)
lighting_type["extra"][i] = len(extra_lighting_type)
lighting_type["missing"][i] = len(missing_lighting_type)
composition["match"][i] = len(matching_composition)
composition["extra"][i] = len(extra_composition)
composition["missing"][i] = len(missing_composition)
count += 1
if count > 1000:
break
color_metrics = compute_metrics(color)
lighting_metrics = compute_metrics(lighting)
lighting_type_metrics = compute_metrics(lighting_type)
composition_metrics = compute_metrics(composition)
overall_precision = (
color_metrics["precision"]
+ lighting_metrics["precision"]
+ lighting_type_metrics["precision"]
+ composition_metrics["precision"]
) / 4
overall_recall = (
color_metrics["recall"]
+ lighting_metrics["recall"]
+ lighting_type_metrics["recall"]
+ composition_metrics["recall"]
) / 4
overall_f1 = (
color_metrics["f1_score"]
+ lighting_metrics["f1_score"]
+ lighting_type_metrics["f1_score"]
+ composition_metrics["f1_score"]
) / 4
overall_accuracy = (
color_metrics["accuracy"]
+ lighting_metrics["accuracy"]
+ lighting_type_metrics["accuracy"]
+ composition_metrics["accuracy"]
) / 4
print(
f"Color - Precision: {color_metrics['precision']:.2f}%, Recall: {color_metrics['recall']:.2f}%, F1-score: {color_metrics['f1_score']:.2f}%, Accuracy-score: {color_metrics['accuracy']:.2f}%"
)
print(
f"Lighting - Precision: {lighting_metrics['precision']:.2f}%, Recall: {lighting_metrics['recall']:.2f}%, F1-score: {lighting_metrics['f1_score']:.2f}%, Accuracy-score: {lighting_metrics['accuracy']:.2f}%"
)
print(
f"Lighting Type - Precision: {lighting_type_metrics['precision']:.2f}%, Recall: {lighting_type_metrics['recall']:.2f}%, F1-score: {lighting_type_metrics['f1_score']:.2f}%, Accuracy-score: {lighting_type_metrics['accuracy']:.2f}%"
)
print(
f"Composition - Precision: {composition_metrics['precision']:.2f}%, Recall: {composition_metrics['recall']:.2f}%, F1-score: {composition_metrics['f1_score']:.2f}%, Accuracy-score: {composition_metrics['accuracy']:.2f}%"
)
print(
f"Overall - Precision: {overall_precision:.2f}%, Recall: {overall_recall:.2f}%, F1-score: {overall_f1:.2f}%, Accuracy-score: {overall_accuracy:.2f}%"
)
json_name = checkpoint.split("/")[-1]
with open(f"{json_name}.json", "w") as f:
results = {
"color": color,
"lighting": lighting,
"lighting_type": lighting_type,
"composition": composition,
"color_metrics": color_metrics,
"lighting_metrics": lighting_metrics,
"lighting_type_metrics": lighting_type_metrics,
"composition_metrics": composition_metrics,
"overall_precision": overall_precision,
"overall_recall": overall_recall,
"overall_f1": overall_f1,
"overall_accuracy": overall_accuracy,
}
json.dump(results, f)
wandb.log(results)
wandb.finish()