in lab/02-Training/wind_turbine.py [0:0]
def train(args):
best_of_the_best = (0,-1)
best_loss = 10000000
num_epochs = args.num_epochs
batch_size = args.batch_size
X = load_data(args.train)
criterion = nn.MSELoss()
kf = KFold(n_splits=args.k_fold_splits, shuffle=True)
num_features = X.shape[1]
for i, indexes in enumerate(kf.split(X)):
# skip other Ks if fixed was informed
if args.k_index_only >= 0 and args.k_index_only != i: continue
train_index, test_index = indexes
print("Test dataset proportion: %.02f%%" % (len(test_index)/len(train_index) * 100))
X_train, X_test = X[train_index], X[test_index]
X_train = torch.from_numpy(X_train).float().to(device)
X_test = torch.from_numpy(X_test).float().to(device)
train_dataset = torch.utils.data.TensorDataset(X_train, X_train)
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size)
test_dataset = torch.utils.data.TensorDataset(X_test, X_test)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size)
model = create_model(num_features, args.dropout_rate)
model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
# Instantiate model
# Training loop
for epoch in range(num_epochs):
start_time = time.time()
train_loss, test_loss = train_epoch( optimizer, criterion, epoch, model, train_dataloader, test_dataloader)
elapsed_time = (time.time() - start_time)
print("k=%d; epoch=%d; train_loss=%.3f; test_loss=%.3f; elapsed_time=%.3fs" % (i, epoch, train_loss, test_loss, elapsed_time))
if test_loss < best_loss:
torch.save(model.state_dict(), os.path.join(args.output_data_dir,'model_state.pth'))
best_loss = test_loss
if best_loss < best_of_the_best[0]:
best_of_the_best = (best_loss, i)
print("\nBest model: best_mse=%f;" % best_loss)
model = create_model(num_features, args.dropout_rate)
model.load_state_dict( torch.load(os.path.join(args.output_data_dir, "model_state.pth")) )
os.mkdir(os.path.join(args.model_dir,'code'))
shutil.copyfile(__file__, os.path.join(args.model_dir, 'code/inference.py'))
torch.save(model, os.path.join(args.model_dir, "model.pth"))