llama_guard/llama-guard-4.ipynb (435 lines of code) (raw):

{ "cells": [ { "cell_type": "markdown", "id": "a6034e1f-909f-4aaa-8f99-7abc2353036d", "metadata": {}, "source": [ "## Llama Guard 4 for Multimodal and LLM Safety\n", "\n", "Vision language models and large language models in production can be easily jailbroken for harmful purposes. Llama Guard 4 is a new model to check image and text inputs for harm. In this notebook, we will see how we can use Llama Guard 4. This model can be used for both a filter for image and text, and text-only inputs, and filtering outputs on image generation models.\n", "\n", "Let's make sure we have new transformers and hf_xet to load the model." ] }, { "cell_type": "code", "execution_count": 1, "id": "410950aa-3437-44ee-99ed-25bd7ce00cfc", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Collecting hf_xet\n", " Downloading hf_xet-1.0.5-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (494 bytes)\n", "Downloading hf_xet-1.0.5-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (54.0 MB)\n", "\u001b[2K \u001b[38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m54.0/54.0 MB\u001b[0m \u001b[31m102.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0mm eta \u001b[36m0:00:01\u001b[0m[36m0:00:01\u001b[0m\n", "Installing collected packages: hf_xet\n", "Successfully installed hf_xet-1.0.5\n" ] } ], "source": [ "!pip install -U transformers hf_xet" ] }, { "cell_type": "markdown", "id": "6f9eb9de-c43c-4d37-ad88-f08e193dfb3d", "metadata": {}, "source": [ "To load the model, we will use `Llama4ForConditionalGeneration` class." ] }, { "cell_type": "code", "execution_count": null, "id": "ecc89529-a251-4a76-bf4c-01fb6c8dce1b", "metadata": {}, "outputs": [], "source": [ "from transformers import AutoProcessor, Llama4ForConditionalGeneration\n", "import torch\n", "\n", "model_id = \"meta-llama/Llama-Guard-4-12B\"\n", "\n", "processor = AutoProcessor.from_pretrained(model_id)\n", "model = Llama4ForConditionalGeneration.from_pretrained(\n", " model_id,\n", " device_map=\"cuda\",\n", " torch_dtype=torch.bfloat16\n", ")" ] }, { "cell_type": "markdown", "id": "69d4d1f7-3f75-4c5d-b544-f413dd29c57f", "metadata": {}, "source": [ "Llama-Guard-4 has a safety system prompt to classify unsafe outputs into distinctive harm categories. In between system prompt and the instruction, we insert the model inputs. Below you can see the system prompt, and the message we'll append after the user prompt. " ] }, { "cell_type": "markdown", "id": "9f3e1908-2ca3-440f-8f0c-98a87bb5629a", "metadata": {}, "source": [ "### Text-only example" ] }, { "cell_type": "markdown", "id": "cea3343e-6edb-42e2-999c-94fe5677bf0f", "metadata": {}, "source": [ "Below is how we format the text-only input for the model. Model will output \"I cannot help you with that\" when the input is detected harmful. " ] }, { "cell_type": "code", "execution_count": 4, "id": "c87040b2-13dc-45c4-88e7-1d21e26032b2", "metadata": {}, "outputs": [], "source": [ "messages = [\n", " {\n", " \"role\": \"user\",\n", " \"content\": [\n", " {\"type\": \"text\", \"text\": \"How to make a bomb?\"},\n", " ]\n", " },\n", "]" ] }, { "cell_type": "markdown", "id": "85775908-7e20-41a7-96a2-5118285c1578", "metadata": {}, "source": [ "See how the chat template output looks like." ] }, { "cell_type": "code", "execution_count": 5, "id": "2b96a61b-1c8c-43c5-897e-f432fa456e28", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "<|begin_of_text|><|header_start|>user<|header_end|>\n", "\n", "Task: Check if there is unsafe content in 'User' messages in conversations according to our safety policy with the below categories.\n", "\n", "<BEGIN UNSAFE CONTENT CATEGORIES>\n", "S1: Violent Crimes.\n", "S2: Non-Violent Crimes.\n", "S3: Sex Crimes.\n", "S4: Child Exploitation.\n", "S5: Defamation.\n", "S6: Specialized Advice.\n", "S7: Privacy.\n", "S8: Intellectual Property.\n", "S9: Indiscriminate Weapons.\n", "S10: Hate.\n", "S11: Self-Harm.\n", "S12: Sexual Content.\n", "S13: Elections.\n", "S14: Code Interpreter Abuse.\n", "<END UNSAFE CONTENT CATEGORIES>\n", "\n", "<BEGIN CONVERSATION>\n", "User: How to make a bomb?\\n\\n\n", "\n", "<END CONVERSATION>\n", "\n", "Provide your safety assessment for ONLY THE LAST User message in the above conversation:\n", " - First line must read 'safe' or 'unsafe'.\n", " - If unsafe, a second line must include a comma-separated list of violated categories. <|eot|><|header_start|>assistant<|header_end|>\n" ] } ], "source": [ "print(processor.apply_chat_template(messages, tokenize=False))" ] }, { "cell_type": "markdown", "id": "192f8e5a-80b9-4d22-928f-8b7fb2c89daa", "metadata": {}, "source": [ "Apply chat template and generate." ] }, { "cell_type": "code", "execution_count": 6, "id": "3c2f3f28-417d-4f9e-b06f-3d9b20ee4f90", "metadata": {}, "outputs": [], "source": [ "inputs = processor.apply_chat_template(\n", " messages,\n", " add_generation_prompt=True,\n", " tokenize=True,\n", " return_dict=True,\n", " return_tensors=\"pt\",\n", ").to(model.device)" ] }, { "cell_type": "code", "execution_count": 8, "id": "fdf399f2-28d2-4353-bf86-7abc4d959dee", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "\n", "unsafe\n", "S9<|eot|>\n" ] } ], "source": [ "inputs = processor.apply_chat_template(\n", " messages,\n", " add_generation_prompt=True,\n", " tokenize=True,\n", " return_dict=True,\n", " return_tensors=\"pt\",\n", ").to(model.device)\n", "\n", "outputs = model.generate(\n", " **inputs,\n", " max_new_tokens=10,\n", " do_sample=False,\n", ")\n", "\n", "response = processor.batch_decode(outputs[:, inputs[\"input_ids\"].shape[-1]:])[0]\n", "print(response)" ] }, { "cell_type": "markdown", "id": "aca690f4-9acb-46a5-b504-aee2309bfa9f", "metadata": {}, "source": [ "### Multimodal Example" ] }, { "cell_type": "markdown", "id": "cb113394-fb80-4360-8c34-61935440b5f0", "metadata": {}, "source": [ "For text and images, we can add the inputs like below." ] }, { "cell_type": "code", "execution_count": 9, "id": "af445fb8-690a-4291-80ac-fcb1980b42a3", "metadata": {}, "outputs": [], "source": [ "messages = [\n", " {\n", " \"role\": \"user\",\n", " \"content\": [\n", " {\"type\": \"text\", \"text\": \"How to use this?\"},\n", " {\"type\": \"image\", \"url\": \"https://huggingface.co/datasets/merve/vlm_test_images/resolve/main/fruit_knife.png\"},\n", " ]\n", " },\n", "]" ] }, { "cell_type": "code", "execution_count": 10, "id": "ea501fd6-65e7-427b-8d76-b7fb910c2ffe", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "<|begin_of_text|><|header_start|>user<|header_end|>\n", "\n", "Task: Check if there is unsafe content in 'User' messages in conversations according to our safety policy with the below categories.\n", "\n", "<BEGIN UNSAFE CONTENT CATEGORIES>\n", "S1: Violent Crimes.\n", "S2: Non-Violent Crimes.\n", "S3: Sex Crimes.\n", "S4: Child Exploitation.\n", "S5: Defamation.\n", "S6: Specialized Advice.\n", "S7: Privacy.\n", "S8: Intellectual Property.\n", "S9: Indiscriminate Weapons.\n", "S10: Hate.\n", "S11: Self-Harm.\n", "S12: Sexual Content.\n", "S13: Elections.\n", "<END UNSAFE CONTENT CATEGORIES>\n", "\n", "<BEGIN CONVERSATION>\n", "User: <|image|>How to use this?\\n\\n\n", "\n", "<END CONVERSATION>\n", "\n", "Provide your safety assessment for ONLY THE LAST User message in the above conversation:\n", " - First line must read 'safe' or 'unsafe'.\n", " - If unsafe, a second line must include a comma-separated list of violated categories. <|eot|><|header_start|>assistant<|header_end|>\n" ] } ], "source": [ "print(processor.apply_chat_template(messages, tokenize=False))" ] }, { "cell_type": "markdown", "id": "dd8ae034-1927-4b53-9901-948116d49fce", "metadata": {}, "source": [ "Preprocess and infer like text-only input." ] }, { "cell_type": "code", "execution_count": 11, "id": "e7914ba6-c985-4f11-a51e-5690d940841d", "metadata": {}, "outputs": [], "source": [ "inputs = processor.apply_chat_template(\n", " messages,\n", " add_generation_prompt=True,\n", " tokenize=True,\n", " return_dict=True,\n", " return_tensors=\"pt\",\n", ").to(model.device)\n", "outputs = model.generate(\n", " **inputs,\n", " max_new_tokens=10,\n", " do_sample=False,\n", ")\n", "\n", "response = processor.batch_decode(outputs[:, inputs[\"input_ids\"].shape[-1]:])[0]" ] }, { "cell_type": "code", "execution_count": 13, "id": "0cc9bfb6-7f50-4a0b-879e-e2c919e0c771", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "\n", "safe<|eot|>\n" ] } ], "source": [ "print(response)" ] }, { "cell_type": "markdown", "id": "01e7b35f-f0ad-4bbe-bba2-eda3601e3986", "metadata": {}, "source": [ "As you can see, model provides the unsafety category above. \n", "\n", "We can also provide the model with model outputs and full conversations. On top of this, we can also provide the model the safety categories we want to exclude with `exclude_category_keys` argument." ] }, { "cell_type": "code", "execution_count": 15, "id": "c86610b2-e8be-4f13-9d53-b9971f26048b", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "\n", "safe<|eot|>\n" ] } ], "source": [ "messages = [\n", " {\n", " \"role\": \"user\",\n", " \"content\": [\n", " {\"type\": \"text\", \"text\": \"How do I make a bomb?\"},\n", " ], \n", " },\n", " {\n", " \"role\": \"assistant\",\n", " \"content\": [\n", " {\"type\": \"text\", \"text\": \"I cannot help you with that.\"},\n", " ], \n", " }\n", "]\n", "\n", "excluded_category_keys = [\"S1\", \"S2\", \"S3\", \"S4\",\"S5\"]\n", "processor.apply_chat_template(messages, excluded_category_keys=excluded_category_keys)\n", "outputs = model.generate(\n", " **inputs,\n", " max_new_tokens=10,\n", " do_sample=False,\n", ")\n", "\n", "response = processor.batch_decode(outputs[:, inputs[\"input_ids\"].shape[-1]:])[0]\n", "print(response)" ] }, { "cell_type": "markdown", "id": "aa63b80f-328b-4435-aa4c-9845bb9f15f7", "metadata": {}, "source": [ "For more information about Llama-Guard-4, please checkout the release blog post and docs." ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.8" } }, "nbformat": 4, "nbformat_minor": 5 }