fn main()

in benchmark/src/main.rs [108:208]


fn main() -> Result<(), Box<dyn std::error::Error>> {
    init_logging();

    // Get args
    let args = Args::parse();
    // Pattern match configuration
    let Args {
        tokenizer_name,
        revision,
        batch_size,
        sequence_length,
        decode_length,
        runs,
        warmups,
        temperature,
        top_k,
        top_p,
        typical_p,
        repetition_penalty,
        frequency_penalty,
        watermark,
        do_sample,
        master_shard_uds_path,
        top_n_tokens,
    } = args;

    let batch_size = batch_size.unwrap_or(vec![1, 2, 4, 8, 16, 32]);

    // Tokenizer instance
    // This will only be used to validate payloads
    tracing::info!("Loading tokenizer");
    let local_path = Path::new(&tokenizer_name);
    let tokenizer =
        if local_path.exists() && local_path.is_dir() && local_path.join("tokenizer.json").exists()
        {
            // Load local tokenizer
            tracing::info!("Found local tokenizer");
            Tokenizer::from_file(local_path.join("tokenizer.json")).unwrap()
        } else {
            tracing::info!("Downloading tokenizer");

            // Parse Huggingface hub token
            let token = std::env::var("HF_TOKEN")
                .or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN"))
                .ok();

            // Download and instantiate tokenizer
            // We need to download it outside of the Tokio runtime
            let params = FromPretrainedParameters {
                revision,
                token,
                ..Default::default()
            };
            Tokenizer::from_pretrained(tokenizer_name.clone(), Some(params)).unwrap()
        };
    tracing::info!("Tokenizer loaded");

    // Launch Tokio runtime
    tokio::runtime::Builder::new_multi_thread()
        .enable_all()
        .build()
        .unwrap()
        .block_on(async {
            // Instantiate sharded client from the master unix socket
            tracing::info!("Connect to model server");
            let mut sharded_client = ShardedClient::connect_uds(master_shard_uds_path)
                .await
                .expect("Could not connect to server");
            // Clear the cache; useful if the webserver rebooted
            sharded_client
                .clear_cache(None)
                .await
                .expect("Unable to clear cache");

            tracing::info!("Connected");

            // Run app
            text_generation_benchmark::run(
                tokenizer_name,
                tokenizer,
                batch_size,
                sequence_length,
                decode_length,
                top_n_tokens,
                runs,
                warmups,
                temperature,
                top_k,
                top_p,
                typical_p,
                repetition_penalty,
                frequency_penalty,
                watermark,
                do_sample,
                sharded_client,
            )
            .await
            .unwrap();
        });
    Ok(())
}