in python/flexflow/keras/models/base_model.py [0:0]
def _train(self, epochs, callbacks, eval=False):
if callbacks != None:
for callback in callbacks:
callback.set_model(self)
if callbacks != None:
for callback in callbacks:
callback.on_train_begin()
ts_start = self._ffconfig.get_current_time()
epoch = 0
epoch_flag = True
self.__tracing_id += 1
while (epoch < epochs) and (epoch_flag == True):
if callbacks != None:
for callback in callbacks:
callback.on_epoch_begin(epoch)
for dataloader in self._input_dataloaders:
dataloader.reset()
self._label_dataloader.reset()
self._ffmodel.reset_metrics()
iterations = self._num_samples / self._ffconfig.batch_size
for iter in range(0, int(iterations)):
if callbacks != None:
for callback in callbacks:
callback.on_batch_begin(iter)
for dataloader in self._input_dataloaders:
dataloader.next_batch(self._ffmodel)
self._label_dataloader.next_batch(self._ffmodel)
self._ffconfig.begin_trace(self.__tracing_id)
self._ffmodel.forward()
# for layer in self._layers:
# layer.ffhandle.forward(self._ffmodel)
if eval == False:
self._ffmodel.zero_gradients()
self._ffmodel.backward()
self._ffmodel.update()
else:
self._ffmodel.compute_metrics()
self._ffconfig.end_trace(self.__tracing_id)
if callbacks != None:
for callback in callbacks:
callback.on_batch_end(iter)
if callbacks != None:
for callback in callbacks:
early_stop = callback.on_epoch_end(epoch)
if early_stop == True:
print("Accuracy reaches, now early stop, epoch: %d" %(epoch))
epoch_flag = False
epoch += 1
ts_end = self._ffconfig.get_current_time()
run_time = 1e-6 * (ts_end - ts_start);
print("epochs %d, ELAPSED TIME = %.4fs, interations %d, samples %d, THROUGHPUT = %.2f samples/s\n" %(epochs, run_time, int(iterations), self._num_samples, self._num_samples * epochs / run_time));
if callbacks != None:
for callback in callbacks:
callback.on_train_end()