in torchrecipes/text/doc_classification/train_app.py [0:0]
def get_lightning_module(self) -> LightningModule:
# check whether this is the OSS or internal transform
# the OSS TransformConf has a `label_transform` field whereas the
# internal transforms don't
if hasattr(self.transform_conf, "label_transform"):
num_classes = len(
# pyre-ignore[16]: Subclass of `TransformConf` has relevant attribute
self.transform_conf.label_transform.label_names
)
transform_conf = (
# pyre-ignore[16]: Subclass of `TransformConf` has relevant attribute
self.transform_conf.transform
)
else:
# pyre-ignore[16]: Subclass of `TransformConf` has relevant attribute
num_classes = len(self.transform_conf.label_names)
transform_conf = self.transform_conf
return hydra.utils.instantiate(
self.module_conf,
transform=transform_conf,
num_classes=num_classes,
_recursive_=False,
)