in sagemaker/source/dl_utils/inference.py [0:0]
def model_fn(model_dir):
with open(os.path.join(model_dir, "output", "net.pth"), 'rb') as f:
model_info = torch.load(f)
pre_trained_model = model_info["net"]
sensor_headers = model_info["sensor_headers"]
fc_hidden_units = model_info["fc_hidden_units"]
conv_channels = model_info["conv_channels"]
net = Network(num_features=len(sensor_headers),
fc_hidden_units=fc_hidden_units,
conv_channels=conv_channels,
dropout_strength=0)
net_dict = net.state_dict()
weight_dict = {}
for key, value in net_dict.items():
if key not in pre_trained_model:
key = "module." + key
weight_dict[key] = pre_trained_model[key]
for key, value in weight_dict.items():
net_dict[key] = value
print("Net loaded")
net = net.to(device)
return net