in bench/generation/evaluate_configurations.py [0:0]
def main():
parser = argparse.ArgumentParser(description="Evaluate quantized model predictions on Lambada Dataset")
parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)")
parser.add_argument(
"--model",
type=str,
default="facebook/opt-350m",
help="The name of the trained Model.",
)
parser.add_argument("--device", type=str, default=None, help="The device to use for generation.")
parser.add_argument("--metric", type=str, default="prediction", choices=["latency", "prediction", "perplexity"])
parser.add_argument("--batch_size", type=int, default=32, help="The batch size during evaluation.")
parser.add_argument("--dtype", type=str, help="Use the following dtype to load the model.")
parser.add_argument("--json", action="store_true", help="Dump the results to a json file.")
parser.add_argument("--png", action="store_true", help="Generate a PNG.")
args = parser.parse_args()
torch.manual_seed(args.seed)
if args.device is None:
if torch.cuda.is_available():
device = torch.device("cuda")
elif torch.backends.mps.is_available():
device = torch.device("mps")
elif torch.xpu.is_available():
device = torch.device("xpu")
else:
device = torch.device("cpu")
else:
device = torch.device(args.device)
if args.dtype is None:
config = AutoConfig.from_pretrained(args.model)
dtype = getattr(config, "torch_dtype", torch.float16)
else:
dtype = torch.float16 if args.dtype == "fp16" else torch.bfloat16
results = evaluate_model_configurations(args.model, args.metric, device, batch_size=args.batch_size, dtype=dtype)
if args.json:
model_name = args.model.split("/")[-1]
json_path = f"{model_name}-{args.metric}.json"
with open(json_path, "w") as fp:
json.dump({model_name: results}, fp, indent=4)
if args.png:
if args.metric == "latency":
title = f"{args.model}: Mean latency per token"
label = "Latency (ms)"
elif args.metric == "prediction":
title = f"{args.model}: Prediction accuracy on Lambada dataset"
label = "Accuracy"
elif args.metric == "perplexity":
title = f"{args.model}: Perplexity evaluated on WikiText dataset"
label = "Perplexity"
gen_barchart(args.model, title, label, results, dtype)