def model_fn()

in code/pretrained_model_with_debugger_hook.py [0:0]


def model_fn(model_dir: str) -> ModelWithDebugHook:
    #create model    
    model = models.resnet18()

    #traffic sign dataset has 43 classes   
    nfeatures = model.fc.in_features
    model.fc = nn.Linear(nfeatures, 43)

    #load model
    weights = torch.load(f'{model_dir}/model/model.pt', map_location=lambda storage, loc: storage)
    model.load_state_dict(weights)

    model.eval()
    model.cpu()

    #hook configuration
    tensors_output_s3uri = os.environ.get('tensors_output')
    if tensors_output_s3uri is None:
        logger.warning(
            'WARN: Skipping hook configuration as no tensors_output env var provided. '
            'Tensors will not be exported'
        )
        hook = None
    else:
        save_config = smd.SaveConfig(mode_save_configs={
            smd.modes.PREDICT: smd.SaveConfigMode(save_interval=1),
        })

        hook = CustomHook(
            tensors_output_s3uri,
            save_config=save_config,
            include_regex='.*bn|.*bias|.*downsample|.*ResNet_input|.*image|.*fc_output',
        )

        #register hook
        hook.register_module(model) 

        #set mode
        hook.set_mode(modes.PREDICT)

    return ModelWithDebugHook(model, hook)