def main()

in bench/generation/evaluate_configurations.py [0:0]


def main():
    parser = argparse.ArgumentParser(description="Evaluate quantized model predictions on Lambada Dataset")
    parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)")
    parser.add_argument(
        "--model",
        type=str,
        default="facebook/opt-350m",
        help="The name of the trained Model.",
    )
    parser.add_argument("--device", type=str, default=None, help="The device to use for generation.")
    parser.add_argument("--metric", type=str, default="prediction", choices=["latency", "prediction", "perplexity"])
    parser.add_argument("--batch_size", type=int, default=32, help="The batch size during evaluation.")
    parser.add_argument("--dtype", type=str, help="Use the following dtype to load the model.")
    parser.add_argument("--json", action="store_true", help="Dump the results to a json file.")
    parser.add_argument("--png", action="store_true", help="Generate a PNG.")
    args = parser.parse_args()

    torch.manual_seed(args.seed)

    if args.device is None:
        if torch.cuda.is_available():
            device = torch.device("cuda")
        elif torch.backends.mps.is_available():
            device = torch.device("mps")
        elif torch.xpu.is_available():
            device = torch.device("xpu")
        else:
            device = torch.device("cpu")
    else:
        device = torch.device(args.device)

    if args.dtype is None:
        config = AutoConfig.from_pretrained(args.model)
        dtype = getattr(config, "torch_dtype", torch.float16)
    else:
        dtype = torch.float16 if args.dtype == "fp16" else torch.bfloat16
    results = evaluate_model_configurations(args.model, args.metric, device, batch_size=args.batch_size, dtype=dtype)
    if args.json:
        model_name = args.model.split("/")[-1]
        json_path = f"{model_name}-{args.metric}.json"
        with open(json_path, "w") as fp:
            json.dump({model_name: results}, fp, indent=4)
    if args.png:
        if args.metric == "latency":
            title = f"{args.model}: Mean latency per token"
            label = "Latency (ms)"
        elif args.metric == "prediction":
            title = f"{args.model}: Prediction accuracy on Lambada dataset"
            label = "Accuracy"
        elif args.metric == "perplexity":
            title = f"{args.model}: Perplexity evaluated on WikiText dataset"
            label = "Perplexity"
        gen_barchart(args.model, title, label, results, dtype)