in videocategorization/tgi_inference_client.py [0:0]
def process_file(file_path, tokenizer, endpoint_url):
# Load tasks from the current file
tasks = load_prompts_from_file(file_path)
results = []
# Headers for the HTTP request
headers = {
"Content-Type": "application/json",
}
# Process each task
for task in tqdm(tasks, desc="Processing tasks"):
video_id = task['video_id']
input_text = task['prompt']
input_text = input_text.replace("Given those categories:", "Given this taxonomy:")
pattern = r"Categories: \[.*?\]\n?"
input_text = re.sub(pattern, '', input_text)
pattern = r"Tags: \[.*?\]\n?"
input_text = re.sub(pattern, '', input_text)
pattern = r"Description: \[.*?\]\n?"
input_text = re.sub(pattern, '', input_text)
input_text = input_text + "RETURN A CATEGORY FROM THE TAXONOMY PROVIDED: "
prompt_tokens = tokenizer.apply_chat_template(
[
{"role": "user", "content": input_text},
],
tokenize=False,
add_generation_prompt=True
)
# Prepare the data for the request
data = {
"inputs": prompt_tokens,
"parameters": {
"max_new_tokens": 20, # Adjust as needed
},
}
# Make a synchronous request to the model endpoint
response = requests.post(endpoint_url, headers=headers, json=data)
if response.status_code == 200:
response_data = response.json()
completion = response_data.get('generated_text', '')
else:
completion = "Error: Unable to get response"
# Append the result
results.append({"video_id": video_id, "completion": completion})
# Save results to file after processing all tasks in the file
output_filename = os.path.splitext(os.path.basename(file_path))[0]
with open(f"processed/{output_filename}_results.json", "w", encoding="utf-8") as f:
json.dump(results, f, ensure_ascii=False, indent=4)