in video_processing/modules/watermark_laion.py [0:0]
def load_watermark_laion(device, model_path):
global MODEL, TRANSFORMS
TRANSFORMS = T.Compose(
[
T.Resize((256, 256)),
T.ToTensor(),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
MODEL = timm.create_model("efficientnet_b3", pretrained=False, num_classes=2)
MODEL.classifier = nn.Sequential(
nn.Linear(in_features=1536, out_features=625),
nn.ReLU(),
nn.Dropout(p=0.3),
nn.Linear(in_features=625, out_features=256),
nn.ReLU(),
nn.Linear(in_features=256, out_features=2),
)
if model_path is None:
model_path = hf_hub_download("finetrainers/laion-watermark-detection", "watermark_model_v1.pt")
state_dict = torch.load(model_path, weights_only=True)
MODEL.load_state_dict(state_dict)
MODEL.eval().to(device)