videocategorization/tgi_inference_client.py (76 lines of code) (raw):

import json import os import requests from tqdm import tqdm from transformers import AutoTokenizer import re import sys from math import ceil # # This script will run the defined prompts against one or more TGI services # the prompts are stored in chunks in a folder called prompts/ # The script is called with 3 parameters: # python tgi_inference_client.py <server_address> <port> <block_number> # block_number is a number between 0 and 3 (both included). Those blocks are 4 subdivisions of the prompts in prompts/ # and by specifying the block number we run inference in each different block, this allow us to parallelize inference. # # Ensure the output directory exists os.makedirs("processed", exist_ok=True) # Function to load prompts from a single JSON file def load_prompts_from_file(file_path): with open(file_path, "r", encoding="utf-8") as file: tasks = json.load(file) return tasks # Function to process a single file's tasks and save results def process_file(file_path, tokenizer, endpoint_url): # Load tasks from the current file tasks = load_prompts_from_file(file_path) results = [] # Headers for the HTTP request headers = { "Content-Type": "application/json", } # Process each task for task in tqdm(tasks, desc="Processing tasks"): video_id = task['video_id'] input_text = task['prompt'] input_text = input_text.replace("Given those categories:", "Given this taxonomy:") pattern = r"Categories: \[.*?\]\n?" input_text = re.sub(pattern, '', input_text) pattern = r"Tags: \[.*?\]\n?" input_text = re.sub(pattern, '', input_text) pattern = r"Description: \[.*?\]\n?" input_text = re.sub(pattern, '', input_text) input_text = input_text + "RETURN A CATEGORY FROM THE TAXONOMY PROVIDED: " prompt_tokens = tokenizer.apply_chat_template( [ {"role": "user", "content": input_text}, ], tokenize=False, add_generation_prompt=True ) # Prepare the data for the request data = { "inputs": prompt_tokens, "parameters": { "max_new_tokens": 20, # Adjust as needed }, } # Make a synchronous request to the model endpoint response = requests.post(endpoint_url, headers=headers, json=data) if response.status_code == 200: response_data = response.json() completion = response_data.get('generated_text', '') else: completion = "Error: Unable to get response" # Append the result results.append({"video_id": video_id, "completion": completion}) # Save results to file after processing all tasks in the file output_filename = os.path.splitext(os.path.basename(file_path))[0] with open(f"processed/{output_filename}_results.json", "w", encoding="utf-8") as f: json.dump(results, f, ensure_ascii=False, indent=4) # Main function to process a subset of files def main(): # Get server address, port, and block number from command-line arguments if len(sys.argv) != 4: print("Usage: python script_name.py <server_address> <port> <block_number>") sys.exit(1) server_address = sys.argv[1] port = sys.argv[2] block_number = int(sys.argv[3]) # Validate block number if block_number < 0 or block_number > 3: print("Error: block_number must be between 0 and 3.") sys.exit(1) # Construct endpoint URL endpoint_url = f"http://{server_address}:{port}/generate" # Initialize tokenizer tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-70B-Instruct") # List all JSON files in the prompts directory files = [f for f in os.listdir("prompts") if f.endswith(".json")] # Sort files to ensure consistent partitioning files.sort() # Divide files into 4 blocks total_files = len(files) block_size = ceil(total_files / 4) # Determine start and end indices for the current block start_index = block_number * block_size end_index = min(start_index + block_size, total_files) # Process only the files in the current block for i in range(start_index, end_index): file_path = os.path.join("prompts", files[i]) process_file(file_path, tokenizer, endpoint_url) # Run the main function if __name__ == "__main__": main()