in sagemaker-python-sdk/dgl_gcmc/train.py [0:0]
def train(args):
print(args)
dataset = MovieLens(
args.data_name,
args.ctx,
use_one_hot_fea=args.use_one_hot_fea,
symm=args.gcn_agg_norm_symm,
test_ratio=args.data_test_ratio,
valid_ratio=args.data_valid_ratio,
)
print("Loading data finished ...\n")
args.src_in_units = dataset.user_feature_shape[1]
args.dst_in_units = dataset.movie_feature_shape[1]
args.rating_vals = dataset.possible_rating_values
### build the net
net = Net(args=args)
net.initialize(init=mx.init.Xavier(factor_type="in"), ctx=args.ctx)
net.hybridize()
nd_possible_rating_values = mx.nd.array(
dataset.possible_rating_values, ctx=args.ctx, dtype=np.float32
)
rating_loss_net = gluon.loss.SoftmaxCELoss()
rating_loss_net.hybridize()
trainer = gluon.Trainer(
net.collect_params(), args.train_optimizer, {"learning_rate": args.train_lr}
)
print("Loading network finished ...\n")
### perpare training data
train_gt_labels = dataset.train_labels
train_gt_ratings = dataset.train_truths
### prepare the logger
train_loss_logger = MetricLogger(
["iter", "loss", "rmse"],
["%d", "%.4f", "%.4f"],
os.path.join(args.save_dir, "train_loss%d.csv" % args.save_id),
)
valid_loss_logger = MetricLogger(
["iter", "rmse"],
["%d", "%.4f"],
os.path.join(args.save_dir, "valid_loss%d.csv" % args.save_id),
)
test_loss_logger = MetricLogger(
["iter", "rmse"],
["%d", "%.4f"],
os.path.join(args.save_dir, "test_loss%d.csv" % args.save_id),
)
### declare the loss information
best_valid_rmse = np.inf
no_better_valid = 0
best_iter = -1
avg_gnorm = 0
count_rmse = 0
count_num = 0
count_loss = 0
print("Start training ...")
dur = []
for iter_idx in range(1, args.train_max_iter):
if iter_idx > 3:
t0 = time.time()
with mx.autograd.record():
pred_ratings = net(
dataset.train_enc_graph,
dataset.train_dec_graph,
dataset.user_feature,
dataset.movie_feature,
)
loss = rating_loss_net(pred_ratings, train_gt_labels).mean()
loss.backward()
count_loss += loss.asscalar()
gnorm = params_clip_global_norm(net.collect_params(), args.train_grad_clip, args.ctx)
avg_gnorm += gnorm
trainer.step(1.0)
if iter_idx > 3:
dur.append(time.time() - t0)
if iter_idx == 1:
print("Total #Param of net: %d" % (gluon_total_param_num(net)))
print(
gluon_net_info(
net, save_path=os.path.join(args.save_dir, "net%d.txt" % args.save_id)
)
)
real_pred_ratings = (
mx.nd.softmax(pred_ratings, axis=1) * nd_possible_rating_values.reshape((1, -1))
).sum(axis=1)
rmse = mx.nd.square(real_pred_ratings - train_gt_ratings).sum()
count_rmse += rmse.asscalar()
count_num += pred_ratings.shape[0]
if iter_idx % args.train_log_interval == 0:
train_loss_logger.log(
iter=iter_idx, loss=count_loss / (iter_idx + 1), rmse=count_rmse / count_num
)
logging_str = "Iter={}, gnorm={:.3f}, loss={:.4f}, rmse={:.4f}, time={:.4f}".format(
iter_idx,
avg_gnorm / args.train_log_interval,
count_loss / iter_idx,
count_rmse / count_num,
np.average(dur),
)
avg_gnorm = 0
count_rmse = 0
count_num = 0
if iter_idx % args.train_valid_interval == 0:
valid_rmse = evaluate(args=args, net=net, dataset=dataset, segment="valid")
valid_loss_logger.log(iter=iter_idx, rmse=valid_rmse)
logging_str += ",\tVal RMSE={:.4f}".format(valid_rmse)
if valid_rmse < best_valid_rmse:
best_valid_rmse = valid_rmse
no_better_valid = 0
best_iter = iter_idx
net.save_parameters(
filename=os.path.join(
args.save_dir, "best_valid_net{}.params".format(args.save_id)
)
)
test_rmse = evaluate(args=args, net=net, dataset=dataset, segment="test")
best_test_rmse = test_rmse
test_loss_logger.log(iter=iter_idx, rmse=test_rmse)
logging_str += ", Test RMSE={:.4f}".format(test_rmse)
else:
no_better_valid += 1
if (
no_better_valid > args.train_early_stopping_patience
and trainer.learning_rate <= args.train_min_lr
):
logging.info("Early stopping threshold reached. Stop training.")
break
if no_better_valid > args.train_decay_patience:
new_lr = max(
trainer.learning_rate * args.train_lr_decay_factor, args.train_min_lr
)
if new_lr < trainer.learning_rate:
logging.info("\tChange the LR to %g" % new_lr)
trainer.set_learning_rate(new_lr)
no_better_valid = 0
if iter_idx % args.train_log_interval == 0:
print(logging_str)
print(
"Best Iter Idx={}, Best Valid RMSE={:.4f}, Best Test RMSE={:.4f}".format(
best_iter, best_valid_rmse, best_test_rmse
)
)
train_loss_logger.close()
valid_loss_logger.close()
test_loss_logger.close()