in julia/src/model.jl [363:608]
function fit(self::FeedForward, optimizer::AbstractOptimizer, data::AbstractDataProvider;
kwargs...)
opts = TrainingOptions(; kwargs...)
opts.verbosity >= 1 && info("Start training on $(self.ctx)")
batch_size = get_batch_size(data)
num_dev = length(self.ctx)
slices = _split_inputs(batch_size, num_dev)
opts.verbosity >= 2 && info("Initializing parameters...")
arg_names, param_names, aux_names = _init_model(self, data, opts.initializer, opts.force_init)
kvstore = opts.kvstore
if isa(kvstore, Symbol)
opts.verbosity >= 2 && info("Creating KVStore...")
kvstore = _create_kvstore(kvstore, length(self.ctx), self.arg_params, opts.verbosity)
end
update_on_kvstore = true
if isa(kvstore, Void) || ismatch(r"local_allreduce", string(get_type(kvstore)))
update_on_kvstore = false
end
freeze_names = Symbol[]
for (attr, value) in list_all_attr(self.arch)
sattr = string(attr)
if endswith(sattr, "grad") && value == "freeze"
push!(freeze_names, Symbol(sattr[1:end-5]))
end
end
freeze_idx = filter(i -> in(param_names[i], freeze_names), 1:length(param_names))
grad_req = Dict{Symbol,GRAD_REQ}()
for param in param_names
if in(param, freeze_names)
grad_req[param] = GRAD_NOP
else
grad_req[param] = GRAD_WRITE
end
end
train_execs = Array{Executor}(num_dev)
for i = 1:num_dev
data_shapes = Dict(map((x) -> x[1] => tuple(x[2][1:end-1]...,length(slices[i])), provide_data(data)))
label_shapes = Dict(map((x) -> x[1] => tuple(x[2][1:end-1]...,length(slices[i])), provide_label(data)))
train_execs[i] = simple_bind(self.arch, self.ctx[i]; grad_req=grad_req, data_shapes..., label_shapes...)
dbg_str = mx.debug_str(train_execs[i])
opts.verbosity >= 2 && info(string("TempSpace: ", split(dbg_str, ['\n'])[end-2]..., " on ", self.ctx[i]))
copy_params_from(train_execs[i], self.arg_params, self.aux_params)
end
data_names = [x[1] for x in provide_data(data)]
label_names = [x[1] for x in provide_label(data)]
data_arrays = [SlicedNDArray[(slices[i], exec.arg_dict[name]) for (i,exec) in enumerate(train_execs)]
for name in data_names]
label_arrays = [SlicedNDArray[(slices[i], exec.arg_dict[name]) for (i,exec) in enumerate(train_execs)]
for name in label_names]
param_idx = filter(i -> in(arg_names[i], param_names), 1:length(arg_names))
param_arrays = [NDArray[exec.arg_arrays[i] for exec in train_execs] for i in param_idx]
grad_arrays = [NDArray[exec.grad_arrays[i] for exec in train_execs] for i in param_idx]
aux_arrays = [NDArray[exec.aux_arrays[i] for exec in train_execs] for i = 1:length(aux_names)]
op_state = OptimizationState(batch_size)
iszero(optimizer.scale) && (optimizer.scale = 1 / batch_size)
if !update_on_kvstore
updater = getupdater(optimizer)
end
if !isa(kvstore, Void)
if update_on_kvstore
set_optimizer(kvstore, optimizer)
end
opts.verbosity >= 2 && info("Initializing KVStore...")
for idx = 1:length(param_arrays)
param_on_devs = param_arrays[idx]
init!(kvstore, idx, self.arg_params[param_names[idx]])
if update_on_kvstore
pull!(kvstore, idx, param_on_devs, priority=-idx)
end
end
end
output_shapes = [tuple(size(x)[1:end-1]...,batch_size) for x in train_execs[1].outputs]
cpu_dev = Context(CPU)
cpu_output_arrays = [empty(shape, cpu_dev) for shape in output_shapes]
cpu_label_arrays = [empty(shape, cpu_dev) for (name,shape) in provide_label(data)]
_invoke_callbacks(self, opts.callbacks, op_state, AbstractEpochCallback)
opts.verbosity >= 2 && info("Start training...")
for i_epoch = 1:opts.n_epoch
time_start = time()
reset!(opts.eval_metric)
op_state.curr_epoch = i_epoch
op_state.curr_batch = 0
_invoke_callbacks(self, opts.callbacks, op_state, AbstractBatchCallback)
for batch in eachbatch(data)
load_data!(data, batch, data_arrays)
load_label!(data, batch, label_arrays)
for (texec, islice) in zip(train_execs, slices)
forward(texec, is_train=true)
for (cpu_out, dev_out) in zip(cpu_output_arrays, texec.outputs)
copy!(slice(cpu_out, islice), dev_out)
end
backward(texec)
end
op_state.curr_iter += 1
op_state.curr_batch += 1
for idx = 1:length(param_names)
if in(idx, freeze_idx)
continue
end
if !isa(kvstore, Void)
push!(kvstore, idx, grad_arrays[idx], priority=-idx)
if update_on_kvstore
pull!(kvstore, idx, param_arrays[idx], priority=-idx)
else
pull!(kvstore, idx, grad_arrays[idx], priority=-idx)
end
end
if !update_on_kvstore
for i_dev = 1:num_dev
fake_idx = idx * num_dev + i_dev
updater(fake_idx, grad_arrays[idx][i_dev], param_arrays[idx][i_dev])
end
end
end
opts.η_decay == :batch && update!(optimizer.η_sched)
_invoke_callbacks(self, opts.callbacks, op_state, AbstractBatchCallback)
load_label!(data, batch, cpu_label_arrays)
update!(opts.eval_metric, cpu_label_arrays, cpu_output_arrays)
end
time_stop = time()
metric = get(opts.eval_metric)
opts.verbosity >= 2 && info(format("== Epoch {1:0>3d}/{2:0>3d} ==========", i_epoch, opts.n_epoch))
if opts.verbosity >= 3
info("## Training summary")
for (name, value) in metric
info(format("{1:>18s} = {2:.4f}", string(name), value))
end
info(format("{1:>18s} = {2:.4f} seconds", "time", time_stop-time_start))
end
if !isa(opts.eval_data, Void)
@assert(get_batch_size(opts.eval_data) == batch_size)
reset!(opts.eval_metric)
for batch in eachbatch(opts.eval_data)
load_data!(opts.eval_data, batch, data_arrays)
for (texec, islice) in zip(train_execs, slices)
forward(texec, is_train=true)
for (cpu_out, dev_out) in zip(cpu_output_arrays, texec.outputs)
copy!(slice(cpu_out, islice), dev_out)
end
end
load_label!(opts.eval_data, batch, cpu_label_arrays)
update!(opts.eval_metric, cpu_label_arrays, cpu_output_arrays)
end
if opts.verbosity >= 3
info("## Validation summary")
for (name, value) in get(opts.eval_metric)
info(format("{1:>18s} = {2:.4f}", string(name), value))
end
end
end
if i_epoch == opts.n_epoch || any(x->isa(x, AbstractEpochCallback), opts.callbacks)
for (name, weights) in zip(param_names, param_arrays)
weight = +([copy(w, cpu()) for w in weights]...) / length(weights)
copy!(self.arg_params[name], weight)
end
for (name, aux_devs) in zip(aux_names, aux_arrays)
aux_avg = +([copy(aux, cpu()) for aux in aux_devs]...) / length(aux_devs)
copy!(self.aux_params[name], aux_avg)
end
end
opts.η_decay == :epoch && update!(optimizer.η_sched)
_invoke_callbacks(self, opts.callbacks, op_state, AbstractEpochCallback; metric=metric)
end
opts.verbosity >= 1 && info("Finish training on $(self.ctx)")
nothing
end