def load_watermark_laion()

in video_processing/modules/watermark_laion.py [0:0]


def load_watermark_laion(device, model_path):
    global MODEL, TRANSFORMS
    TRANSFORMS = T.Compose(
        [
            T.Resize((256, 256)),
            T.ToTensor(),
            T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ]
    )

    MODEL = timm.create_model("efficientnet_b3", pretrained=False, num_classes=2)
    MODEL.classifier = nn.Sequential(
        nn.Linear(in_features=1536, out_features=625),
        nn.ReLU(),
        nn.Dropout(p=0.3),
        nn.Linear(in_features=625, out_features=256),
        nn.ReLU(),
        nn.Linear(in_features=256, out_features=2),
    )
    if model_path is None:
        model_path = hf_hub_download("finetrainers/laion-watermark-detection", "watermark_model_v1.pt")
    state_dict = torch.load(model_path, weights_only=True)
    MODEL.load_state_dict(state_dict)
    MODEL.eval().to(device)