in recommenders/models/deeprec/models/base_model.py [0:0]
def fit(self, train_file, valid_file, test_file=None):
"""Fit the model with `train_file`. Evaluate the model on valid_file per epoch to observe the training status.
If `test_file` is not None, evaluate it too.
Args:
train_file (str): training data set.
valid_file (str): validation set.
test_file (str): test set.
Returns:
object: An instance of self.
"""
if self.hparams.write_tfevents:
self.writer = tf.compat.v1.summary.FileWriter(
self.hparams.SUMMARIES_DIR, self.sess.graph
)
train_sess = self.sess
for epoch in range(1, self.hparams.epochs + 1):
step = 0
self.hparams.current_epoch = epoch
epoch_loss = 0
train_start = time.time()
for (
batch_data_input,
impression,
data_size,
) in self.iterator.load_data_from_file(train_file):
step_result = self.train(train_sess, batch_data_input)
(_, _, step_loss, step_data_loss, summary) = step_result
if self.hparams.write_tfevents:
self.writer.add_summary(summary, step)
epoch_loss += step_loss
step += 1
if step % self.hparams.show_step == 0:
print(
"step {0:d} , total_loss: {1:.4f}, data_loss: {2:.4f}".format(
step, step_loss, step_data_loss
)
)
train_end = time.time()
train_time = train_end - train_start
if self.hparams.save_model:
if not os.path.exists(self.hparams.MODEL_DIR):
os.makedirs(self.hparams.MODEL_DIR)
if epoch % self.hparams.save_epoch == 0:
save_path_str = join(self.hparams.MODEL_DIR, "epoch_" + str(epoch))
self.saver.save(sess=train_sess, save_path=save_path_str)
eval_start = time.time()
eval_res = self.run_eval(valid_file)
train_info = ",".join(
[
str(item[0]) + ":" + str(item[1])
for item in [("logloss loss", epoch_loss / step)]
]
)
eval_info = ", ".join(
[
str(item[0]) + ":" + str(item[1])
for item in sorted(eval_res.items(), key=lambda x: x[0])
]
)
if test_file is not None:
test_res = self.run_eval(test_file)
test_info = ", ".join(
[
str(item[0]) + ":" + str(item[1])
for item in sorted(test_res.items(), key=lambda x: x[0])
]
)
eval_end = time.time()
eval_time = eval_end - eval_start
if test_file is not None:
print(
"at epoch {0:d}".format(epoch)
+ "\ntrain info: "
+ train_info
+ "\neval info: "
+ eval_info
+ "\ntest info: "
+ test_info
)
else:
print(
"at epoch {0:d}".format(epoch)
+ "\ntrain info: "
+ train_info
+ "\neval info: "
+ eval_info
)
print(
"at epoch {0:d} , train time: {1:.1f} eval time: {2:.1f}".format(
epoch, train_time, eval_time
)
)
if self.hparams.write_tfevents:
self.writer.close()
return self