in docker_images/timm/app/pipelines/image_classification.py [0:0]
def __init__(self, model_id: str):
self.model = timm.create_model(f"hf_hub:{model_id}", pretrained=True)
self.transform = create_transform(
**resolve_model_data_config(self.model, use_test_size=True)
)
self.top_k = min(self.model.num_classes, 5)
self.model.eval()
self.dataset_info = None
label_names = self.model.pretrained_cfg.get("label_names", None)
label_descriptions = self.model.pretrained_cfg.get("label_descriptions", None)
if label_names is None:
# if no labels added to config, use imagenet labeller in timm
imagenet_subset = infer_imagenet_subset(self.model)
if imagenet_subset:
self.dataset_info = ImageNetInfo(imagenet_subset)
else:
# fallback label names
label_names = [f"LABEL_{i}" for i in range(self.model.num_classes)]
if self.dataset_info is None:
self.dataset_info = CustomDatasetInfo(
label_names=label_names,
label_descriptions=label_descriptions,
)