in #U57fa#U7840#U6559#U7a0b/A2-#U795e#U7ecf#U7f51#U7edc#U57fa#U672c#U539f#U7406/#U7b2c7#U6b65 - #U6df1#U5ea6#U795e#U7ecf#U7f51#U7edc/src/ch15-DnnOptimization/MiniFramework/NeuralNet_4_1.py [0:0]
def train(self, dataReader, checkpoint=0.1, need_test=True):
t0 = time.time()
self.loss_trace = TrainingHistory_2_4()
self.lossFunc = LossFunction_1_1(self.hp.net_type)
# if num_example=200, batch_size=10, then iteration=200/10=20
if self.hp.batch_size == -1 or self.hp.batch_size > dataReader.num_train:
self.hp.batch_size = dataReader.num_train
# end if
max_iteration = math.ceil(dataReader.num_train / self.hp.batch_size)
checkpoint_iteration = (int)(max_iteration * checkpoint)
need_stop = False
for epoch in range(self.hp.max_epoch):
for iteration in range(max_iteration):
# get x and y value for one sample
batch_x, batch_y = dataReader.GetBatchTrainSamples(self.hp.batch_size, iteration)
# for optimizers which need pre-update weights
if self.hp.optimizer_name == OptimizerName.Nag:
self.__pre_update()
# get z from x,y
self.__forward(batch_x, train=True)
# calculate gradient of w and b
self.__backward(batch_x, batch_y)
# final update w,b
self.__update()
total_iteration = epoch * max_iteration + iteration
if (total_iteration+1) % checkpoint_iteration == 0:
#self.save_parameters()
need_stop = self.CheckErrorAndLoss(dataReader, batch_x, batch_y, epoch, total_iteration)
if need_stop:
break
#end if
# end for
#self.save_parameters() # 这里会显著降低性能,因为频繁的磁盘操作,而且可能会有文件读写失败
dataReader.Shuffle()
if need_stop:
break
# end if
# end for
self.CheckErrorAndLoss(dataReader, batch_x, batch_y, epoch, total_iteration)
t1 = time.time()
print("time used:", t1 - t0)
self.save_parameters()
if need_test:
print("testing...")
accuracy = self.Test(dataReader)
print(accuracy)