accuracy.py (151 lines of code) (raw):
# git clone https://github.com/bigdata-pw/florence-tool.git; cd florence-tool; pip install -r requirements.txt; pip install -e .
from florence_tool import FlorenceTool
import argparse
from datasets import load_dataset
import wandb
import tqdm
import json
from data import NONE_KEY_MAP
def compute_metrics(category):
total_match = sum(category["match"].values())
total_missing = sum(category["missing"].values())
total_extra = sum(category["extra"].values())
total = total_match + total_missing + total_extra
precision = total_match / (total_match + total_extra) if (total_match + total_extra) > 0 else 0
recall = total_match / (total_match + total_missing) if (total_match + total_missing) > 0 else 0
f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
accuracy = (total_match / total) if total > 0 else 0
return {
"precision": precision * 100,
"recall": recall * 100,
"f1_score": f1_score * 100,
"accuracy": accuracy * 100,
}
def accuracy(args):
checkpoint = args.ckpt_id_or_path
run_name = checkpoint.split("/")[-1]
wandb.init(project="shot-categorizer", name=f"{run_name}-eval")
florence = FlorenceTool(
checkpoint,
device="cuda",
dtype="float16",
check_task_types=False,
)
florence.load_model()
dataset = load_dataset(args.dataset_id, split=args.dataset_split)
color = {"match": {}, "extra": {}, "missing": {}}
lighting = {"match": {}, "extra": {}, "missing": {}}
lighting_type = {"match": {}, "extra": {}, "missing": {}}
composition = {"match": {}, "extra": {}, "missing": {}}
count = 0
for i, still in tqdm.tqdm(enumerate(dataset)):
image = still["image"].convert("RGB")
original = {
"<COLOR>": still["Color"] if still["Color"] else NONE_KEY_MAP["Color"],
"<LIGHTING>": still["Lighting"] if still["Lighting"] else NONE_KEY_MAP["Lighting"],
"<LIGHTING_TYPE>": still["Lighting Type"] if still["Lighting Type"] else NONE_KEY_MAP["Lighting Type"],
"<COMPOSITION>": still["Composition"] if still["Composition"] else NONE_KEY_MAP["Composition"],
}
output = florence.run(
image=image,
task_prompt=["<COLOR>", "<LIGHTING>", "<LIGHTING_TYPE>", "<COMPOSITION>"],
)
output["<COLOR>"] = output["<COLOR>"].split(", ")
output["<LIGHTING>"] = output["<LIGHTING>"].split(", ")
output["<LIGHTING_TYPE>"] = output["<LIGHTING_TYPE>"].split(", ")
output["<COMPOSITION>"] = output["<COMPOSITION>"].split(", ")
matching_color = set(original["<COLOR>"]).intersection(set(output["<COLOR>"]))
extra_color = set(output["<COLOR>"]) - set(original["<COLOR>"])
missing_color = set(original["<COLOR>"]) - set(output["<COLOR>"])
matching_lighting = set(original["<LIGHTING>"]).intersection(set(output["<LIGHTING>"]))
extra_lighting = set(output["<LIGHTING>"]) - set(original["<LIGHTING>"])
missing_lighting = set(original["<LIGHTING>"]) - set(output["<LIGHTING>"])
matching_lighting_type = set(original["<LIGHTING_TYPE>"]).intersection(set(output["<LIGHTING_TYPE>"]))
extra_lighting_type = set(output["<LIGHTING_TYPE>"]) - set(original["<LIGHTING_TYPE>"])
missing_lighting_type = set(original["<LIGHTING_TYPE>"]) - set(output["<LIGHTING_TYPE>"])
matching_composition = set(original["<COMPOSITION>"]).intersection(set(output["<COMPOSITION>"]))
extra_composition = set(output["<COMPOSITION>"]) - set(original["<COMPOSITION>"])
missing_composition = set(original["<COMPOSITION>"]) - set(output["<COMPOSITION>"])
color["match"][i] = len(matching_color)
color["extra"][i] = len(extra_color)
color["missing"][i] = len(missing_color)
lighting["match"][i] = len(matching_lighting)
lighting["extra"][i] = len(extra_lighting)
lighting["missing"][i] = len(missing_lighting)
lighting_type["match"][i] = len(matching_lighting_type)
lighting_type["extra"][i] = len(extra_lighting_type)
lighting_type["missing"][i] = len(missing_lighting_type)
composition["match"][i] = len(matching_composition)
composition["extra"][i] = len(extra_composition)
composition["missing"][i] = len(missing_composition)
count += 1
if count > 1000:
break
color_metrics = compute_metrics(color)
lighting_metrics = compute_metrics(lighting)
lighting_type_metrics = compute_metrics(lighting_type)
composition_metrics = compute_metrics(composition)
overall_precision = (
color_metrics["precision"]
+ lighting_metrics["precision"]
+ lighting_type_metrics["precision"]
+ composition_metrics["precision"]
) / 4
overall_recall = (
color_metrics["recall"]
+ lighting_metrics["recall"]
+ lighting_type_metrics["recall"]
+ composition_metrics["recall"]
) / 4
overall_f1 = (
color_metrics["f1_score"]
+ lighting_metrics["f1_score"]
+ lighting_type_metrics["f1_score"]
+ composition_metrics["f1_score"]
) / 4
overall_accuracy = (
color_metrics["accuracy"]
+ lighting_metrics["accuracy"]
+ lighting_type_metrics["accuracy"]
+ composition_metrics["accuracy"]
) / 4
print(
f"Color - Precision: {color_metrics['precision']:.2f}%, Recall: {color_metrics['recall']:.2f}%, F1-score: {color_metrics['f1_score']:.2f}%, Accuracy-score: {color_metrics['accuracy']:.2f}%"
)
print(
f"Lighting - Precision: {lighting_metrics['precision']:.2f}%, Recall: {lighting_metrics['recall']:.2f}%, F1-score: {lighting_metrics['f1_score']:.2f}%, Accuracy-score: {lighting_metrics['accuracy']:.2f}%"
)
print(
f"Lighting Type - Precision: {lighting_type_metrics['precision']:.2f}%, Recall: {lighting_type_metrics['recall']:.2f}%, F1-score: {lighting_type_metrics['f1_score']:.2f}%, Accuracy-score: {lighting_type_metrics['accuracy']:.2f}%"
)
print(
f"Composition - Precision: {composition_metrics['precision']:.2f}%, Recall: {composition_metrics['recall']:.2f}%, F1-score: {composition_metrics['f1_score']:.2f}%, Accuracy-score: {composition_metrics['accuracy']:.2f}%"
)
print(
f"Overall - Precision: {overall_precision:.2f}%, Recall: {overall_recall:.2f}%, F1-score: {overall_f1:.2f}%, Accuracy-score: {overall_accuracy:.2f}%"
)
json_name = checkpoint.split("/")[-1]
with open(f"{json_name}.json", "w") as f:
results = {
"color": color,
"lighting": lighting,
"lighting_type": lighting_type,
"composition": composition,
"color_metrics": color_metrics,
"lighting_metrics": lighting_metrics,
"lighting_type_metrics": lighting_type_metrics,
"composition_metrics": composition_metrics,
"overall_precision": overall_precision,
"overall_recall": overall_recall,
"overall_f1": overall_f1,
"overall_accuracy": overall_accuracy,
}
json.dump(results, f)
wandb.log(results)
wandb.finish()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--ckpt_id_or_path", type=str, help="Checkpoint ID from the HF Hub or local file path.")
parser.add_argument("--dataset_id", type=str, help="Dataset ID to use.")
parser.add_argument("--dataset_split", type=str, default="train", help="Dataset split to use.")
args = parser.parse_args()
accuracy(args)