video_processing/modules/nsfw.py (16 lines of code) (raw):
from transformers import AutoModelForImageClassification, AutoImageProcessor
import torch
MODEL_ID = "Falconsai/nsfw_image_detection"
MODEL, PROCESSOR = None, None
def load_nsfw(device):
global MODEL, PROCESSOR
MODEL = AutoModelForImageClassification.from_pretrained(MODEL_ID).eval().to(device)
PROCESSOR = AutoImageProcessor.from_pretrained(MODEL_ID)
@torch.no_grad()
def run_nsfw(image):
if not isinstance(image, list):
image = [image]
inputs = PROCESSOR(images=image, return_tensors="pt").to(MODEL.device)
outputs = MODEL(**inputs).logits
predicted_labels = outputs.argmax(-1)
return [MODEL.config.id2label[p.cpu().item()] for p in predicted_labels]