def process_file()

in videocategorization/tgi_inference_client.py [0:0]


def process_file(file_path, tokenizer, endpoint_url):
    # Load tasks from the current file
    tasks = load_prompts_from_file(file_path)
    results = []

    # Headers for the HTTP request
    headers = {
        "Content-Type": "application/json",
    }

    # Process each task
    for task in tqdm(tasks, desc="Processing tasks"):
        video_id = task['video_id']
        input_text = task['prompt']
        input_text = input_text.replace("Given those categories:", "Given this taxonomy:")
        pattern = r"Categories: \[.*?\]\n?"
        input_text = re.sub(pattern, '', input_text)
        pattern = r"Tags: \[.*?\]\n?"
        input_text = re.sub(pattern, '', input_text)
        pattern = r"Description: \[.*?\]\n?"
        input_text = re.sub(pattern, '', input_text)
        input_text = input_text + "RETURN A CATEGORY FROM THE TAXONOMY PROVIDED: "

        prompt_tokens = tokenizer.apply_chat_template(
            [
                {"role": "user", "content": input_text},
            ],
            tokenize=False,
            add_generation_prompt=True
        )

        # Prepare the data for the request
        data = {
            "inputs": prompt_tokens,
            "parameters": {
                "max_new_tokens": 20,  # Adjust as needed
            },
        }

        # Make a synchronous request to the model endpoint
        response = requests.post(endpoint_url, headers=headers, json=data)
        if response.status_code == 200:
            response_data = response.json()
            completion = response_data.get('generated_text', '')
        else:
            completion = "Error: Unable to get response"

        # Append the result
        results.append({"video_id": video_id, "completion": completion})

    # Save results to file after processing all tasks in the file
    output_filename = os.path.splitext(os.path.basename(file_path))[0]
    with open(f"processed/{output_filename}_results.json", "w", encoding="utf-8") as f:
        json.dump(results, f, ensure_ascii=False, indent=4)