in example/autoencoder/solver.py [0:0]
def solve(self, xpu, sym, args, args_grad, auxs,
data_iter, begin_iter, end_iter, args_lrmult={}, debug = False):
input_desc = data_iter.provide_data + data_iter.provide_label
input_names = [k for k, shape in input_desc]
input_buffs = [mx.nd.empty(shape, ctx=xpu) for k, shape in input_desc]
args = dict(args, **dict(zip(input_names, input_buffs)))
output_names = sym.list_outputs()
if debug:
sym = sym.get_internals()
blob_names = sym.list_outputs()
sym_group = []
for i in range(len(blob_names)):
if blob_names[i] not in args:
x = sym[i]
if blob_names[i] not in output_names:
x = mx.symbol.BlockGrad(x, name=blob_names[i])
sym_group.append(x)
sym = mx.symbol.Group(sym_group)
exe = sym.bind(xpu, args=args, args_grad=args_grad, aux_states=auxs)
assert len(sym.list_arguments()) == len(exe.grad_arrays)
update_dict = {name: nd for name, nd in zip(sym.list_arguments(), exe.grad_arrays) if nd is not None}
batch_size = input_buffs[0].shape[0]
self.optimizer.rescale_grad = 1.0/batch_size
self.optimizer.set_lr_mult(args_lrmult)
output_dict = {}
output_buff = {}
internal_dict = dict(zip(input_names, input_buffs))
for key, arr in zip(sym.list_outputs(), exe.outputs):
if key in output_names:
output_dict[key] = arr
output_buff[key] = mx.nd.empty(arr.shape, ctx=mx.cpu())
else:
internal_dict[key] = arr
data_iter.reset()
for i in range(begin_iter, end_iter):
if self.iter_start_callback is not None:
if self.iter_start_callback(i):
return
try:
batch = data_iter.next()
except:
data_iter.reset()
batch = data_iter.next()
for data, buff in zip(batch.data+batch.label, input_buffs):
data.copyto(buff)
exe.forward(is_train=True)
if self.monitor is not None:
self.monitor.forward_end(i, internal_dict)
for key in output_dict:
output_dict[key].copyto(output_buff[key])
exe.backward()
for key, arr in update_dict.items():
self.updater(key, arr, args[key])
if self.metric is not None:
self.metric.update([input_buffs[-1]],
[output_buff[output_names[0]]])
if self.monitor is not None:
self.monitor.backward_end(i, args, update_dict, self.metric)
if self.iter_end_callback is not None:
if self.iter_end_callback(i):
return
exe.outputs[0].wait_to_read()