in videocategorization/tgi_inference_client.py [0:0]
def main():
# Get server address, port, and block number from command-line arguments
if len(sys.argv) != 4:
print("Usage: python script_name.py <server_address> <port> <block_number>")
sys.exit(1)
server_address = sys.argv[1]
port = sys.argv[2]
block_number = int(sys.argv[3])
# Validate block number
if block_number < 0 or block_number > 3:
print("Error: block_number must be between 0 and 3.")
sys.exit(1)
# Construct endpoint URL
endpoint_url = f"http://{server_address}:{port}/generate"
# Initialize tokenizer
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-70B-Instruct")
# List all JSON files in the prompts directory
files = [f for f in os.listdir("prompts") if f.endswith(".json")]
# Sort files to ensure consistent partitioning
files.sort()
# Divide files into 4 blocks
total_files = len(files)
block_size = ceil(total_files / 4)
# Determine start and end indices for the current block
start_index = block_number * block_size
end_index = min(start_index + block_size, total_files)
# Process only the files in the current block
for i in range(start_index, end_index):
file_path = os.path.join("prompts", files[i])
process_file(file_path, tokenizer, endpoint_url)