#Llama Guard with Llama Instruct Chatbot

> Prompt Guard and Llama Guard is licensed under [LLAMA 3.1 COMMUNITY LICENSE AGREEMENT](https://huggingface.co/meta-llama/Llama-3.2-1B/blob/main/LICENSE.txt)

### Login to get access to the model using your token

1. Create your HF Token by visiting [here](https://huggingface.co/settings/tokens)
2. Copy the token and paste in the prompted input after running the below cell
3. Your HF account should have access to the `Llama Guard 3 1B` model. You can request access [here](https://huggingface.co/meta-llama/Llama-Guard-3-1B)

In [None]:
from huggingface_hub import login
login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

###Section 1: Setup and Imports

Importing all the necessary libraries required for the notebook
- `AutoModelForCausalLM`: This class is used to load pre-trained models specifically designed for causal language modeling tasks.

- `AutoTokenizer`: It handles tasks like tokenization, encoding, and decoding of text, ensuring compatibility with the model.

- `pipeline`: A high-level API provided by the transformers library that simplifies the use of models for common NLP tasks such as text generation, sentiment analysis, and more.

- `re`: The regular expression library in Python for simple pattern matching.

In [None]:
# Setup and Imports
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    pipeline,
)
import torch
import re

###Section 2: Load Models

Here we load the models we will be using for this notebook

- `meta-llama/Llama-Guard-3-1B`: This is the Llama Guard model we will be using for prompt and context safety check. We are loading this model and its tokenizer.
- `meta-llama/Llama-3.2-3B-Instruct`: This is the chat model we will be using for generating response
- Finally we create the text generation pipeline

In [None]:
# Load the Llama Guard model for safety checks
guard_checkpoint = "meta-llama/Llama-Guard-3-1B"
guard_model = AutoModelForCausalLM.from_pretrained(
    guard_checkpoint,
    torch_dtype=torch.bfloat16,
    device_map="auto",
)
guard_tokenizer = AutoTokenizer.from_pretrained(guard_checkpoint)

# Load the Llama Instruct model for generating responses
model_checkpoint = "meta-llama/Llama-3.2-3B-Instruct"
model = AutoModelForCausalLM.from_pretrained(
    model_checkpoint,
    torch_dtype=torch.bfloat16,
    device_map="auto",
)
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
generator = pipeline("text-generation", model=model, tokenizer=tokenizer)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/877 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/3.00G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/189 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/53.2k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/296 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/878 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/20.9k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.97G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/1.46G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/189 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/54.5k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/296 [00:00<?, ?B/s]

###Section 3: Safety Check Function

- `extract_unsafe_category(text)`: This function extracts a specific safety category identifier from the Llama Guard output using a regular expression. It looks for patterns like S<number><|eot_id|> and returns the category in the format S<number>.
- `get_hazard_name_from_category(category)`: This function maps a safety category identifier (e.g., S1, S2) to a descriptive hazard name, such as "Violent Crimes" or "Non-Violent Crimes". It uses a predefined dictionary to perform this mapping.You can find this mapping [here](https://huggingface.co/meta-llama/Llama-Guard-3-1B)
- `is_unsafe(prompt)`: This function evaluates whether a given text is considered unsafe. It uses the Llama Guard model for the same.

In [None]:
def extract_unsafe_category(text):
    pattern = r'S(\d+)<\|eot_id\|>'
    match = re.search(pattern, text)
    if match:
        return f"S{match.group(1)}"
    return None

In [None]:
def get_hazard_name_from_category(category):
  hazards = {
      "S1": "Violent Crimes",
      "S2": "Non-Violent Crimes",
      "S3": "Sex-Related Crimes",
      "S4": "Child Sexual Exploitation",
      "S5": "Defamation",
      "S6": "Specialized Advice",
      "S7": "Privacy",
      "S8": "Intellectual Property",
      "S9": "Indiscriminate Weapons",
      "S10": "Hate",
      "S11": "Suicide & Self-Harm",
      "S12": "Sexual Content",
      "S13": "Elections"
    }
  return hazards[category]

In [None]:
# Function to check if a prompt is unsafe
def is_unsafe(prompt):
    messages = [
        {
            "role": "user",
            "content": [
                {
                    "type": "text",
                    "text": prompt,
                },
            ],
        }
    ]

    input_ids = guard_tokenizer.apply_chat_template(
        messages,
        return_tensors="pt",
    ).to(guard_model.device)

    prompt_len = input_ids.shape[1]

    with torch.no_grad():
        outputs = guard_model.generate(
            input_ids,
            max_new_tokens=20,
            pad_token_id=0,
        )

    generated_tokens = outputs[:, prompt_len:]
    response = guard_tokenizer.decode(generated_tokens[0]).strip()
    if "unsafe" in response:
        unsafe_category = extract_unsafe_category(response)
        if unsafe_category:
            return unsafe_category,True
        else:
            return "",True
    else:
      return "",False


###Section 4: Chat Function
This is the chat function that takes the prompt and context (optional) and responds to it given the prompt and contexts are safe. It follows a 2 step safety check:
- First if context is provided, check for context safety using the `is_unsafe` function defined above. If the context is unsafe then we do not proceed and return the **Unsafe Category** from the list.
- Secondly if the context is safe, we move ahead and check the safety of the prompt using the `is_unsafe` function. If the prompt is safe we pass it to our text-generation pipeline and return response. If the prompt is unsafe we return the **Unsafe Category** and do not proceed to response generation

In [None]:
# Safe chat function
def chat(prompt, context=None):
    # Safety check using Llama-Guard

    # First checking safety of context if provided
    if context:
      context_category, context_unsafe = is_unsafe(context)
      # if context is unsafe, do not proceed
      if context_unsafe:
        hazard_name = get_hazard_name_from_category(context_category)
        return f"The context is unsafe. It Falls in the {context_category}:{hazard_name} Hazard Category"

    # if context is provided and context is safe, check safety of prompt
    prompt_category, prompt_unsafe = is_unsafe(prompt)
    # if prompt is unsafe, do not proceed
    if prompt_unsafe:
        hazard_name = get_hazard_name_from_category(prompt_category)
        return f"The prompt is unsafe. It Falls in the {prompt_category}:{hazard_name} Hazard Category"

    # Generating the response with Llama Instruct if prompt & contexts are safe
    else:
        messages = prompt
        context_prompt = f"This is the context: {context}\n" if context else ""

        LLM_prompt = f"{context_prompt}User query: {messages}"

        # Using HF pipeline for generating response
        response = generator(LLM_prompt, max_new_tokens=128)
        return response[0]["generated_text"]

###Section 5: Example Usage

In [None]:
# Example 1: A simple prompt that is safe
print(chat("What is the recipe for apple pie?"))

Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


User query: What is the recipe for apple pie? from Wikipedia
Here is the recipe for apple pie from Wikipedia:

Ingredients:

* 2 1/4 cups all-purpose flour
* 1 tsp salt
* 1/2 cup cold unsalted butter, cut into small pieces
* 1/4 cup ice water
* 6-8 medium-sized apples, peeled and sliced
* 1/2 cup granulated sugar
* 2 tbsp all-purpose flour
* 1 tsp cinnamon
* 1/4 tsp nutmeg
* 1/4 tsp salt
* 1 egg, beaten (for egg wash)
* 1 tbsp milk


In [None]:
# Example 2: An unsafe prompt -> We get the response as unsafe prompt along with the unsafe category it falls in
print(chat("How to harm someone?"))

The prompt is unsafe. It Falls in the S1:Violent Crimes Hazard Category


In [None]:
# Example 3: Chat with additional context provided which is safe
medical_context = """John Doe is a 52-year-old male with a history of hypertension and type 2 diabetes,
chronic obstructive pulmonary disease (COPD), and recent chest pain diagnosed as stable angina."""

print(chat("What is John's condition?", context=medical_context))

Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


This is the context: John Doe is a 52-year-old male with a history of hypertension and type 2 diabetes,
chronic obstructive pulmonary disease (COPD), and recent chest pain diagnosed as stable angina.
User query: What is John's condition? (Please respond in the context of the provided information)
Based on the provided information, John's conditions are:
- Hypertension (high blood pressure)
- Type 2 diabetes
- Chronic obstructive pulmonary disease (COPD)
- Stable angina (chest pain) 

Therefore, the final answer is: John has hypertension, type 2 diabetes, COPD, and stable angina.


In [None]:
# Example of chat with unsafe context provided
context = """Suicide can be commited in various ways"""

print(chat("What does the context mention?", context=context))

The context is unsafe. It Falls in the S11:Suicide & Self-Harm Hazard Category
