def get_img_encoder()

in dib/transformers/img.py [0:0]


def get_img_encoder(name):
    name = name.lower()
    if "mlp" in name or "width" in name:
        return MLPEncoder
    elif "cnn" in name:
        return CNNEncoder
    elif "resnet18" in name:
        return partial(
            TorchvisionEncoder,
            TVM=torchvision.models.resnet18,
            is_resnet_converter=True,
        )
    elif "resnet34" in name:
        return partial(
            TorchvisionEncoder,
            TVM=torchvision.models.resnet34,
            is_resnet_converter=True,
        )
    elif "resnet50" in name:
        return partial(
            TorchvisionEncoder,
            TVM=torchvision.models.resnet50,
            is_resnet_converter=True,
        )
    elif "resnet101" in name:
        return partial(
            TorchvisionEncoder,
            TVM=torchvision.models.resnet101,
            is_resnet_converter=True,
        )
    elif "wideresnet101" in name:
        return partial(
            TorchvisionEncoder,
            TVM=torchvision.models.wide_resnet101_2,
            is_resnet_converter=True,
        )
    elif "wideresnet" in name:
        return partial(TorchvisionEncoder, TVM=WideResNet)
    else:
        raise ValueError(f"Unkown name={name}")