perl-package/AI-MXNet/lib/AI/MXNet/Module.pm (747 lines of code) (raw):

## TODO ## this class is here because of https://github.com/gfx/p5-Mouse/pull/67 ## once 2.4.7 version of Mouse in Ubuntu for affected Perl version ## these accessors should be merged into main class package AI::MXNet::Module::Private; use Mouse; has [qw/_param_names _fixed_param_names _aux_names _data_names _label_names _state_names _output_names _arg_params _aux_params _params_dirty _optimizer _kvstore _update_on_kvstore _updater _work_load_list _preload_opt_states _exec_group _data_shapes _label_shapes _context _grad_req/ ] => (is => 'rw', init_arg => undef); package AI::MXNet::Module; use AI::MXNet::Base; use AI::MXNet::Function::Parameters; use List::Util qw(max); use Data::Dumper (); use Mouse; func _create_kvstore( Maybe[Str|AI::MXNet::KVStore] $kvstore, Int $num_device, HashRef[AI::MXNet::NDArray] $arg_params ) { my $update_on_kvstore = 1; my $kv; if(defined $kvstore) { if(blessed $kvstore) { $kv = $kvstore; } else { # create kvstore using the string type if($num_device == 1 and $kvstore !~ /dist/) { # no need to use kv for single device and single machine } else { $kv = AI::MXNet::KVStore->create($kvstore); if($kvstore eq 'local') { # automatically select a proper local my $max_size = max(map { product(@{ $_->shape }) } values %{ $arg_params }); if($max_size > 1024 * 1024 * 16) { $update_on_kvstore = 0; } } } } } $update_on_kvstore = 0 if not $kv; return ($kv, $update_on_kvstore); } func _initialize_kvstore( AI::MXNet::KVStore :$kvstore, HashRef[AI::MXNet::NDArray] :$arg_params, ArrayRef[Str] :$param_names, Bool :$update_on_kvstore, ArrayRef[AI::MXNet::NDArray]|ArrayRef[ArrayRef[AI::MXNet::NDArray]] :$param_arrays ) { enumerate(sub{ my ($idx, $param_on_devs) = @_; my $name = $param_names->[$idx]; $kvstore->init($name, $arg_params->{ $name }); if($update_on_kvstore) { $kvstore->pull($name, out => $param_on_devs, priority => -$idx); } }, $param_arrays); } func _update_params_on_kvstore( ArrayRef[AI::MXNet::NDArray]|ArrayRef[ArrayRef[AI::MXNet::NDArray]] $param_arrays, ArrayRef[AI::MXNet::NDArray]|ArrayRef[ArrayRef[AI::MXNet::NDArray]] $grad_arrays, AI::MXNet::KVStore $kvstore, ArrayRef[Str] $param_names ) { enumerate(sub{ my ($index, $arg_list, $grad_list) = @_; if(ref $grad_list eq 'ARRAY' and not defined $grad_list->[0]) { return; } my $name = $param_names->[$index]; # push gradient, priority is negative index $kvstore->push($name, $grad_list, priority => -$index); # pull back the weights $kvstore->pull($name, out => $arg_list, priority => -$index); }, $param_arrays, $grad_arrays); } func _update_params( ArrayRef[ArrayRef[AI::MXNet::NDArray]] $param_arrays, ArrayRef[ArrayRef[AI::MXNet::NDArray]] $grad_arrays, AI::MXNet::Updater $updater, Int $num_device, Maybe[AI::MXNet::KVStore] $kvstore=, Maybe[ArrayRef[Str]] $param_names= ) { enumerate(sub{ my ($index, $arg_list, $grad_list) = @_; if(not defined $grad_list->[0]) { return; } if($kvstore) { my $name = $param_names->[$index]; # push gradient, priority is negative index $kvstore->push($name, $grad_list, priority => -$index); # pull back the sum gradients, to the same locations. $kvstore->pull($name, out => $grad_list, priority => -$index); } enumerate(sub { my ($k, $w, $g) = @_; # faked an index here, to make optimizer create diff # state for the same index but on diff devs, TODO(mli) # use a better solution later &{$updater}($index*$num_device+$k, $g, $w); }, $arg_list, $grad_list); }, $param_arrays, $grad_arrays); } method load_checkpoint(Str $prefix, Int $epoch) { my $symbol = AI::MXNet::Symbol->load("$prefix-symbol.json"); my %save_dict = %{ AI::MXNet::NDArray->load(sprintf('%s-%04d.params', $prefix, $epoch)) }; my %arg_params; my %aux_params; while(my ($k, $v) = each %save_dict) { my ($tp, $name) = split(/:/, $k, 2); if($tp eq 'arg') { $arg_params{$name} = $v; } if($tp eq 'aux') { $aux_params{$name} = $v; } } return ($symbol, \%arg_params, \%aux_params); } =head1 NAME AI::MXNet::Module - FeedForward interface of MXNet. See AI::MXNet::Module::Base for the details. =cut extends 'AI::MXNet::Module::Base'; has '_symbol' => (is => 'ro', init_arg => 'symbol', isa => 'AI::MXNet::Symbol', required => 1); has '_data_names' => (is => 'ro', init_arg => 'data_names', isa => 'ArrayRef[Str]'); has '_label_names' => (is => 'ro', init_arg => 'label_names', isa => 'Maybe[ArrayRef[Str]]'); has 'work_load_list' => (is => 'rw', isa => 'Maybe[ArrayRef[Int]]'); has 'fixed_param_names' => (is => 'rw', isa => 'Maybe[ArrayRef[Str]]'); has 'state_names' => (is => 'rw', isa => 'Maybe[ArrayRef[Str]]'); has 'logger' => (is => 'ro', default => sub { AI::MXNet::Logging->get_logger }); has '_p' => (is => 'rw', init_arg => undef); has 'context' => ( is => 'ro', isa => 'AI::MXNet::Context|ArrayRef[AI::MXNet::Context]', default => sub { AI::MXNet::Context->cpu } ); around BUILDARGS => sub { my $orig = shift; my $class = shift; if(@_%2) { my $symbol = shift; return $class->$orig(symbol => $symbol, @_); } return $class->$orig(@_); }; sub BUILD { my $self = shift; $self->_p(AI::MXNet::Module::Private->new); my $context = $self->context; if(blessed $context) { $context = [$context]; } $self->_p->_context($context); my $work_load_list = $self->work_load_list; if(not defined $work_load_list) { $work_load_list = [(1)x@{$self->_p->_context}]; } assert(@{ $work_load_list } == @{ $self->_p->_context }); $self->_p->_work_load_list($work_load_list); my @data_names = @{ $self->_data_names//['data'] }; my @label_names = @{ $self->_label_names//['softmax_label'] }; my @state_names = @{ $self->state_names//[] }; my $arg_names = $self->_symbol->list_arguments; my @input_names = (@data_names, @label_names, @state_names); my %input_names = map { $_ => 1 } @input_names; $self->_p->_param_names([grep { not exists $input_names{$_} } @{ $arg_names }]); $self->_p->_fixed_param_names($self->fixed_param_names//[]); $self->_p->_state_names(\@state_names); $self->_p->_aux_names($self->_symbol->list_auxiliary_states); $self->_p->_data_names(\@data_names); $self->_p->_label_names(\@label_names); $self->_p->_output_names($self->_symbol->list_outputs); $self->_p->_params_dirty(0); $self->_check_input_names($self->_symbol, $self->_p->_data_names, "data", 1); $self->_check_input_names($self->_symbol, $self->_p->_label_names, "label", 0); $self->_check_input_names($self->_symbol, $self->_p->_state_names, "state", 1); $self->_check_input_names($self->_symbol, $self->_p->_fixed_param_names, "fixed_param", 1); } method Module(@args) { return @args ? __PACKAGE__->new(@args) : __PACKAGE__ } method BucketingModule(@args) { return AI::MXNet::Module::Bucketing->new(@args) } =head2 load Create a model from previously saved checkpoint. Parameters ---------- prefix : str path prefix of saved model files. You should have "prefix-symbol.json", "prefix-xxxx.params", and optionally "prefix-xxxx.states", where xxxx is the epoch number. epoch : int epoch to load. load_optimizer_states : bool whether to load optimizer states. Checkpoint needs to have been made with save_optimizer_states=True. data_names : array ref of str Default is ['data'] for a typical model used in image classification. label_names : array ref of str Default is ['softmax_label'] for a typical model used in image classification. logger : Logger Default is AI::MXNet::Logging. context : Context or list of Context Default is cpu(0). work_load_list : array ref of number Default is undef, indicating an uniform workload. fixed_param_names: array ref of str Default is undef, indicating no network parameters are fixed. =cut method load( Str $prefix, Int $epoch, Bool $load_optimizer_states=0, %kwargs ) { my ($sym, $args, $auxs) = __PACKAGE__->load_checkpoint($prefix, $epoch); my $mod = $self->new(symbol => $sym, %kwargs); $mod->_p->_arg_params($args); $mod->_p->_aux_params($auxs); $mod->params_initialized(1); if($load_optimizer_states) { $mod->_p->_preload_opt_states(sprintf('%s-%04d.states', $prefix, $epoch)); } return $mod; } =head2 save_checkpoint Save current progress to a checkpoint. Use mx->callback->module_checkpoint as epoch_end_callback to save during training. Parameters ---------- prefix : str The file prefix to checkpoint to epoch : int The current epoch number save_optimizer_states : bool Whether to save optimizer states for later training =cut method save_checkpoint(Str $prefix, Int $epoch, Bool $save_optimizer_states=0) { $self->_symbol->save("$prefix-symbol.json"); my $param_name = sprintf('%s-%04d.params', $prefix, $epoch); $self->save_params($param_name); AI::MXNet::Logging->info('Saved checkpoint to "%s"', $param_name); if($save_optimizer_states) { my $state_name = sprintf('%s-%04d.states', $prefix, $epoch); $self->save_optimizer_states($state_name); AI::MXNet::Logging->info('Saved optimizer state to "%s"', $state_name); } } =head2 model_save_checkpoint Checkpoint the model data into file. Parameters ---------- prefix : str Prefix of model name. epoch : int The epoch number of the model. symbol : AI::MXNet::Symbol The input symbol arg_params : hash ref of str to AI::MXNet::NDArray Model parameter, hash ref of name to AI::MXNet::NDArray of net's weights. aux_params : hash ref of str to NDArray Model parameter, hash ref of name to AI::MXNet::NDArray of net's auxiliary states. Notes ----- - prefix-symbol.json will be saved for symbol. - prefix-epoch.params will be saved for parameters. =cut method model_save_checkpoint( Str $prefix, Int $epoch, Maybe[AI::MXNet::Symbol] $symbol, HashRef[AI::MXNet::NDArray] $arg_params, HashRef[AI::MXNet::NDArray] $aux_params ) { if(defined $symbol) { $symbol->save("$prefix-symbol.json"); } my $param_name = sprintf('%s-%04d.params', $prefix, $epoch); $self->save_params($param_name, $arg_params, $aux_params); AI::MXNet::Logging->info('Saved checkpoint to "%s"', $param_name); } # Internal function to reset binded state. method _reset_bind() { $self->binded(0); $self->_p->_exec_group(undef); $self->_p->_data_shapes(undef); $self->_p->_label_shapes(undef); } method data_names() { return $self->_p->_data_names; } method label_names() { return $self->_p->_label_names; } method output_names() { return $self->_p->_output_names; } method data_shapes() { assert($self->binded); return $self->_p->_data_shapes; } method label_shapes() { assert($self->binded); return $self->_p->_label_shapes; } method output_shapes() { assert($self->binded); return $self->_p->_exec_group->get_output_shapes; } method get_params() { assert($self->binded and $self->params_initialized); if($self->_p->_params_dirty) { $self->_sync_params_from_devices(); } return ($self->_p->_arg_params, $self->_p->_aux_params); } method init_params( Maybe[AI::MXNet::Initializer] :$initializer=AI::MXNet::Initializer->Uniform(scale => 0.01), Maybe[HashRef[AI::MXNet::NDArray]] :$arg_params=, Maybe[HashRef[AI::MXNet::NDArray]] :$aux_params=, Bool :$allow_missing=0, Bool :$force_init=0, Bool :$allow_extra=0 ) { if($self->params_initialized and not $force_init) { AI::MXNet::Logging->warning( "Parameters already initialized and force_init=0. " ."init_params call ignored." ); return; } assert($self->binded, 'call bind before initializing the parameters'); my $_impl = sub { my ($name, $arr, $cache) = @_; # Internal helper for parameter initialization if(defined $cache) { if(exists $cache->{$name}) { my $cache_arr = $cache->{$name}; # just in case the cached array is just the target itself if($cache_arr->handle ne $arr->handle) { $cache_arr->copyto($arr); } } else { if(not $allow_missing) { confess("$name is not presented"); } if(defined $initializer) { &{$initializer}($name, $arr); } } } else { &{$initializer}($name, $arr) if defined $initializer; } }; my $attrs = $self->_symbol->attr_dict; while(my ($name, $arr) = each %{ $self->_p->_arg_params }) { $_impl->( AI::MXNet::InitDesc->new( name => $name, ($attrs->{$name} ? (attrs => $attrs->{$name}) : ()) ), $arr, $arg_params ); } while(my ($name, $arr) = each %{ $self->_p->_aux_params }) { $_impl->( AI::MXNet::InitDesc->new( name => $name, ($attrs->{$name} ? (attrs => $attrs->{$name}) : ()) ), $arr, $aux_params ); } $self->params_initialized(1); $self->_p->_params_dirty(0); # copy the initialized parameters to devices $self->_p->_exec_group->set_params($self->_p->_arg_params, $self->_p->_aux_params, $allow_extra); } method set_params( HashRef[AI::MXNet::NDArray] $arg_params, HashRef[AI::MXNet::NDArray] $aux_params, Bool :$allow_missing=0, Bool :$force_init=1, Bool :$allow_extra=0 ) { if(not $allow_missing) { $self->init_params( arg_params => $arg_params, aux_params => $aux_params, allow_missing => $allow_missing, force_init => $force_init, allow_extra => $allow_extra ); return; } if($self->params_initialized and not $force_init) { AI::MXNet::Logging->warning( "Parameters already initialized and force_init=False. " ."set_params call ignored." ); return; } $self->_p->_exec_group->set_params($arg_params, $aux_params, $allow_extra); $self->_p->_params_dirty(1); $self->params_initialized(1); } =head2 bind Bind the symbols to construct executors. This is necessary before one can perform computation with the module. Parameters ---------- :$data_shapes : ArrayRef[AI::MXNet::DataDesc|NameShape] Typically is $data_iter->provide_data. :$label_shapes : Maybe[ArrayRef[AI::MXNet::DataDesc|NameShape]] Typically is $data_iter->provide_label. :$for_training : bool Default is 1. Whether the executors should be bind for training. :$inputs_need_grad : bool Default is 0. Whether the gradients to the input data need to be computed. Typically this is not needed. But this might be needed when implementing composition of modules. :$force_rebind : bool Default is 0. This function does nothing if the executors are already binded. But with this 1, the executors will be forced to rebind. :$shared_module : Module Default is undef. This is used in bucketing. When not undef, the shared module essentially corresponds to a different bucket -- a module with different symbol but with the same sets of parameters (e.g. unrolled RNNs with different lengths). =cut method bind( ArrayRef[AI::MXNet::DataDesc|NameShape] :$data_shapes, Maybe[ArrayRef[AI::MXNet::DataDesc|NameShape]] :$label_shapes=, Bool :$for_training=1, Bool :$inputs_need_grad=0, Bool :$force_rebind=0, Maybe[AI::MXNet::Module] :$shared_module=, GradReq|HashRef[GradReq]|ArrayRef[GradReq] :$grad_req='write', Maybe[ArrayRef[Str]] :$state_names=$self->_p->_state_names ) { # force rebinding is typically used when one want to switch from # training to prediction phase. if($force_rebind) { $self->_reset_bind(); } if($self->binded) { $self->logger->warning('Already binded, ignoring bind()'); return; } $self->for_training($for_training); $self->inputs_need_grad($inputs_need_grad); $self->binded(1); $self->_p->_grad_req($grad_req); if(not $for_training) { assert(not $inputs_need_grad); } ($data_shapes, $label_shapes) = $self->_parse_data_desc( $self->data_names, $self->label_names, $data_shapes, $label_shapes ); $self->_p->_data_shapes($data_shapes); $self->_p->_label_shapes($label_shapes); my $shared_group; if($shared_module) { assert($shared_module->binded and $shared_module->params_initialized); $shared_group = $shared_module->_p->_exec_group; } $self->_p->_exec_group( AI::MXNet::DataParallelExecutorGroup->new( symbol => $self->_symbol, contexts => $self->_p->_context, workload => $self->_p->_work_load_list, data_shapes => $self->_p->_data_shapes, label_shapes => $self->_p->_label_shapes, param_names => $self->_p->_param_names, state_names => $state_names, for_training => $for_training, inputs_need_grad => $inputs_need_grad, shared_group => $shared_group, logger => $self->logger, fixed_param_names => $self->_p->_fixed_param_names, grad_req => $grad_req ) ); if($shared_module) { $self->params_initialized(1); $self->_p->_arg_params($shared_module->_p->_arg_params); $self->_p->_aux_params($shared_module->_p->_aux_params); } elsif($self->params_initialized) { # if the parameters are already initialized, we are re-binding # so automatically copy the already initialized params $self->_p->_exec_group->set_params($self->_p->_arg_params, $self->_p->_aux_params); } else { assert(not defined $self->_p->_arg_params and not $self->_p->_aux_params); my @param_arrays = ( map { AI::MXNet::NDArray->zeros($_->[0]->shape, dtype => $_->[0]->dtype) } @{ $self->_p->_exec_group->_p->param_arrays } ); my %arg_params; @arg_params{ @{ $self->_p->_param_names } } = @param_arrays; $self->_p->_arg_params(\%arg_params); my @aux_arrays = ( map { AI::MXNet::NDArray->zeros($_->[0]->shape, dtype => $_->[0]->dtype) } @{ $self->_p->_exec_group->_p->aux_arrays } ); my %aux_params; @aux_params{ @{ $self->_p->_aux_names } } = @aux_arrays; $self->_p->_aux_params(\%aux_params); } if($shared_module and $shared_module->optimizer_initialized) { $self->borrow_optimizer($shared_module) } } =head2 reshape Reshape the module for new input shapes. Parameters ---------- :$data_shapes : ArrayRef[AI::MXNet::DataDesc] Typically is $data_iter->provide_data. :$label_shapes= : Maybe[ArrayRef[AI::MXNet::DataDesc]] Typically is $data_iter->provide_label. =cut method reshape( ArrayRef[AI::MXNet::DataDesc|NameShape] :$data_shapes, Maybe[ArrayRef[AI::MXNet::DataDesc|NameShape]] :$label_shapes= ) { assert($self->binded); ($data_shapes, $label_shapes) = $self->_parse_data_desc( $self->data_names, $self->label_names, $data_shapes, $label_shapes ); $self->_p->_data_shapes($data_shapes); $self->_p->_label_shapes($label_shapes); $self->_p->_exec_group->reshape($self->_p->_data_shapes, $self->_p->_label_shapes); } method init_optimizer( Str|AI::MXNet::KVStore :$kvstore='local', Optimizer :$optimizer='sgd', HashRef :$optimizer_params={ learning_rate => 0.01 }, Bool :$force_init=0 ) { assert($self->binded and $self->params_initialized); if($self->optimizer_initialized and not $force_init) { $self->logger->warning('optimizer already initialized, ignoring...'); return; } if($self->_p->_params_dirty) { $self->_sync_params_from_devices; } my ($kvstore, $update_on_kvstore) = _create_kvstore( $kvstore, scalar(@{$self->_p->_context}), $self->_p->_arg_params ); my $batch_size = $self->_p->_exec_group->_p->batch_size; if($kvstore and $kvstore->type =~ /dist/ and $kvstore->type =~ /_sync/) { $batch_size *= $kvstore->num_workers; } my $rescale_grad = 1/$batch_size; if(not blessed $optimizer) { my %idx2name; if($update_on_kvstore) { @idx2name{ 0..@{$self->_p->_exec_group->param_names}-1 } = @{$self->_p->_exec_group->param_names}; } else { for my $k (0..@{$self->_p->_context}-1) { @idx2name{ map { $_ + $k } 0..@{$self->_p->_exec_group->param_names}-1 } = @{$self->_p->_exec_group->param_names}; } } if(not exists $optimizer_params->{rescale_grad}) { $optimizer_params->{rescale_grad} = $rescale_grad; } $optimizer = AI::MXNet::Optimizer->create( $optimizer, sym => $self->symbol, param_idx2name => \%idx2name, %{ $optimizer_params } ); if($optimizer->rescale_grad != $rescale_grad) { AI::MXNet::Logging->warning( "Optimizer created manually outside Module but rescale_grad " ."is not normalized to 1.0/batch_size/num_workers (%s vs. %s). " ."Is this intended?", $optimizer->rescale_grad, $rescale_grad ); } } $self->_p->_optimizer($optimizer); $self->_p->_kvstore($kvstore); $self->_p->_update_on_kvstore($update_on_kvstore); $self->_p->_updater(undef); if($kvstore) { # copy initialized local parameters to kvstore _initialize_kvstore( kvstore => $kvstore, param_arrays => $self->_p->_exec_group->_p->param_arrays, arg_params => $self->_p->_arg_params, param_names => $self->_p->_param_names, update_on_kvstore => $update_on_kvstore ); } if($update_on_kvstore) { $kvstore->set_optimizer($self->_p->_optimizer); } else { $self->_p->_updater(AI::MXNet::Optimizer->get_updater($optimizer)); } $self->optimizer_initialized(1); if($self->_p->_preload_opt_states) { $self->load_optimizer_states($self->_p->_preload_opt_states); $self->_p->_preload_opt_states(undef); } } =head2 borrow_optimizer Borrow optimizer from a shared module. Used in bucketing, where exactly the same optimizer (esp. kvstore) is used. Parameters ---------- shared_module : AI::MXNet::Module =cut method borrow_optimizer(AI::MXNet::Module $shared_module) { assert($shared_module->optimizer_initialized); $self->_p->_optimizer($shared_module->_p->_optimizer); $self->_p->_kvstore($shared_module->_p->_kvstore); $self->_p->_update_on_kvstore($shared_module->_p->_update_on_kvstore); $self->_p->_updater($shared_module->_p->_updater); $self->optimizer_initialized(1); } method forward( AI::MXNet::DataBatch $data_batch, Maybe[Bool] :$is_train= ) { assert($self->binded and $self->params_initialized); # If starting to do the inference, force rebind the module. if($self->label_shapes and not $data_batch->label) { confess( "If you are trying to do inference, rebind module ". "with 'force_rebind=True' and 'for_training=False'" ); } my @curr_data_shapes = map { $_->shape } @{ $self->data_shapes }; my @new_data_shapes = map { $_->shape } @{ $data_batch->data }; if(Data::Dumper->Dump(\@curr_data_shapes) ne Data::Dumper->Dump(\@new_data_shapes)) { my $new_dshape; if($data_batch->can('provide_data') and $data_batch->provide_data) { $new_dshape = $data_batch->provide_data; } else { $new_dshape = []; zip(sub { my ($i, $shape) = @_; push @{ $new_dshape }, AI::MXNet::DataDesc->new( $i->name, $shape, $i->dtype, $i->layout ); }, $self->data_shapes, \@new_data_shapes); } my $new_lshape; if($data_batch->can('provide_label') and $data_batch->provide_label) { $new_lshape = $data_batch->provide_label; } elsif($data_batch->can('label') and $data_batch->label) { $new_lshape = []; zip(sub { my ($i, $j) = @_; push @{ $new_lshape }, AI::MXNet::DataDesc->new( $i->name, $j->shape, $i->dtype, $i->layout ); }, $self->label_shapes, $data_batch->label); } $self->reshape(data_shapes => $new_dshape, label_shapes => $new_lshape); } $self->_p->_exec_group->forward($data_batch, $is_train); } method backward(Maybe[AI::MXNet::NDArray|ArrayRef[AI::MXNet::NDArray]] $out_grads=) { assert($self->binded and $self->params_initialized); $self->_p->_exec_group->backward($out_grads); } method update() { assert($self->binded and $self->params_initialized and $self->optimizer_initialized); $self->_p->_params_dirty(1); if($self->_p->_update_on_kvstore) { _update_params_on_kvstore( $self->_p->_exec_group->_p->param_arrays, $self->_p->_exec_group->_p->grad_arrays, $self->_p->_kvstore, $self->_p->_exec_group->param_names ); } else { _update_params( $self->_p->_exec_group->_p->param_arrays, $self->_p->_exec_group->_p->grad_arrays, $self->_p->_updater, scalar(@{ $self->_p->_context}), $self->_p->_kvstore, $self->_p->_exec_group->param_names ); } } method get_outputs(Bool $merge_multi_context=1) { assert($self->binded and $self->params_initialized); return $self->_p->_exec_group->get_outputs($merge_multi_context); } method get_input_grads(Bool $merge_multi_context=1) { assert($self->binded and $self->params_initialized and $self->inputs_need_grad); return $self->_p->_exec_group->get_input_grads($merge_multi_context); } method get_states(Bool $merge_multi_context=1) { assert($self->binded and $self->params_initialized); return $self->_p->_exec_group->get_states($merge_multi_context); } method set_states(:$states=, :$value=) { assert($self->binded and $self->params_initialized); return $self->_p->_exec_group->set_states($states, $value); } method update_metric( AI::MXNet::EvalMetric $eval_metric, ArrayRef[AI::MXNet::NDArray] $labels ) { $self->_p->_exec_group->update_metric($eval_metric, $labels); } =head2 _sync_params_from_devices Synchronize parameters from devices to CPU. This function should be called after calling 'update' that updates the parameters on the devices, before one can read the latest parameters from $self->_arg_params and $self->_aux_params. =cut method _sync_params_from_devices() { $self->_p->_exec_group->get_params($self->_p->_arg_params, $self->_p->_aux_params); $self->_p->_params_dirty(0); } method save_optimizer_states(Str $fname) { assert($self->optimizer_initialized); if($self->_p->_update_on_kvstore) { $self->_p->_kvstore->save_optimizer_states($fname); } else { open(F, ">:raw", "$fname") or confess("can't open $fname for writing: $!"); print F $self->_p->_updater->get_states(); close(F); } } method load_optimizer_states(Str $fname) { assert($self->optimizer_initialized); if($self->_p->_update_on_kvstore) { $self->_p->_kvstore->load_optimizer_states($fname); } else { open(F, "<:raw", "$fname") or confess("can't open $fname for reading: $!"); my $data; { local($/) = undef; $data = <F>; } close(F); $self->_p->_updater->set_states($data); } } method install_monitor(AI::MXNet::Monitor $mon) { assert($self->binded); $self->_p->_exec_group->install_monitor($mon); } method _updater() { $self->_p->_updater; } method _kvstore() { $self->_p->_kvstore; } 1;