in dib/transformers/img.py [0:0]
def get_img_encoder(name):
name = name.lower()
if "mlp" in name or "width" in name:
return MLPEncoder
elif "cnn" in name:
return CNNEncoder
elif "resnet18" in name:
return partial(
TorchvisionEncoder,
TVM=torchvision.models.resnet18,
is_resnet_converter=True,
)
elif "resnet34" in name:
return partial(
TorchvisionEncoder,
TVM=torchvision.models.resnet34,
is_resnet_converter=True,
)
elif "resnet50" in name:
return partial(
TorchvisionEncoder,
TVM=torchvision.models.resnet50,
is_resnet_converter=True,
)
elif "resnet101" in name:
return partial(
TorchvisionEncoder,
TVM=torchvision.models.resnet101,
is_resnet_converter=True,
)
elif "wideresnet101" in name:
return partial(
TorchvisionEncoder,
TVM=torchvision.models.wide_resnet101_2,
is_resnet_converter=True,
)
elif "wideresnet" in name:
return partial(TorchvisionEncoder, TVM=WideResNet)
else:
raise ValueError(f"Unkown name={name}")