in sagemaker-voice-classification/notebook/train.py [0:0]
def model_fn(model_dir):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = NetM3()
if torch.cuda.device_count() > 1:
print("Gpu count: {}".format(torch.cuda.device_count()))
model = nn.DataParallel(model)
with open(os.path.join(model_dir, "model.pth"), "rb") as f:
model.load_state_dict(torch.load(f))
return model.to(device)