function fit()

in julia/src/model.jl [363:608]


function fit(self::FeedForward, optimizer::AbstractOptimizer, data::AbstractDataProvider;
             kwargs...)
  opts = TrainingOptions(; kwargs...)

  opts.verbosity >= 1 && info("Start training on $(self.ctx)")

  batch_size  = get_batch_size(data)
  num_dev     = length(self.ctx)
  slices      = _split_inputs(batch_size, num_dev)

  
  opts.verbosity >= 2 && info("Initializing parameters...")
  arg_names, param_names, aux_names = _init_model(self, data, opts.initializer, opts.force_init)

  
  kvstore = opts.kvstore
  if isa(kvstore, Symbol)
    opts.verbosity >= 2 && info("Creating KVStore...")
    kvstore = _create_kvstore(kvstore, length(self.ctx), self.arg_params, opts.verbosity)
  end

  update_on_kvstore = true
  if isa(kvstore, Void) || ismatch(r"local_allreduce", string(get_type(kvstore)))
    update_on_kvstore = false
  end

  
  freeze_names = Symbol[]
  for (attr, value) in list_all_attr(self.arch)
    sattr = string(attr)
    if endswith(sattr, "grad") && value == "freeze"
      push!(freeze_names, Symbol(sattr[1:end-5]))
    end
  end
  
  freeze_idx = filter(i -> in(param_names[i], freeze_names), 1:length(param_names))

  
  grad_req = Dict{Symbol,GRAD_REQ}()
  for param in param_names
    if in(param, freeze_names)
      grad_req[param] = GRAD_NOP
    else
      grad_req[param] = GRAD_WRITE
    end
  end

  train_execs = Array{Executor}(num_dev)
  for i = 1:num_dev
    data_shapes = Dict(map((x) -> x[1] => tuple(x[2][1:end-1]...,length(slices[i])), provide_data(data)))
    label_shapes = Dict(map((x) -> x[1] => tuple(x[2][1:end-1]...,length(slices[i])), provide_label(data)))
    train_execs[i] = simple_bind(self.arch, self.ctx[i]; grad_req=grad_req, data_shapes..., label_shapes...)
    dbg_str = mx.debug_str(train_execs[i])
    opts.verbosity >= 2 && info(string("TempSpace: ", split(dbg_str, ['\n'])[end-2]..., " on ", self.ctx[i]))

    copy_params_from(train_execs[i], self.arg_params, self.aux_params)
  end

  
  data_names   = [x[1] for x in provide_data(data)]
  label_names  = [x[1] for x in provide_label(data)]

  data_arrays  = [SlicedNDArray[(slices[i], exec.arg_dict[name]) for (i,exec) in enumerate(train_execs)]
                  for name in data_names]
  label_arrays = [SlicedNDArray[(slices[i], exec.arg_dict[name]) for (i,exec) in enumerate(train_execs)]
                  for name in label_names]

  param_idx    = filter(i -> in(arg_names[i], param_names), 1:length(arg_names))

  param_arrays = [NDArray[exec.arg_arrays[i] for exec in train_execs] for i in param_idx]
  grad_arrays  = [NDArray[exec.grad_arrays[i] for exec in train_execs] for i in param_idx]
  aux_arrays   = [NDArray[exec.aux_arrays[i] for exec in train_execs] for i = 1:length(aux_names)]

  op_state = OptimizationState(batch_size)
  
  iszero(optimizer.scale) && (optimizer.scale = 1 / batch_size)

  if !update_on_kvstore
    updater = getupdater(optimizer)
  end

  if !isa(kvstore, Void)
    if update_on_kvstore
      set_optimizer(kvstore, optimizer)
    end

    opts.verbosity >= 2 && info("Initializing KVStore...")
    
    for idx = 1:length(param_arrays)
      param_on_devs = param_arrays[idx]

      init!(kvstore, idx, self.arg_params[param_names[idx]])

      if update_on_kvstore
        
        pull!(kvstore, idx, param_on_devs, priority=-idx)
      end
    end
  end

  
  output_shapes = [tuple(size(x)[1:end-1]...,batch_size) for x in train_execs[1].outputs]
  cpu_dev = Context(CPU)
  cpu_output_arrays = [empty(shape, cpu_dev) for shape in output_shapes]
  cpu_label_arrays  = [empty(shape, cpu_dev) for (name,shape) in provide_label(data)]

  
  _invoke_callbacks(self, opts.callbacks, op_state, AbstractEpochCallback)

  opts.verbosity >= 2 && info("Start training...")
  for i_epoch = 1:opts.n_epoch
    time_start = time()
    reset!(opts.eval_metric)

    op_state.curr_epoch = i_epoch
    op_state.curr_batch = 0

    
    _invoke_callbacks(self, opts.callbacks, op_state, AbstractBatchCallback)

    for batch in eachbatch(data)
      load_data!(data, batch, data_arrays)
      load_label!(data, batch, label_arrays)

      
      for (texec, islice) in zip(train_execs, slices)
        forward(texec, is_train=true)

        
        for (cpu_out, dev_out) in zip(cpu_output_arrays, texec.outputs)
          copy!(slice(cpu_out, islice), dev_out)
        end

        backward(texec)
      end

      op_state.curr_iter  += 1
      op_state.curr_batch += 1

      
      for idx = 1:length(param_names)
        if in(idx, freeze_idx)
          continue 
        end

        
        if !isa(kvstore, Void)
          
          push!(kvstore, idx, grad_arrays[idx], priority=-idx)
          if update_on_kvstore
            
            pull!(kvstore, idx, param_arrays[idx], priority=-idx)
          else
            
            pull!(kvstore, idx, grad_arrays[idx], priority=-idx)
          end
        end

        if !update_on_kvstore
          
          for i_dev = 1:num_dev
            
            
            
            fake_idx = idx * num_dev + i_dev
            updater(fake_idx, grad_arrays[idx][i_dev], param_arrays[idx][i_dev])
          end
        end
      end

      
      opts.η_decay == :batch && update!(optimizer.η_sched)

      
      _invoke_callbacks(self, opts.callbacks, op_state, AbstractBatchCallback)

      
      load_label!(data, batch, cpu_label_arrays)
      update!(opts.eval_metric, cpu_label_arrays, cpu_output_arrays)
    end 

    time_stop = time()
    metric = get(opts.eval_metric)
    opts.verbosity >= 2 && info(format("== Epoch {1:0>3d}/{2:0>3d} ==========", i_epoch, opts.n_epoch))
    if opts.verbosity >= 3
        info("## Training summary")
        for (name, value) in metric
            info(format("{1:>18s} = {2:.4f}", string(name), value))
        end
        info(format("{1:>18s} = {2:.4f} seconds", "time", time_stop-time_start))
    end

    
    if !isa(opts.eval_data, Void)
      
      
      
      @assert(get_batch_size(opts.eval_data) == batch_size)

      reset!(opts.eval_metric)
      for batch in eachbatch(opts.eval_data)
        load_data!(opts.eval_data, batch, data_arrays)

        
        for (texec, islice) in zip(train_execs, slices)
          forward(texec, is_train=true)

          
          for (cpu_out, dev_out) in zip(cpu_output_arrays, texec.outputs)
            copy!(slice(cpu_out, islice), dev_out)
          end
        end
        load_label!(opts.eval_data, batch, cpu_label_arrays)
        update!(opts.eval_metric, cpu_label_arrays, cpu_output_arrays)
      end

      if opts.verbosity >= 3
          info("## Validation summary")
          for (name, value) in get(opts.eval_metric)
            info(format("{1:>18s} = {2:.4f}", string(name), value))
          end
      end
    end

    if i_epoch == opts.n_epoch || any(x->isa(x, AbstractEpochCallback), opts.callbacks)
      
      for (name, weights) in zip(param_names, param_arrays)
        
        weight = +([copy(w, cpu()) for w in weights]...) / length(weights)
        copy!(self.arg_params[name], weight)
      end
      for (name, aux_devs) in zip(aux_names, aux_arrays)
        aux_avg = +([copy(aux, cpu()) for aux in aux_devs]...) / length(aux_devs)
        copy!(self.aux_params[name], aux_avg)
      end
    end

    
    opts.η_decay == :epoch && update!(optimizer.η_sched)

    _invoke_callbacks(self, opts.callbacks, op_state, AbstractEpochCallback; metric=metric)
  end 

  opts.verbosity >= 1 && info("Finish training on $(self.ctx)")
  nothing
end