def main()

in videocategorization/tgi_inference_client.py [0:0]


def main():
    # Get server address, port, and block number from command-line arguments
    if len(sys.argv) != 4:
        print("Usage: python script_name.py <server_address> <port> <block_number>")
        sys.exit(1)

    server_address = sys.argv[1]
    port = sys.argv[2]
    block_number = int(sys.argv[3])

    # Validate block number
    if block_number < 0 or block_number > 3:
        print("Error: block_number must be between 0 and 3.")
        sys.exit(1)

    # Construct endpoint URL
    endpoint_url = f"http://{server_address}:{port}/generate"

    # Initialize tokenizer
    tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-70B-Instruct")

    # List all JSON files in the prompts directory
    files = [f for f in os.listdir("prompts") if f.endswith(".json")]

    # Sort files to ensure consistent partitioning
    files.sort()

    # Divide files into 4 blocks
    total_files = len(files)
    block_size = ceil(total_files / 4)

    # Determine start and end indices for the current block
    start_index = block_number * block_size
    end_index = min(start_index + block_size, total_files)

    # Process only the files in the current block
    for i in range(start_index, end_index):
        file_path = os.path.join("prompts", files[i])
        process_file(file_path, tokenizer, endpoint_url)