perl-package/AI-MXNet/lib/AI/MXNet/Module/Bucketing.pm (335 lines of code) (raw):

package AI::MXNet::Module::Bucketing; use Mouse; use AI::MXNet::Function::Parameters; use AI::MXNet::Base; =encoding UTF-8 =head1 NAME AI::MXNet::Module::Bucketing =head1 SYNOPSIS my $buckets = [10, 20, 30, 40, 50, 60]; my $start_label = 1; my $invalid_label = 0; my ($train_sentences, $vocabulary) = tokenize_text( './data/ptb.train.txt', start_label => $start_label, invalid_label => $invalid_label ); my ($validation_sentences) = tokenize_text( './data/ptb.test.txt', vocab => $vocabulary, start_label => $start_label, invalid_label => $invalid_label ); my $data_train = mx->rnn->BucketSentenceIter( $train_sentences, $batch_size, buckets => $buckets, invalid_label => $invalid_label ); my $data_val = mx->rnn->BucketSentenceIter( $validation_sentences, $batch_size, buckets => $buckets, invalid_label => $invalid_label ); my $stack = mx->rnn->SequentialRNNCell(); for my $i (0..$num_layers-1) { $stack->add(mx->rnn->LSTMCell(num_hidden => $num_hidden, prefix => "lstm_l${i}_")); } my $sym_gen = sub { my $seq_len = shift; my $data = mx->sym->Variable('data'); my $label = mx->sym->Variable('softmax_label'); my $embed = mx->sym->Embedding( data => $data, input_dim => scalar(keys %$vocabulary), output_dim => $num_embed, name => 'embed' ); $stack->reset; my ($outputs, $states) = $stack->unroll($seq_len, inputs => $embed, merge_outputs => 1); my $pred = mx->sym->Reshape($outputs, shape => [-1, $num_hidden]); $pred = mx->sym->FullyConnected(data => $pred, num_hidden => scalar(keys %$vocabulary), name => 'pred'); $label = mx->sym->Reshape($label, shape => [-1]); $pred = mx->sym->SoftmaxOutput(data => $pred, label => $label, name => 'softmax'); return ($pred, ['data'], ['softmax_label']); }; my $contexts; if(defined $gpus) { $contexts = [map { mx->gpu($_) } split(/,/, $gpus)]; } else { $contexts = mx->cpu(0); } my $model = mx->mod->BucketingModule( sym_gen => $sym_gen, default_bucket_key => $data_train->default_bucket_key, context => $contexts ); $model->fit( $data_train, eval_data => $data_val, eval_metric => mx->metric->Perplexity($invalid_label), kvstore => $kv_store, optimizer => $optimizer, optimizer_params => { learning_rate => $lr, momentum => $mom, wd => $wd, }, initializer => mx->init->Xavier(factor_type => "in", magnitude => 2.34), num_epoch => $num_epoch, batch_end_callback => mx->callback->Speedometer($batch_size, $disp_batches), ($chkp_epoch ? (epoch_end_callback => mx->rnn->do_rnn_checkpoint($stack, $chkp_prefix, $chkp_epoch)) : ()) ); =head1 DESCRIPTION Implements the AI::MXNet::Module::Base API, and allows multiple symbols to be used depending on the `bucket_key` provided by each different mini-batch of data =cut =head2 new Parameters ---------- $sym_gen : subref or any perl object that overloads &{} op A sub when called with a bucket key, returns a list with triple of ($symbol, $data_names, $label_names). $default_bucket_key : str or anything else The key for the default bucket. $logger : Logger $context : AI::MXNet::Context or array ref of AI::MXNet::Context objects Default is cpu(0) $work_load_list : array ref of Num Default is undef, indicating uniform workload. $fixed_param_names: arrayref of str Default is undef, indicating no network parameters are fixed. $state_names : arrayref of str states are similar to data and label, but not provided by data iterator. Instead they are initialized to 0 and can be set by set_states() =cut extends 'AI::MXNet::Module::Base'; has '_sym_gen' => (is => 'ro', init_arg => 'sym_gen', required => 1); has '_default_bucket_key' => (is => 'rw', init_arg => 'default_bucket_key', required => 1); has '_context' => ( is => 'ro', isa => 'AI::MXNet::Context|ArrayRef[AI::MXNet::Context]', lazy => 1, default => sub { AI::MXNet::Context->cpu }, init_arg => 'context' ); has '_work_load_list' => (is => 'rw', init_arg => 'work_load_list', isa => 'ArrayRef[Num]'); has '_curr_module' => (is => 'rw', init_arg => undef); has '_curr_bucket_key' => (is => 'rw', init_arg => undef); has '_buckets' => (is => 'rw', init_arg => undef, default => sub { +{} }); has '_fixed_param_names' => (is => 'rw', isa => 'ArrayRef[Str]', init_arg => 'fixed_param_names'); has '_state_names' => (is => 'rw', isa => 'ArrayRef[Str]', init_arg => 'state_names'); has '_params_dirty' => (is => 'rw', init_arg => undef); sub BUILD { my ($self, $original_params) = @_; $self->_fixed_param_names([]) unless defined $original_params->{fixed_param_names}; $self->_state_names([]) unless defined $original_params->{state_names}; $self->_params_dirty(0); my ($symbol, $data_names, $label_names) = &{$self->_sym_gen}($self->_default_bucket_key); $self->_check_input_names($symbol, $data_names//[], "data", 1); $self->_check_input_names($symbol, $label_names//[], "label", 0); $self->_check_input_names($symbol, $self->_state_names, "state", 1); $self->_check_input_names($symbol, $self->_fixed_param_names, "fixed_param", 1); } method _reset_bind() { $self->binded(0); $self->_buckets({}); $self->_curr_module(undef); $self->_curr_bucket_key(undef); } method data_names() { if($self->binded) { return $self->_curr_module->data_names; } else { return (&{$self->_sym_gen}($self->_default_bucket_key))[1]; } } method output_names() { if($self->binded) { return $self->_curr_module->ouput_names; } else { my ($symbol) = &{$self->_sym_gen}($self->_default_bucket_key); return $symbol->list_ouputs; } } method data_shapes() { assert($self->binded); return $self->_curr_module->data_shapes; } method label_shapes() { assert($self->binded); return $self->_curr_module->label_shapes; } method output_shapes() { assert($self->binded); return $self->_curr_module->output_shapes; } method get_params() { assert($self->binded and $self->params_initialized); $self->_curr_module->_p->_params_dirty($self->_params_dirty); my ($arg_params, $aux_params) = $self->_curr_module->get_params; $self->_params_dirty(0); return ($arg_params, $aux_params); } method set_params( HashRef[AI::MXNet::NDArray] $arg_params, HashRef[AI::MXNet::NDArray] $aux_params, Bool $allow_missing=0, Bool $force_init=1, Bool $allow_extra=0 ) { if(not $allow_missing) { $self->init_params( arg_params => $arg_params, aux_params => $aux_params, allow_missing => $allow_missing, force_init => $force_init, allow_extra => $allow_extra ); return; } if($self->params_initialized and not $force_init) { AI::MXNet::Logging->warning( "Parameters already initialized and force_init=False. " ."set_params call ignored." ); return; } $self->_curr_module->set_params( $arg_params, $aux_params, allow_missing => $allow_missing, force_init => $force_init, allow_extra => $allow_extra ); # because we didn't update self._arg_params, they are dirty now. $self->_params_dirty(1); $self->params_initialized(1); } method init_params( AI::MXNet::Initializer :$initializer=AI::MXNet::Initializer->Uniform(scale => 0.01), Maybe[HashRef[AI::MXNet::NDArray]] :$arg_params=, Maybe[HashRef[AI::MXNet::NDArray]] :$aux_params=, Bool :$allow_missing=0, Bool :$force_init=0, Bool :$allow_extra=0 ) { return if($self->params_initialized and not $force_init); assert($self->binded, 'call bind before initializing the parameters'); $self->_curr_module->init_params( initializer => $initializer, arg_params => $arg_params, aux_params => $aux_params, allow_missing => $allow_missing, force_init => $force_init, allow_extra => $allow_extra ); $self->_params_dirty(0); $self->params_initialized(1); } method get_states(Bool $merge_multi_context=1) { assert($self->binded and $self->params_initialized); $self->_curr_module->get_states($merge_multi_context); } method set_states(:$states=, :$value=) { assert($self->binded and $self->params_initialized); $self->_curr_module->set_states(states => $states, value => $value); } =head2 bind Binding for a AI::MXNet::Module::Bucketing means setting up the buckets and bind the executor for the default bucket key. Executors corresponding to other keys are binded afterwards with switch_bucket. Parameters ---------- :$data_shapes : ArrayRef[AI::MXNet::DataDesc|NameShape] This should correspond to the symbol for the default bucket. :$label_shapes= : Maybe[ArrayRef[AI::MXNet::DataDesc|NameShape]] This should correspond to the symbol for the default bucket. :$for_training : Bool Default is 1. :$inputs_need_grad : Bool Default is 0. :$force_rebind : Bool Default is 0. :$shared_module : AI::MXNet::Module::Bucketing Default is undef. This value is currently not used. :$grad_req : str, array ref of str, hash ref of str to str Requirement for gradient accumulation. Can be 'write', 'add', or 'null' (defaults to 'write'). Can be specified globally (str) or for each argument (array ref, hash ref). :$bucket_key : str bucket key for binding. by default is to use the ->default_bucket_key =cut method bind( ArrayRef[AI::MXNet::DataDesc|NameShape] :$data_shapes, Maybe[ArrayRef[AI::MXNet::DataDesc|NameShape]] :$label_shapes=, Bool :$for_training=1, Bool :$inputs_need_grad=0, Bool :$force_rebind=0, Maybe[AI::MXNet::BaseModule] :$shared_module=, Str|ArrayRef[Str]|HashRef[Str] :$grad_req='write', Maybe[Str] :$bucket_key= ) { # in case we already initialized params, keep it my ($arg_params, $aux_params); if($self->params_initialized) { ($arg_params, $aux_params) = $self->get_params; } # force rebinding is typically used when one want to switch from # training to prediction phase. $self->_reset_bind if $force_rebind; if($self->binded) { $self->logger->warning('Already binded, ignoring bind()'); return; } assert((not defined $shared_module), 'shared_module for BucketingModule is not supported'); $self->for_training($for_training); $self->inputs_need_grad($inputs_need_grad); $self->binded(1); my ($symbol, $data_names, $label_names) = &{$self->_sym_gen}($bucket_key//$self->_default_bucket_key); my $module = AI::MXNet::Module->new( symbol => $symbol, data_names => $data_names, label_names => $label_names, logger => $self->logger, context => $self->_context, work_load_list => $self->_work_load_list, state_names => $self->_state_names, fixed_param_names => $self->_fixed_param_names ); $module->bind( data_shapes => $data_shapes, label_shapes => $label_shapes, for_training => $for_training, inputs_need_grad => $inputs_need_grad, force_rebind => 0, shared_module => undef, grad_req => $grad_req ); $self->_curr_module($module); $self->_curr_bucket_key($self->_default_bucket_key); $self->_buckets->{ $self->_default_bucket_key } = $module; # copy back saved params, if already initialized if($self->params_initialized) { $self->set_params($arg_params, $aux_params); } } =head2 switch_bucket Switch to a different bucket. This will change $self->_curr_module. Parameters ---------- :$bucket_key : str (or any perl object that overloads "" op) The key of the target bucket. :$data_shapes : Maybe[ArrayRef[AI::MXNet::DataDesc|NameShape]] Typically $data_batch->provide_data. :$label_shapes : Maybe[ArrayRef[AI::MXNet::DataDesc|NameShape]] Typically $data_batch->provide_label. =cut method switch_bucket( Maybe[ArrayRef[AI::MXNet::DataDesc|NameShape]] :$data_shapes=, Maybe[ArrayRef[AI::MXNet::DataDesc|NameShape]] :$label_shapes=, :$bucket_key ) { assert($self->binded, 'call bind before switching bucket'); if(not exists $self->_buckets->{ $bucket_key }) { my ($symbol, $data_names, $label_names) = &{$self->_sym_gen}($bucket_key); my $module = AI::MXNet::Module->new( symbol => $symbol, data_names => $data_names, label_names => $label_names, logger => $self->logger, context => $self->_context, work_load_list => $self->_work_load_list ); $module->bind( data_shapes => $data_shapes, label_shapes => $label_shapes, for_training => $self->_curr_module->for_training, inputs_need_grad => $self->_curr_module->inputs_need_grad, force_rebind => 0, shared_module => $self->_buckets->{ $self->_default_bucket_key }, ); $self->_buckets->{ $bucket_key } = $module; } $self->_curr_module($self->_buckets->{ $bucket_key }); $self->_curr_bucket_key($bucket_key); } method init_optimizer( Str :$kvstore='local', Optimizer :$optimizer='sgd', HashRef :$optimizer_params={ learning_rate => 0.01 }, Bool :$force_init=0 ) { assert($self->binded and $self->params_initialized); if($self->optimizer_initialized and not $force_init) { $self->logger->warning('optimizer already initialized, ignoring.'); return; } $self->_curr_module->init_optimizer( kvstore => $kvstore, optimizer => $optimizer, optimizer_params => $optimizer_params, force_init => $force_init ); for my $mod (values %{ $self->_buckets }) { if($mod ne $self->_curr_module) { $mod->borrow_optimizer($self->_curr_module); } } $self->optimizer_initialized(1); } method prepare(AI::MXNet::DataBatch $data_batch) { assert($self->binded and $self->params_initialized); ## perform bind if have not done so yet my $original_bucket_key = $self->_curr_bucket_key; $self->switch_bucket( bucket_key => $data_batch->bucket_key, data_shapes => $data_batch->provide_data, label_shapes => $data_batch->provide_label ); # switch back $self->switch_bucket(bucket_key => $original_bucket_key); } method forward( AI::MXNet::DataBatch $data_batch, Bool :$is_train= ) { assert($self->binded and $self->params_initialized); $self->switch_bucket( bucket_key => $data_batch->bucket_key, data_shapes => $data_batch->provide_data, label_shapes => $data_batch->provide_label ); $self->_curr_module->forward($data_batch, is_train => $is_train); } method backward(Maybe[ArrayRef[AI::MXNet::NDArray]|AI::MXNet::NDArray] $out_grads=) { assert($self->binded and $self->params_initialized); $self->_curr_module->backward($out_grads); } method update() { assert($self->binded and $self->params_initialized and $self->optimizer_initialized); $self->_params_dirty(1); $self->_curr_module->update; } method get_outputs(Bool $merge_multi_context=1) { assert($self->binded and $self->params_initialized); return $self->_curr_module->get_outputs($merge_multi_context); } method get_input_grads(Bool $merge_multi_context=1) { assert($self->binded and $self->params_initialized and $self->inputs_need_grad); return $self->_curr_module->get_input_grads($merge_multi_context); } method update_metric( AI::MXNet::EvalMetric $eval_metric, ArrayRef[AI::MXNet::NDArray] $labels ) { assert($self->binded and $self->params_initialized); $self->_curr_module->update_metric($eval_metric, $labels); } method symbol() { assert($self->binded); return $self->_curr_module->symbol; } method get_symbol() { assert($self->binded); return $self->_buckets->{ $self->_default_bucket_key }->symbol; } method install_monitor(AI::MXNet::Monitor $mon) { assert($self->binded); for my $mod (values %{ $self->_buckets }) { $mod->install_monitor($mon); } } 1;