perl-package/AI-MXNet/lib/AI/MXNet/Module.pm (747 lines of code) (raw):
## TODO
## this class is here because of https://github.com/gfx/p5-Mouse/pull/67
## once 2.4.7 version of Mouse in Ubuntu for affected Perl version
## these accessors should be merged into main class
package AI::MXNet::Module::Private;
use Mouse;
has [qw/_param_names _fixed_param_names
_aux_names _data_names _label_names _state_names
_output_names _arg_params _aux_params
_params_dirty _optimizer _kvstore
_update_on_kvstore _updater _work_load_list
_preload_opt_states _exec_group
_data_shapes _label_shapes _context _grad_req/
] => (is => 'rw', init_arg => undef);
package AI::MXNet::Module;
use AI::MXNet::Base;
use AI::MXNet::Function::Parameters;
use List::Util qw(max);
use Data::Dumper ();
use Mouse;
func _create_kvstore(
Maybe[Str|AI::MXNet::KVStore] $kvstore,
Int $num_device,
HashRef[AI::MXNet::NDArray] $arg_params
)
{
my $update_on_kvstore = 1;
my $kv;
if(defined $kvstore)
{
if(blessed $kvstore)
{
$kv = $kvstore;
}
else
{
# create kvstore using the string type
if($num_device == 1 and $kvstore !~ /dist/)
{
# no need to use kv for single device and single machine
}
else
{
$kv = AI::MXNet::KVStore->create($kvstore);
if($kvstore eq 'local')
{
# automatically select a proper local
my $max_size = max(map { product(@{ $_->shape }) } values %{ $arg_params });
if($max_size > 1024 * 1024 * 16)
{
$update_on_kvstore = 0;
}
}
}
}
}
$update_on_kvstore = 0 if not $kv;
return ($kv, $update_on_kvstore);
}
func _initialize_kvstore(
AI::MXNet::KVStore :$kvstore,
HashRef[AI::MXNet::NDArray] :$arg_params,
ArrayRef[Str] :$param_names,
Bool :$update_on_kvstore,
ArrayRef[AI::MXNet::NDArray]|ArrayRef[ArrayRef[AI::MXNet::NDArray]] :$param_arrays
)
{
enumerate(sub{
my ($idx, $param_on_devs) = @_;
my $name = $param_names->[$idx];
$kvstore->init($name, $arg_params->{ $name });
if($update_on_kvstore)
{
$kvstore->pull($name, out => $param_on_devs, priority => -$idx);
}
}, $param_arrays);
}
func _update_params_on_kvstore(
ArrayRef[AI::MXNet::NDArray]|ArrayRef[ArrayRef[AI::MXNet::NDArray]] $param_arrays,
ArrayRef[AI::MXNet::NDArray]|ArrayRef[ArrayRef[AI::MXNet::NDArray]] $grad_arrays,
AI::MXNet::KVStore $kvstore,
ArrayRef[Str] $param_names
)
{
enumerate(sub{
my ($index, $arg_list, $grad_list) = @_;
if(ref $grad_list eq 'ARRAY' and not defined $grad_list->[0])
{
return;
}
my $name = $param_names->[$index];
# push gradient, priority is negative index
$kvstore->push($name, $grad_list, priority => -$index);
# pull back the weights
$kvstore->pull($name, out => $arg_list, priority => -$index);
}, $param_arrays, $grad_arrays);
}
func _update_params(
ArrayRef[ArrayRef[AI::MXNet::NDArray]] $param_arrays,
ArrayRef[ArrayRef[AI::MXNet::NDArray]] $grad_arrays,
AI::MXNet::Updater $updater,
Int $num_device,
Maybe[AI::MXNet::KVStore] $kvstore=,
Maybe[ArrayRef[Str]] $param_names=
)
{
enumerate(sub{
my ($index, $arg_list, $grad_list) = @_;
if(not defined $grad_list->[0])
{
return;
}
if($kvstore)
{
my $name = $param_names->[$index];
# push gradient, priority is negative index
$kvstore->push($name, $grad_list, priority => -$index);
# pull back the sum gradients, to the same locations.
$kvstore->pull($name, out => $grad_list, priority => -$index);
}
enumerate(sub {
my ($k, $w, $g) = @_;
# faked an index here, to make optimizer create diff
# state for the same index but on diff devs, TODO(mli)
# use a better solution later
&{$updater}($index*$num_device+$k, $g, $w);
}, $arg_list, $grad_list);
}, $param_arrays, $grad_arrays);
}
method load_checkpoint(Str $prefix, Int $epoch)
{
my $symbol = AI::MXNet::Symbol->load("$prefix-symbol.json");
my %save_dict = %{ AI::MXNet::NDArray->load(sprintf('%s-%04d.params', $prefix, $epoch)) };
my %arg_params;
my %aux_params;
while(my ($k, $v) = each %save_dict)
{
my ($tp, $name) = split(/:/, $k, 2);
if($tp eq 'arg')
{
$arg_params{$name} = $v;
}
if($tp eq 'aux')
{
$aux_params{$name} = $v;
}
}
return ($symbol, \%arg_params, \%aux_params);
}
=head1 NAME
AI::MXNet::Module - FeedForward interface of MXNet.
See AI::MXNet::Module::Base for the details.
=cut
extends 'AI::MXNet::Module::Base';
has '_symbol' => (is => 'ro', init_arg => 'symbol', isa => 'AI::MXNet::Symbol', required => 1);
has '_data_names' => (is => 'ro', init_arg => 'data_names', isa => 'ArrayRef[Str]');
has '_label_names' => (is => 'ro', init_arg => 'label_names', isa => 'Maybe[ArrayRef[Str]]');
has 'work_load_list' => (is => 'rw', isa => 'Maybe[ArrayRef[Int]]');
has 'fixed_param_names' => (is => 'rw', isa => 'Maybe[ArrayRef[Str]]');
has 'state_names' => (is => 'rw', isa => 'Maybe[ArrayRef[Str]]');
has 'logger' => (is => 'ro', default => sub { AI::MXNet::Logging->get_logger });
has '_p' => (is => 'rw', init_arg => undef);
has 'context' => (
is => 'ro',
isa => 'AI::MXNet::Context|ArrayRef[AI::MXNet::Context]',
default => sub { AI::MXNet::Context->cpu }
);
around BUILDARGS => sub {
my $orig = shift;
my $class = shift;
if(@_%2)
{
my $symbol = shift;
return $class->$orig(symbol => $symbol, @_);
}
return $class->$orig(@_);
};
sub BUILD
{
my $self = shift;
$self->_p(AI::MXNet::Module::Private->new);
my $context = $self->context;
if(blessed $context)
{
$context = [$context];
}
$self->_p->_context($context);
my $work_load_list = $self->work_load_list;
if(not defined $work_load_list)
{
$work_load_list = [(1)x@{$self->_p->_context}];
}
assert(@{ $work_load_list } == @{ $self->_p->_context });
$self->_p->_work_load_list($work_load_list);
my @data_names = @{ $self->_data_names//['data'] };
my @label_names = @{ $self->_label_names//['softmax_label'] };
my @state_names = @{ $self->state_names//[] };
my $arg_names = $self->_symbol->list_arguments;
my @input_names = (@data_names, @label_names, @state_names);
my %input_names = map { $_ => 1 } @input_names;
$self->_p->_param_names([grep { not exists $input_names{$_} } @{ $arg_names }]);
$self->_p->_fixed_param_names($self->fixed_param_names//[]);
$self->_p->_state_names(\@state_names);
$self->_p->_aux_names($self->_symbol->list_auxiliary_states);
$self->_p->_data_names(\@data_names);
$self->_p->_label_names(\@label_names);
$self->_p->_output_names($self->_symbol->list_outputs);
$self->_p->_params_dirty(0);
$self->_check_input_names($self->_symbol, $self->_p->_data_names, "data", 1);
$self->_check_input_names($self->_symbol, $self->_p->_label_names, "label", 0);
$self->_check_input_names($self->_symbol, $self->_p->_state_names, "state", 1);
$self->_check_input_names($self->_symbol, $self->_p->_fixed_param_names, "fixed_param", 1);
}
method Module(@args) { return @args ? __PACKAGE__->new(@args) : __PACKAGE__ }
method BucketingModule(@args) { return AI::MXNet::Module::Bucketing->new(@args) }
=head2 load
Create a model from previously saved checkpoint.
Parameters
----------
prefix : str
path prefix of saved model files. You should have
"prefix-symbol.json", "prefix-xxxx.params", and
optionally "prefix-xxxx.states", where xxxx is the
epoch number.
epoch : int
epoch to load.
load_optimizer_states : bool
whether to load optimizer states. Checkpoint needs
to have been made with save_optimizer_states=True.
data_names : array ref of str
Default is ['data'] for a typical model used in image classification.
label_names : array ref of str
Default is ['softmax_label'] for a typical model used in image
classification.
logger : Logger
Default is AI::MXNet::Logging.
context : Context or list of Context
Default is cpu(0).
work_load_list : array ref of number
Default is undef, indicating an uniform workload.
fixed_param_names: array ref of str
Default is undef, indicating no network parameters are fixed.
=cut
method load(
Str $prefix,
Int $epoch,
Bool $load_optimizer_states=0,
%kwargs
)
{
my ($sym, $args, $auxs) = __PACKAGE__->load_checkpoint($prefix, $epoch);
my $mod = $self->new(symbol => $sym, %kwargs);
$mod->_p->_arg_params($args);
$mod->_p->_aux_params($auxs);
$mod->params_initialized(1);
if($load_optimizer_states)
{
$mod->_p->_preload_opt_states(sprintf('%s-%04d.states', $prefix, $epoch));
}
return $mod;
}
=head2 save_checkpoint
Save current progress to a checkpoint.
Use mx->callback->module_checkpoint as epoch_end_callback to save during training.
Parameters
----------
prefix : str
The file prefix to checkpoint to
epoch : int
The current epoch number
save_optimizer_states : bool
Whether to save optimizer states for later training
=cut
method save_checkpoint(Str $prefix, Int $epoch, Bool $save_optimizer_states=0)
{
$self->_symbol->save("$prefix-symbol.json");
my $param_name = sprintf('%s-%04d.params', $prefix, $epoch);
$self->save_params($param_name);
AI::MXNet::Logging->info('Saved checkpoint to "%s"', $param_name);
if($save_optimizer_states)
{
my $state_name = sprintf('%s-%04d.states', $prefix, $epoch);
$self->save_optimizer_states($state_name);
AI::MXNet::Logging->info('Saved optimizer state to "%s"', $state_name);
}
}
=head2 model_save_checkpoint
Checkpoint the model data into file.
Parameters
----------
prefix : str
Prefix of model name.
epoch : int
The epoch number of the model.
symbol : AI::MXNet::Symbol
The input symbol
arg_params : hash ref of str to AI::MXNet::NDArray
Model parameter, hash ref of name to AI::MXNet::NDArray of net's weights.
aux_params : hash ref of str to NDArray
Model parameter, hash ref of name to AI::MXNet::NDArray of net's auxiliary states.
Notes
-----
- prefix-symbol.json will be saved for symbol.
- prefix-epoch.params will be saved for parameters.
=cut
method model_save_checkpoint(
Str $prefix,
Int $epoch,
Maybe[AI::MXNet::Symbol] $symbol,
HashRef[AI::MXNet::NDArray] $arg_params,
HashRef[AI::MXNet::NDArray] $aux_params
)
{
if(defined $symbol)
{
$symbol->save("$prefix-symbol.json");
}
my $param_name = sprintf('%s-%04d.params', $prefix, $epoch);
$self->save_params($param_name, $arg_params, $aux_params);
AI::MXNet::Logging->info('Saved checkpoint to "%s"', $param_name);
}
# Internal function to reset binded state.
method _reset_bind()
{
$self->binded(0);
$self->_p->_exec_group(undef);
$self->_p->_data_shapes(undef);
$self->_p->_label_shapes(undef);
}
method data_names()
{
return $self->_p->_data_names;
}
method label_names()
{
return $self->_p->_label_names;
}
method output_names()
{
return $self->_p->_output_names;
}
method data_shapes()
{
assert($self->binded);
return $self->_p->_data_shapes;
}
method label_shapes()
{
assert($self->binded);
return $self->_p->_label_shapes;
}
method output_shapes()
{
assert($self->binded);
return $self->_p->_exec_group->get_output_shapes;
}
method get_params()
{
assert($self->binded and $self->params_initialized);
if($self->_p->_params_dirty)
{
$self->_sync_params_from_devices();
}
return ($self->_p->_arg_params, $self->_p->_aux_params);
}
method init_params(
Maybe[AI::MXNet::Initializer] :$initializer=AI::MXNet::Initializer->Uniform(scale => 0.01),
Maybe[HashRef[AI::MXNet::NDArray]] :$arg_params=,
Maybe[HashRef[AI::MXNet::NDArray]] :$aux_params=,
Bool :$allow_missing=0,
Bool :$force_init=0,
Bool :$allow_extra=0
)
{
if($self->params_initialized and not $force_init)
{
AI::MXNet::Logging->warning(
"Parameters already initialized and force_init=0. "
."init_params call ignored."
);
return;
}
assert($self->binded, 'call bind before initializing the parameters');
my $_impl = sub {
my ($name, $arr, $cache) = @_;
# Internal helper for parameter initialization
if(defined $cache)
{
if(exists $cache->{$name})
{
my $cache_arr = $cache->{$name};
# just in case the cached array is just the target itself
if($cache_arr->handle ne $arr->handle)
{
$cache_arr->copyto($arr);
}
}
else
{
if(not $allow_missing)
{
confess("$name is not presented");
}
if(defined $initializer)
{
&{$initializer}($name, $arr);
}
}
}
else
{
&{$initializer}($name, $arr) if defined $initializer;
}
};
my $attrs = $self->_symbol->attr_dict;
while(my ($name, $arr) = each %{ $self->_p->_arg_params })
{
$_impl->(
AI::MXNet::InitDesc->new(
name => $name,
($attrs->{$name} ? (attrs => $attrs->{$name}) : ())
),
$arr, $arg_params
);
}
while(my ($name, $arr) = each %{ $self->_p->_aux_params })
{
$_impl->(
AI::MXNet::InitDesc->new(
name => $name,
($attrs->{$name} ? (attrs => $attrs->{$name}) : ())
),
$arr, $aux_params
);
}
$self->params_initialized(1);
$self->_p->_params_dirty(0);
# copy the initialized parameters to devices
$self->_p->_exec_group->set_params($self->_p->_arg_params, $self->_p->_aux_params, $allow_extra);
}
method set_params(
HashRef[AI::MXNet::NDArray] $arg_params,
HashRef[AI::MXNet::NDArray] $aux_params,
Bool :$allow_missing=0,
Bool :$force_init=1,
Bool :$allow_extra=0
)
{
if(not $allow_missing)
{
$self->init_params(
arg_params => $arg_params, aux_params => $aux_params,
allow_missing => $allow_missing, force_init => $force_init,
allow_extra => $allow_extra
);
return;
}
if($self->params_initialized and not $force_init)
{
AI::MXNet::Logging->warning(
"Parameters already initialized and force_init=False. "
."set_params call ignored."
);
return;
}
$self->_p->_exec_group->set_params($arg_params, $aux_params, $allow_extra);
$self->_p->_params_dirty(1);
$self->params_initialized(1);
}
=head2 bind
Bind the symbols to construct executors. This is necessary before one
can perform computation with the module.
Parameters
----------
:$data_shapes : ArrayRef[AI::MXNet::DataDesc|NameShape]
Typically is $data_iter->provide_data.
:$label_shapes : Maybe[ArrayRef[AI::MXNet::DataDesc|NameShape]]
Typically is $data_iter->provide_label.
:$for_training : bool
Default is 1. Whether the executors should be bind for training.
:$inputs_need_grad : bool
Default is 0. Whether the gradients to the input data need to be computed.
Typically this is not needed. But this might be needed when implementing composition
of modules.
:$force_rebind : bool
Default is 0. This function does nothing if the executors are already
binded. But with this 1, the executors will be forced to rebind.
:$shared_module : Module
Default is undef. This is used in bucketing. When not undef, the shared module
essentially corresponds to a different bucket -- a module with different symbol
but with the same sets of parameters (e.g. unrolled RNNs with different lengths).
=cut
method bind(
ArrayRef[AI::MXNet::DataDesc|NameShape] :$data_shapes,
Maybe[ArrayRef[AI::MXNet::DataDesc|NameShape]] :$label_shapes=,
Bool :$for_training=1,
Bool :$inputs_need_grad=0,
Bool :$force_rebind=0,
Maybe[AI::MXNet::Module] :$shared_module=,
GradReq|HashRef[GradReq]|ArrayRef[GradReq] :$grad_req='write',
Maybe[ArrayRef[Str]] :$state_names=$self->_p->_state_names
)
{
# force rebinding is typically used when one want to switch from
# training to prediction phase.
if($force_rebind)
{
$self->_reset_bind();
}
if($self->binded)
{
$self->logger->warning('Already binded, ignoring bind()');
return;
}
$self->for_training($for_training);
$self->inputs_need_grad($inputs_need_grad);
$self->binded(1);
$self->_p->_grad_req($grad_req);
if(not $for_training)
{
assert(not $inputs_need_grad);
}
($data_shapes, $label_shapes) = $self->_parse_data_desc(
$self->data_names, $self->label_names, $data_shapes, $label_shapes
);
$self->_p->_data_shapes($data_shapes);
$self->_p->_label_shapes($label_shapes);
my $shared_group;
if($shared_module)
{
assert($shared_module->binded and $shared_module->params_initialized);
$shared_group = $shared_module->_p->_exec_group;
}
$self->_p->_exec_group(
AI::MXNet::DataParallelExecutorGroup->new(
symbol => $self->_symbol,
contexts => $self->_p->_context,
workload => $self->_p->_work_load_list,
data_shapes => $self->_p->_data_shapes,
label_shapes => $self->_p->_label_shapes,
param_names => $self->_p->_param_names,
state_names => $state_names,
for_training => $for_training,
inputs_need_grad => $inputs_need_grad,
shared_group => $shared_group,
logger => $self->logger,
fixed_param_names => $self->_p->_fixed_param_names,
grad_req => $grad_req
)
);
if($shared_module)
{
$self->params_initialized(1);
$self->_p->_arg_params($shared_module->_p->_arg_params);
$self->_p->_aux_params($shared_module->_p->_aux_params);
}
elsif($self->params_initialized)
{
# if the parameters are already initialized, we are re-binding
# so automatically copy the already initialized params
$self->_p->_exec_group->set_params($self->_p->_arg_params, $self->_p->_aux_params);
}
else
{
assert(not defined $self->_p->_arg_params and not $self->_p->_aux_params);
my @param_arrays = (
map { AI::MXNet::NDArray->zeros($_->[0]->shape, dtype => $_->[0]->dtype) }
@{ $self->_p->_exec_group->_p->param_arrays }
);
my %arg_params;
@arg_params{ @{ $self->_p->_param_names } } = @param_arrays;
$self->_p->_arg_params(\%arg_params);
my @aux_arrays = (
map { AI::MXNet::NDArray->zeros($_->[0]->shape, dtype => $_->[0]->dtype) }
@{ $self->_p->_exec_group->_p->aux_arrays }
);
my %aux_params;
@aux_params{ @{ $self->_p->_aux_names } } = @aux_arrays;
$self->_p->_aux_params(\%aux_params);
}
if($shared_module and $shared_module->optimizer_initialized)
{
$self->borrow_optimizer($shared_module)
}
}
=head2 reshape
Reshape the module for new input shapes.
Parameters
----------
:$data_shapes : ArrayRef[AI::MXNet::DataDesc]
Typically is $data_iter->provide_data.
:$label_shapes= : Maybe[ArrayRef[AI::MXNet::DataDesc]]
Typically is $data_iter->provide_label.
=cut
method reshape(
ArrayRef[AI::MXNet::DataDesc|NameShape] :$data_shapes,
Maybe[ArrayRef[AI::MXNet::DataDesc|NameShape]] :$label_shapes=
)
{
assert($self->binded);
($data_shapes, $label_shapes) = $self->_parse_data_desc(
$self->data_names, $self->label_names, $data_shapes, $label_shapes
);
$self->_p->_data_shapes($data_shapes);
$self->_p->_label_shapes($label_shapes);
$self->_p->_exec_group->reshape($self->_p->_data_shapes, $self->_p->_label_shapes);
}
method init_optimizer(
Str|AI::MXNet::KVStore :$kvstore='local',
Optimizer :$optimizer='sgd',
HashRef :$optimizer_params={ learning_rate => 0.01 },
Bool :$force_init=0
)
{
assert($self->binded and $self->params_initialized);
if($self->optimizer_initialized and not $force_init)
{
$self->logger->warning('optimizer already initialized, ignoring...');
return;
}
if($self->_p->_params_dirty)
{
$self->_sync_params_from_devices;
}
my ($kvstore, $update_on_kvstore) = _create_kvstore(
$kvstore,
scalar(@{$self->_p->_context}),
$self->_p->_arg_params
);
my $batch_size = $self->_p->_exec_group->_p->batch_size;
if($kvstore and $kvstore->type =~ /dist/ and $kvstore->type =~ /_sync/)
{
$batch_size *= $kvstore->num_workers;
}
my $rescale_grad = 1/$batch_size;
if(not blessed $optimizer)
{
my %idx2name;
if($update_on_kvstore)
{
@idx2name{ 0..@{$self->_p->_exec_group->param_names}-1 } = @{$self->_p->_exec_group->param_names};
}
else
{
for my $k (0..@{$self->_p->_context}-1)
{
@idx2name{ map { $_ + $k } 0..@{$self->_p->_exec_group->param_names}-1 } = @{$self->_p->_exec_group->param_names};
}
}
if(not exists $optimizer_params->{rescale_grad})
{
$optimizer_params->{rescale_grad} = $rescale_grad;
}
$optimizer = AI::MXNet::Optimizer->create(
$optimizer,
sym => $self->symbol,
param_idx2name => \%idx2name,
%{ $optimizer_params }
);
if($optimizer->rescale_grad != $rescale_grad)
{
AI::MXNet::Logging->warning(
"Optimizer created manually outside Module but rescale_grad "
."is not normalized to 1.0/batch_size/num_workers (%s vs. %s). "
."Is this intended?",
$optimizer->rescale_grad, $rescale_grad
);
}
}
$self->_p->_optimizer($optimizer);
$self->_p->_kvstore($kvstore);
$self->_p->_update_on_kvstore($update_on_kvstore);
$self->_p->_updater(undef);
if($kvstore)
{
# copy initialized local parameters to kvstore
_initialize_kvstore(
kvstore => $kvstore,
param_arrays => $self->_p->_exec_group->_p->param_arrays,
arg_params => $self->_p->_arg_params,
param_names => $self->_p->_param_names,
update_on_kvstore => $update_on_kvstore
);
}
if($update_on_kvstore)
{
$kvstore->set_optimizer($self->_p->_optimizer);
}
else
{
$self->_p->_updater(AI::MXNet::Optimizer->get_updater($optimizer));
}
$self->optimizer_initialized(1);
if($self->_p->_preload_opt_states)
{
$self->load_optimizer_states($self->_p->_preload_opt_states);
$self->_p->_preload_opt_states(undef);
}
}
=head2 borrow_optimizer
Borrow optimizer from a shared module. Used in bucketing, where exactly the same
optimizer (esp. kvstore) is used.
Parameters
----------
shared_module : AI::MXNet::Module
=cut
method borrow_optimizer(AI::MXNet::Module $shared_module)
{
assert($shared_module->optimizer_initialized);
$self->_p->_optimizer($shared_module->_p->_optimizer);
$self->_p->_kvstore($shared_module->_p->_kvstore);
$self->_p->_update_on_kvstore($shared_module->_p->_update_on_kvstore);
$self->_p->_updater($shared_module->_p->_updater);
$self->optimizer_initialized(1);
}
method forward(
AI::MXNet::DataBatch $data_batch,
Maybe[Bool] :$is_train=
)
{
assert($self->binded and $self->params_initialized);
# If starting to do the inference, force rebind the module.
if($self->label_shapes and not $data_batch->label)
{
confess(
"If you are trying to do inference, rebind module ".
"with 'force_rebind=True' and 'for_training=False'"
);
}
my @curr_data_shapes = map { $_->shape } @{ $self->data_shapes };
my @new_data_shapes = map { $_->shape } @{ $data_batch->data };
if(Data::Dumper->Dump(\@curr_data_shapes) ne Data::Dumper->Dump(\@new_data_shapes))
{
my $new_dshape;
if($data_batch->can('provide_data') and $data_batch->provide_data)
{
$new_dshape = $data_batch->provide_data;
}
else
{
$new_dshape = [];
zip(sub {
my ($i, $shape) = @_;
push @{ $new_dshape }, AI::MXNet::DataDesc->new(
$i->name, $shape, $i->dtype, $i->layout
);
}, $self->data_shapes, \@new_data_shapes);
}
my $new_lshape;
if($data_batch->can('provide_label') and $data_batch->provide_label)
{
$new_lshape = $data_batch->provide_label;
}
elsif($data_batch->can('label') and $data_batch->label)
{
$new_lshape = [];
zip(sub {
my ($i, $j) = @_;
push @{ $new_lshape }, AI::MXNet::DataDesc->new(
$i->name, $j->shape, $i->dtype, $i->layout
);
}, $self->label_shapes, $data_batch->label);
}
$self->reshape(data_shapes => $new_dshape, label_shapes => $new_lshape);
}
$self->_p->_exec_group->forward($data_batch, $is_train);
}
method backward(Maybe[AI::MXNet::NDArray|ArrayRef[AI::MXNet::NDArray]] $out_grads=)
{
assert($self->binded and $self->params_initialized);
$self->_p->_exec_group->backward($out_grads);
}
method update()
{
assert($self->binded and $self->params_initialized and $self->optimizer_initialized);
$self->_p->_params_dirty(1);
if($self->_p->_update_on_kvstore)
{
_update_params_on_kvstore(
$self->_p->_exec_group->_p->param_arrays,
$self->_p->_exec_group->_p->grad_arrays,
$self->_p->_kvstore,
$self->_p->_exec_group->param_names
);
}
else
{
_update_params(
$self->_p->_exec_group->_p->param_arrays,
$self->_p->_exec_group->_p->grad_arrays,
$self->_p->_updater,
scalar(@{ $self->_p->_context}),
$self->_p->_kvstore,
$self->_p->_exec_group->param_names
);
}
}
method get_outputs(Bool $merge_multi_context=1)
{
assert($self->binded and $self->params_initialized);
return $self->_p->_exec_group->get_outputs($merge_multi_context);
}
method get_input_grads(Bool $merge_multi_context=1)
{
assert($self->binded and $self->params_initialized and $self->inputs_need_grad);
return $self->_p->_exec_group->get_input_grads($merge_multi_context);
}
method get_states(Bool $merge_multi_context=1)
{
assert($self->binded and $self->params_initialized);
return $self->_p->_exec_group->get_states($merge_multi_context);
}
method set_states(:$states=, :$value=)
{
assert($self->binded and $self->params_initialized);
return $self->_p->_exec_group->set_states($states, $value);
}
method update_metric(
AI::MXNet::EvalMetric $eval_metric,
ArrayRef[AI::MXNet::NDArray] $labels
)
{
$self->_p->_exec_group->update_metric($eval_metric, $labels);
}
=head2 _sync_params_from_devices
Synchronize parameters from devices to CPU. This function should be called after
calling 'update' that updates the parameters on the devices, before one can read the
latest parameters from $self->_arg_params and $self->_aux_params.
=cut
method _sync_params_from_devices()
{
$self->_p->_exec_group->get_params($self->_p->_arg_params, $self->_p->_aux_params);
$self->_p->_params_dirty(0);
}
method save_optimizer_states(Str $fname)
{
assert($self->optimizer_initialized);
if($self->_p->_update_on_kvstore)
{
$self->_p->_kvstore->save_optimizer_states($fname);
}
else
{
open(F, ">:raw", "$fname") or confess("can't open $fname for writing: $!");
print F $self->_p->_updater->get_states();
close(F);
}
}
method load_optimizer_states(Str $fname)
{
assert($self->optimizer_initialized);
if($self->_p->_update_on_kvstore)
{
$self->_p->_kvstore->load_optimizer_states($fname);
}
else
{
open(F, "<:raw", "$fname") or confess("can't open $fname for reading: $!");
my $data;
{ local($/) = undef; $data = <F>; }
close(F);
$self->_p->_updater->set_states($data);
}
}
method install_monitor(AI::MXNet::Monitor $mon)
{
assert($self->binded);
$self->_p->_exec_group->install_monitor($mon);
}
method _updater()
{
$self->_p->_updater;
}
method _kvstore()
{
$self->_p->_kvstore;
}
1;