in code/pretrained_model_with_debugger_hook.py [0:0]
def model_fn(model_dir: str) -> ModelWithDebugHook:
#create model
model = models.resnet18()
#traffic sign dataset has 43 classes
nfeatures = model.fc.in_features
model.fc = nn.Linear(nfeatures, 43)
#load model
weights = torch.load(f'{model_dir}/model/model.pt', map_location=lambda storage, loc: storage)
model.load_state_dict(weights)
model.eval()
model.cpu()
#hook configuration
tensors_output_s3uri = os.environ.get('tensors_output')
if tensors_output_s3uri is None:
logger.warning(
'WARN: Skipping hook configuration as no tensors_output env var provided. '
'Tensors will not be exported'
)
hook = None
else:
save_config = smd.SaveConfig(mode_save_configs={
smd.modes.PREDICT: smd.SaveConfigMode(save_interval=1),
})
hook = CustomHook(
tensors_output_s3uri,
save_config=save_config,
include_regex='.*bn|.*bias|.*downsample|.*ResNet_input|.*image|.*fc_output',
)
#register hook
hook.register_module(model)
#set mode
hook.set_mode(modes.PREDICT)
return ModelWithDebugHook(model, hook)