perl-package/AI-MXNet/lib/AI/MXNet/RNN/Cell.pm (1,233 lines of code) (raw):

package AI::MXNet::RNN::Params; use Mouse; use AI::MXNet::Function::Parameters; =head1 NAME AI::MXNet::RNN::Params =cut =head1 DESCRIPTION A container for holding variables. Used by RNN cells for parameter sharing between cells. Parameters ---------- prefix : str All variables name created by this container will be prepended with the prefix =cut has '_prefix' => (is => 'ro', init_arg => 'prefix', isa => 'Str', default => ''); has '_params' => (is => 'rw', init_arg => undef); around BUILDARGS => sub { my $orig = shift; my $class = shift; return $class->$orig(prefix => $_[0]) if @_ == 1; return $class->$orig(@_); }; sub BUILD { my $self = shift; $self->_params({}); } =head2 get Get a variable with the name or create a new one if does not exist. Parameters ---------- $name : str name of the variable @kwargs: more arguments that are passed to mx->sym->Variable call =cut method get(Str $name, @kwargs) { $name = $self->_prefix . $name; if(not exists $self->_params->{$name}) { $self->_params->{$name} = AI::MXNet::Symbol->Variable($name, @kwargs); } return $self->_params->{$name}; } package AI::MXNet::RNN::Cell::Base; =head1 NAME AI::MXNet::RNNCell::Base =cut =head1 DESCRIPTION Abstract base class for RNN cells Parameters ---------- prefix : str prefix for name of layers (and name of weight if params is undef) params : AI::MXNet::RNN::Params or undef container for weight sharing between cells. created if undef. =cut use AI::MXNet::Base; use Mouse; use overload "&{}" => sub { my $self = shift; sub { $self->call(@_) } }; has '_prefix' => (is => 'rw', init_arg => 'prefix', isa => 'Str', default => ''); has '_params' => (is => 'rw', init_arg => 'params', isa => 'Maybe[AI::MXNet::RNN::Params]'); has [qw/_own_params _modified _init_counter _counter /] => (is => 'rw', init_arg => undef); around BUILDARGS => sub { my $orig = shift; my $class = shift; return $class->$orig(prefix => $_[0]) if @_ == 1; return $class->$orig(@_); }; sub BUILD { my $self = shift; if(not defined $self->_params) { $self->_own_params(1); $self->_params(AI::MXNet::RNN::Params->new($self->_prefix)); } else { $self->_own_params(0); } $self->_modified(0); $self->reset; } =head2 reset Reset before re-using the cell for another graph =cut method reset() { $self->_init_counter(-1); $self->_counter(-1); } =head2 call Construct symbol for one step of RNN. Parameters ---------- $inputs : mx->sym->Variable input symbol, 2D, batch * num_units $states : mx->sym->Variable or ArrayRef[AI::MXNet::Symbol] state from previous step or begin_state(). Returns ------- $output : AI::MXNet::Symbol output symbol $states : ArrayRef[AI::MXNet::Symbol] state to next step of RNN. Can be called via overloaded &{}: &{$cell}($inputs, $states); =cut method call(AI::MXNet::Symbol $inputs, AI::MXNet::Symbol|ArrayRef[AI::MXNet::Symbol] $states) { confess("Not Implemented"); } method _gate_names() { ['']; } =head2 params Parameters of this cell =cut method params() { $self->_own_params(0); return $self->_params; } =head2 state_shape shape(s) of states =cut method state_shape() { return [map { $_->{shape} } @{ $self->state_info }]; } =head2 state_info shape and layout information of states =cut method state_info() { confess("Not Implemented"); } =head2 begin_state Initial state for this cell. Parameters ---------- :$func : sub ref, default is AI::MXNet::Symbol->can('zeros') Function for creating initial state. Can be AI::MXNet::Symbol->can('zeros'), AI::MXNet::Symbol->can('uniform'), AI::MXNet::Symbol->can('Variable') etc. Use AI::MXNet::Symbol->can('Variable') if you want to directly feed the input as states. @kwargs : more keyword arguments passed to func. For example mean, std, dtype, etc. Returns ------- $states : ArrayRef[AI::MXNet::Symbol] starting states for first RNN step =cut method begin_state(CodeRef :$func=AI::MXNet::Symbol->can('zeros'), @kwargs) { assert( (not $self->_modified), "After applying modifier cells (e.g. DropoutCell) the base " ."cell cannot be called directly. Call the modifier cell instead." ); my @states; my $func_needs_named_name = $func ne AI::MXNet::Symbol->can('Variable'); for my $info (@{ $self->state_info }) { $self->_init_counter($self->_init_counter + 1); my @name = (sprintf("%sbegin_state_%d", $self->_prefix, $self->_init_counter)); my %info = %{ $info//{} }; if($func_needs_named_name) { unshift(@name, 'name'); } else { if(exists $info{__layout__}) { $info{kwargs} = { __layout__ => delete $info{__layout__} }; } } my %kwargs = (@kwargs, %info); my $state = &{$func}( 'AI::MXNet::Symbol', @name, %kwargs ); push @states, $state; } return \@states; } =head2 unpack_weights Unpack fused weight matrices into separate weight matrices Parameters ---------- $args : HashRef[AI::MXNet::NDArray] hash ref containing packed weights. usually from AI::MXNet::Module->get_output() Returns ------- $args : HashRef[AI::MXNet::NDArray] hash ref with weights associated with this cell, unpacked. =cut method unpack_weights(HashRef[AI::MXNet::NDArray] $args) { my %args = %{ $args }; my $h = $self->_num_hidden; for my $group_name ('i2h', 'h2h') { my $weight = delete $args{ sprintf('%s%s_weight', $self->_prefix, $group_name) }; my $bias = delete $args{ sprintf('%s%s_bias', $self->_prefix, $group_name) }; enumerate(sub { my ($j, $name) = @_; my $wname = sprintf('%s%s%s_weight', $self->_prefix, $group_name, $name); $args->{$wname} = $weight->slice([$j*$h,($j+1)*$h-1])->copy; my $bname = sprintf('%s%s%s_bias', $self->_prefix, $group_name, $name); $args->{$bname} = $bias->slice([$j*$h,($j+1)*$h-1])->copy; }, $self->_gate_names); } return \%args; } =head2 pack_weights Pack fused weight matrices into common weight matrices Parameters ---------- args : HashRef[AI::MXNet::NDArray] hash ref containing unpacked weights. Returns ------- $args : HashRef[AI::MXNet::NDArray] hash ref with weights associated with this cell, packed. =cut method pack_weights(HashRef[AI::MXNet::NDArray] $args) { my %args = %{ $args }; my $h = $self->_num_hidden; for my $group_name ('i2h', 'h2h') { my @weight; my @bias; for my $name (@{ $self->_gate_names }) { my $wname = sprintf('%s%s%s_weight', $self->_prefix, $group_name, $name); push @weight, delete $args{$wname}; my $bname = sprintf('%s%s%s_bias', $self->_prefix, $group_name, $name); push @bias, delete $args{$bname}; } $args{ sprintf('%s%s_weight', $self->_prefix, $group_name) } = AI::MXNet::NDArray->concatenate( \@weight ); $args{ sprintf('%s%s_bias', $self->_prefix, $group_name) } = AI::MXNet::NDArray->concatenate( \@bias ); } return \%args; } =head2 unroll Unroll an RNN cell across time steps. Parameters ---------- :$length : Int number of steps to unroll :$inputs : AI::MXNet::Symbol, array ref of Symbols, or undef if inputs is a single Symbol (usually the output of Embedding symbol), it should have shape of [$batch_size, $length, ...] if layout == 'NTC' (batch, time series) or ($length, $batch_size, ...) if layout == 'TNC' (time series, batch). If inputs is a array ref of symbols (usually output of previous unroll), they should all have shape ($batch_size, ...). If inputs is undef, a placeholder variables are automatically created. :$begin_state : array ref of Symbol input states. Created by begin_state() or output state of another cell. Created from begin_state() if undef. :$input_prefix : str prefix for automatically created input placehodlers. :$layout : str layout of input symbol. Only used if the input is a single Symbol. :$merge_outputs : Bool If 0, returns outputs as an array ref of Symbols. If 1, concatenates the output across the time steps and returns a single symbol with the shape [$batch_size, $length, ...) if the layout equal to 'NTC', or [$length, $batch_size, ...) if the layout equal tp 'TNC'. If undef, output whatever is faster Returns ------- $outputs : array ref of Symbol or Symbol output symbols. $states : Symbol or nested list of Symbol has the same structure as begin_state() =cut method unroll( Int $length, Maybe[AI::MXNet::Symbol|ArrayRef[AI::MXNet::Symbol]] :$inputs=, Maybe[AI::MXNet::Symbol|ArrayRef[AI::MXNet::Symbol]] :$begin_state=, Str :$input_prefix='', Str :$layout='NTC', Maybe[Bool] :$merge_outputs= ) { $self->reset; my $axis = index($layout, 'T'); if(not defined $inputs) { $inputs = [ map { AI::MXNet::Symbol->Variable("${input_prefix}t${_}_data") } (0..$length-1) ]; } elsif(blessed($inputs)) { assert( (@{ $inputs->list_outputs() } == 1), "unroll doesn't allow grouped symbol as input. Please " ."convert to list first or let unroll handle slicing" ); $inputs = AI::MXNet::Symbol->SliceChannel( $inputs, axis => $axis, num_outputs => $length, squeeze_axis => 1 ); } else { assert(@$inputs == $length); } $begin_state //= $self->begin_state; my $states = $begin_state; my $outputs; my @inputs = @{ $inputs }; for my $i (0..$length-1) { my $output; ($output, $states) = &{$self}( $inputs[$i], $states ); push @$outputs, $output; } if($merge_outputs) { @$outputs = map { AI::MXNet::Symbol->expand_dims($_, axis => $axis) } @$outputs; $outputs = AI::MXNet::Symbol->Concat(@$outputs, dim => $axis); } return($outputs, $states); } method _get_activation($inputs, $activation, @kwargs) { if(not ref $activation) { return AI::MXNet::Symbol->Activation($inputs, act_type => $activation, @kwargs); } else { return &{$activation}($inputs, @kwargs); } } method _cells_state_shape($cells) { return [map { @{ $_->state_shape } } @$cells]; } method _cells_state_info($cells) { return [map { @{ $_->state_info } } @$cells]; } method _cells_begin_state($cells, @kwargs) { return [map { @{ $_->begin_state(@kwargs) } } @$cells]; } method _cells_unpack_weights($cells, $args) { $args = $_->unpack_weights($args) for @$cells; return $args; } method _cells_pack_weights($cells, $args) { $args = $_->pack_weights($args) for @$cells; return $args; } package AI::MXNet::RNN::Cell; use Mouse; extends 'AI::MXNet::RNN::Cell::Base'; =head1 NAME AI::MXNet::RNN::Cell =cut =head1 DESCRIPTION Simple recurrent neural network cell Parameters ---------- num_hidden : int number of units in output symbol activation : str or Symbol, default 'tanh' type of activation function prefix : str, default 'rnn_' prefix for name of layers (and name of weight if params is undef) params : AI::MXNet::RNNParams or undef container for weight sharing between cells. created if undef. =cut has '_num_hidden' => (is => 'ro', init_arg => 'num_hidden', isa => 'Int', required => 1); has 'forget_bias' => (is => 'ro', isa => 'Num'); has '_activation' => ( is => 'ro', init_arg => 'activation', isa => 'Activation', default => 'tanh' ); has '+_prefix' => (default => 'rnn_'); has [qw/_iW _iB _hW _hB/] => (is => 'rw', init_arg => undef); around BUILDARGS => sub { my $orig = shift; my $class = shift; return $class->$orig(num_hidden => $_[0]) if @_ == 1; return $class->$orig(@_); }; sub BUILD { my $self = shift; $self->_iW($self->params->get('i2h_weight')); $self->_iB( $self->params->get( 'i2h_bias', (defined($self->forget_bias) ? (init => AI::MXNet::LSTMBias->new(forget_bias => $self->forget_bias)) : () ) ) ); $self->_hW($self->params->get('h2h_weight')); $self->_hB($self->params->get('h2h_bias')); } method state_info() { return [{ shape => [0, $self->_num_hidden], __layout__ => 'NC' }]; } method call(AI::MXNet::Symbol $inputs, SymbolOrArrayOfSymbols $states) { $self->_counter($self->_counter + 1); my $name = sprintf('%st%d_', $self->_prefix, $self->_counter); my $i2h = AI::MXNet::Symbol->FullyConnected( data => $inputs, weight => $self->_iW, bias => $self->_iB, num_hidden => $self->_num_hidden, name => "${name}i2h" ); my $h2h = AI::MXNet::Symbol->FullyConnected( data => @{$states}[0], weight => $self->_hW, bias => $self->_hB, num_hidden => $self->_num_hidden, name => "${name}h2h" ); my $output = $self->_get_activation( $i2h + $h2h, $self->_activation, name => "${name}out" ); return ($output, [$output]); } package AI::MXNet::RNN::LSTMCell; use Mouse; use AI::MXNet::Base; extends 'AI::MXNet::RNN::Cell'; =head1 NAME AI::MXNet::RNN::LSTMCell =cut =head1 DESCRIPTION Long-Short Term Memory (LSTM) network cell. Parameters ---------- num_hidden : int number of units in output symbol prefix : str, default 'lstm_' prefix for name of layers (and name of weight if params is undef) params : AI::MXNet::RNN::Params or None container for weight sharing between cells. created if undef. forget_bias : bias added to forget gate, default 1.0. Jozefowicz et al. 2015 recommends setting this to 1.0 =cut has '+_prefix' => (default => 'lstm_'); has '+_activation' => (init_arg => undef); has '+forget_bias' => (is => 'ro', isa => 'Num', default => 1); method state_info() { return [{ shape => [0, $self->_num_hidden], __layout__ => 'NC' } , { shape => [0, $self->_num_hidden], __layout__ => 'NC' }]; } method _gate_names() { [qw/_i _f _c _o/]; } method call(AI::MXNet::Symbol $inputs, SymbolOrArrayOfSymbols $states) { $self->_counter($self->_counter + 1); my $name = sprintf('%st%d_', $self->_prefix, $self->_counter); my @states = @{ $states }; my $i2h = AI::MXNet::Symbol->FullyConnected( data => $inputs, weight => $self->_iW, bias => $self->_iB, num_hidden => $self->_num_hidden*4, name => "${name}i2h" ); my $h2h = AI::MXNet::Symbol->FullyConnected( data => $states[0], weight => $self->_hW, bias => $self->_hB, num_hidden => $self->_num_hidden*4, name => "${name}h2h" ); my $gates = $i2h + $h2h; my @slice_gates = @{ AI::MXNet::Symbol->SliceChannel( $gates, num_outputs => 4, name => "${name}slice" ) }; my $in_gate = AI::MXNet::Symbol->Activation( $slice_gates[0], act_type => "sigmoid", name => "${name}i" ); my $forget_gate = AI::MXNet::Symbol->Activation( $slice_gates[1], act_type => "sigmoid", name => "${name}f" ); my $in_transform = AI::MXNet::Symbol->Activation( $slice_gates[2], act_type => "tanh", name => "${name}c" ); my $out_gate = AI::MXNet::Symbol->Activation( $slice_gates[3], act_type => "sigmoid", name => "${name}o" ); my $next_c = AI::MXNet::Symbol->_plus( $forget_gate * $states[1], $in_gate * $in_transform, name => "${name}state" ); my $next_h = AI::MXNet::Symbol->_mul( $out_gate, AI::MXNet::Symbol->Activation( $next_c, act_type => "tanh" ), name => "${name}out" ); return ($next_h, [$next_h, $next_c]); } package AI::MXNet::RNN::GRUCell; use Mouse; use AI::MXNet::Base; extends 'AI::MXNet::RNN::Cell'; =head1 NAME AI::MXNet::RNN::GRUCell =cut =head1 DESCRIPTION Gated Rectified Unit (GRU) network cell. Note: this is an implementation of the cuDNN version of GRUs (slight modification compared to Cho et al. 2014). Parameters ---------- num_hidden : int number of units in output symbol prefix : str, default 'gru_' prefix for name of layers (and name of weight if params is undef) params : AI::MXNet::RNN::Params or undef container for weight sharing between cells. created if undef. =cut has '+_prefix' => (default => 'gru_'); method _gate_names() { [qw/_r _z _o/]; } method call(AI::MXNet::Symbol $inputs, SymbolOrArrayOfSymbols $states) { $self->_counter($self->_counter + 1); my $name = sprintf('%st%d_', $self->_prefix, $self->_counter); my $prev_state_h = @{ $states }[0]; my $i2h = AI::MXNet::Symbol->FullyConnected( data => $inputs, weight => $self->_iW, bias => $self->_iB, num_hidden => $self->_num_hidden*3, name => "${name}i2h" ); my $h2h = AI::MXNet::Symbol->FullyConnected( data => $prev_state_h, weight => $self->_hW, bias => $self->_hB, num_hidden => $self->_num_hidden*3, name => "${name}h2h" ); my ($i2h_r, $i2h_z); ($i2h_r, $i2h_z, $i2h) = @{ AI::MXNet::Symbol->SliceChannel( $i2h, num_outputs => 3, name => "${name}_i2h_slice" ) }; my ($h2h_r, $h2h_z); ($h2h_r, $h2h_z, $h2h) = @{ AI::MXNet::Symbol->SliceChannel( $h2h, num_outputs => 3, name => "${name}_h2h_slice" ) }; my $reset_gate = AI::MXNet::Symbol->Activation( $i2h_r + $h2h_r, act_type => "sigmoid", name => "${name}_r_act" ); my $update_gate = AI::MXNet::Symbol->Activation( $i2h_z + $h2h_z, act_type => "sigmoid", name => "${name}_z_act" ); my $next_h_tmp = AI::MXNet::Symbol->Activation( $i2h + $reset_gate * $h2h, act_type => "tanh", name => "${name}_h_act" ); my $next_h = AI::MXNet::Symbol->_plus( (1 - $update_gate) * $next_h_tmp, $update_gate * $prev_state_h, name => "${name}out" ); return ($next_h, [$next_h]); } package AI::MXNet::RNN::FusedCell; use Mouse; use AI::MXNet::Types; use AI::MXNet::Base; extends 'AI::MXNet::RNN::Cell::Base'; =head1 NAME AI::MXNet::RNN::FusedCell =cut =head1 DESCRIPTION Fusing RNN layers across time step into one kernel. Improves speed but is less flexible. Currently only supported if using cuDNN on GPU. =cut has '_num_hidden' => (is => 'ro', isa => 'Int', init_arg => 'num_hidden', required => 1); has '_num_layers' => (is => 'ro', isa => 'Int', init_arg => 'num_layers', default => 1); has '_dropout' => (is => 'ro', isa => 'Num', init_arg => 'dropout', default => 0); has '_get_next_state' => (is => 'ro', isa => 'Bool', init_arg => 'get_next_state', default => 0); has '_bidirectional' => (is => 'ro', isa => 'Bool', init_arg => 'bidirectional', default => 0); has 'forget_bias' => (is => 'ro', isa => 'Num', default => 1); has 'initializer' => (is => 'rw', isa => 'Maybe[AI::MXNet::Initializer]'); has '_mode' => ( is => 'ro', isa => enum([qw/rnn_relu rnn_tanh lstm gru/]), init_arg => 'mode', default => 'lstm' ); has [qw/_parameter _directions/] => (is => 'rw', init_arg => undef); around BUILDARGS => sub { my $orig = shift; my $class = shift; return $class->$orig(num_hidden => $_[0]) if @_ == 1; return $class->$orig(@_); }; sub BUILD { my $self = shift; if(not $self->_prefix) { $self->_prefix($self->_mode.'_'); } if(not defined $self->initializer) { $self->initializer( AI::MXNet::Xavier->new( factor_type => 'in', magnitude => 2.34 ) ); } if(not $self->initializer->isa('AI::MXNet::FusedRNN')) { $self->initializer( AI::MXNet::FusedRNN->new( init => $self->initializer, num_hidden => $self->_num_hidden, num_layers => $self->_num_layers, mode => $self->_mode, bidirectional => $self->_bidirectional, forget_bias => $self->forget_bias ) ); } $self->_parameter($self->params->get('parameters', init => $self->initializer)); $self->_directions($self->_bidirectional ? [qw/l r/] : ['l']); } method state_info() { my $b = @{ $self->_directions }; my $n = $self->_mode eq 'lstm' ? 2 : 1; return [map { +{ shape => [$b*$self->_num_layers, 0, $self->_num_hidden], __layout__ => 'LNC' } } 0..$n-1]; } method _gate_names() { return { rnn_relu => [''], rnn_tanh => [''], lstm => [qw/_i _f _c _o/], gru => [qw/_r _z _o/] }->{ $self->_mode }; } method _num_gates() { return scalar(@{ $self->_gate_names }) } method _slice_weights($arr, $li, $lh) { my %args; my @gate_names = @{ $self->_gate_names }; my @directions = @{ $self->_directions }; my $b = @directions; my $p = 0; for my $layer (0..$self->_num_layers-1) { for my $direction (@directions) { for my $gate (@gate_names) { my $name = sprintf('%s%s%d_i2h%s_weight', $self->_prefix, $direction, $layer, $gate); my $size; if($layer > 0) { $size = $b*$lh*$lh; $args{$name} = $arr->slice([$p,$p+$size-1])->reshape([$lh, $b*$lh]); } else { $size = $li*$lh; $args{$name} = $arr->slice([$p,$p+$size-1])->reshape([$lh, $li]); } $p += $size; } for my $gate (@gate_names) { my $name = sprintf('%s%s%d_h2h%s_weight', $self->_prefix, $direction, $layer, $gate); my $size = $lh**2; $args{$name} = $arr->slice([$p,$p+$size-1])->reshape([$lh, $lh]); $p += $size; } } } for my $layer (0..$self->_num_layers-1) { for my $direction (@directions) { for my $gate (@gate_names) { my $name = sprintf('%s%s%d_i2h%s_bias', $self->_prefix, $direction, $layer, $gate); $args{$name} = $arr->slice([$p,$p+$lh-1]); $p += $lh; } for my $gate (@gate_names) { my $name = sprintf('%s%s%d_h2h%s_bias', $self->_prefix, $direction, $layer, $gate); $args{$name} = $arr->slice([$p,$p+$lh-1]); $p += $lh; } } } assert($p == $arr->size, "Invalid parameters size for FusedRNNCell"); return %args; } method unpack_weights(HashRef[AI::MXNet::NDArray] $args) { my %args = %{ $args }; my $arr = delete $args{ $self->_parameter->name }; my $b = @{ $self->_directions }; my $m = $self->_num_gates; my $h = $self->_num_hidden; my $num_input = int(int(int($arr->size/$b)/$h)/$m) - ($self->_num_layers - 1)*($h+$b*$h+2) - $h - 2; my %nargs = $self->_slice_weights($arr, $num_input, $self->_num_hidden); %args = (%args, map { $_ => $nargs{$_}->copy } keys %nargs); return \%args } method pack_weights(HashRef[AI::MXNet::NDArray] $args) { my %args = %{ $args }; my $b = @{ $self->_directions }; my $m = $self->_num_gates; my @c = @{ $self->_gate_names }; my $h = $self->_num_hidden; my $w0 = $args{ sprintf('%sl0_i2h%s_weight', $self->_prefix, $c[0]) }; my $num_input = $w0->shape->[1]; my $total = ($num_input+$h+2)*$h*$m*$b + ($self->_num_layers-1)*$m*$h*($h+$b*$h+2)*$b; my $arr = AI::MXNet::NDArray->zeros([$total], ctx => $w0->context, dtype => $w0->dtype); my %nargs = $self->_slice_weights($arr, $num_input, $h); while(my ($name, $nd) = each %nargs) { $nd .= delete $args{ $name }; } $args{ $self->_parameter->name } = $arr; return \%args; } method call(AI::MXNet::Symbol $inputs, SymbolOrArrayOfSymbols $states) { confess("AI::MXNet::RNN::FusedCell cannot be stepped. Please use unroll"); } method unroll( Int $length, Maybe[AI::MXNet::Symbol|ArrayRef[AI::MXNet::Symbol]] :$inputs=, Maybe[AI::MXNet::Symbol|ArrayRef[AI::MXNet::Symbol]] :$begin_state=, Str :$input_prefix='', Str :$layout='NTC', Maybe[Bool] :$merge_outputs= ) { $self->reset; my $axis = index($layout, 'T'); $inputs //= AI::MXNet::Symbol->Variable("${input_prefix}data"); if(blessed($inputs)) { assert( (@{ $inputs->list_outputs() } == 1), "unroll doesn't allow grouped symbol as input. Please " ."convert to list first or let unroll handle slicing" ); if($axis == 1) { AI::MXNet::Logging->warning( "NTC layout detected. Consider using " ."TNC for RNN::FusedCell for faster speed" ); $inputs = AI::MXNet::Symbol->SwapAxis($inputs, dim1 => 0, dim2 => 1); } else { assert($axis == 0, "Unsupported layout $layout"); } } else { assert(@$inputs == $length); $inputs = [map { AI::MXNet::Symbol->expand_dims($_, axis => 0) } @{ $inputs }]; $inputs = AI::MXNet::Symbol->Concat(@{ $inputs }, dim => 0); } $begin_state //= $self->begin_state; my $states = $begin_state; my @states = @{ $states }; my %states; if($self->_mode eq 'lstm') { %states = (state => $states[0], state_cell => $states[1]); } else { %states = (state => $states[0]); } my $rnn = AI::MXNet::Symbol->RNN( data => $inputs, parameters => $self->_parameter, state_size => $self->_num_hidden, num_layers => $self->_num_layers, bidirectional => $self->_bidirectional, p => $self->_dropout, state_outputs => $self->_get_next_state, mode => $self->_mode, name => $self->_prefix.'rnn', %states ); my $outputs; my %attr = (__layout__ => 'LNC'); if(not $self->_get_next_state) { ($outputs, $states) = ($rnn, []); } elsif($self->_mode eq 'lstm') { my @rnn = @{ $rnn }; $rnn[1]->_set_attr(%attr); $rnn[2]->_set_attr(%attr); ($outputs, $states) = ($rnn[0], [$rnn[1], $rnn[2]]); } else { my @rnn = @{ $rnn }; $rnn[1]->_set_attr(%attr); ($outputs, $states) = ($rnn[0], [$rnn[1]]); } if(defined $merge_outputs and not $merge_outputs) { AI::MXNet::Logging->warning( "Call RNN::FusedCell->unroll with merge_outputs=1 " ."for faster speed" ); $outputs = [@ { AI::MXNet::Symbol->SliceChannel( $outputs, axis => 0, num_outputs => $length, squeeze_axis => 1 ) }]; } elsif($axis == 1) { $outputs = AI::MXNet::Symbol->SwapAxis($outputs, dim1 => 0, dim2 => 1); } return ($outputs, $states); } =head2 unfuse Unfuse the fused RNN Returns ------- $cell : AI::MXNet::RNN::SequentialCell unfused cell that can be used for stepping, and can run on CPU. =cut method unfuse() { my $stack = AI::MXNet::RNN::SequentialCell->new; my $get_cell = { rnn_relu => sub { AI::MXNet::RNN::Cell->new( num_hidden => $self->_num_hidden, activation => 'relu', prefix => shift ) }, rnn_tanh => sub { AI::MXNet::RNN::Cell->new( num_hidden => $self->_num_hidden, activation => 'tanh', prefix => shift ) }, lstm => sub { AI::MXNet::RNN::LSTMCell->new( num_hidden => $self->_num_hidden, prefix => shift ) }, gru => sub { AI::MXNet::RNN::GRUCell->new( num_hidden => $self->_num_hidden, prefix => shift ) }, }->{ $self->_mode }; for my $i (0..$self->_num_layers-1) { if($self->_bidirectional) { $stack->add( AI::MXNet::RNN::BidirectionalCell->new( $get_cell->(sprintf('%sl%d_', $self->_prefix, $i)), $get_cell->(sprintf('%sr%d_', $self->_prefix, $i)), output_prefix => sprintf('%sbi_%s_%d', $self->_prefix, $self->_mode, $i) ) ); } else { $stack->add($get_cell->(sprintf('%sl%d_', $self->_prefix, $i))); } } return $stack; } package AI::MXNet::RNN::SequentialCell; use Mouse; use AI::MXNet::Base; extends 'AI::MXNet::RNN::Cell::Base'; =head1 NAME AI:MXNet::RNN::SequentialCell =cut =head1 DESCRIPTION Sequentially stacking multiple RNN cells Parameters ---------- params : AI::MXNet::RNN::Params or undef container for weight sharing between cells. created if undef. =cut has [qw/_override_cell_params _cells/] => (is => 'rw', init_arg => undef); sub BUILD { my ($self, $original_arguments) = @_; $self->_override_cell_params(defined $original_arguments->{params}); $self->_cells([]); } =head2 add Append a cell to the stack. Parameters ---------- $cell : AI::MXNet::RNN::Cell::Base =cut method add(AI::MXNet::RNN::Cell::Base $cell) { push @{ $self->_cells }, $cell; if($self->_override_cell_params) { assert( $cell->_own_params, "Either specify params for SequentialRNNCell " ."or child cells, not both." ); %{ $cell->params->_params } = (%{ $cell->params->_params }, %{ $self->params->_params }); } %{ $self->params->_params } = (%{ $self->params->_params }, %{ $cell->params->_params }); } method state_info() { return $self->_cells_state_info($self->_cells); } method begin_state(@kwargs) { assert( (not $self->_modified), "After applying modifier cells (e.g. DropoutCell) the base " ."cell cannot be called directly. Call the modifier cell instead." ); return $self->_cells_begin_state($self->_cells, @kwargs); } method unpack_weights(HashRef[AI::MXNet::NDArray] $args) { return $self->_cells_unpack_weights($self->_cells, $args) } method pack_weights(HashRef[AI::MXNet::NDArray] $args) { return $self->_cells_pack_weights($self->_cells, $args); } method call($inputs, $states) { $self->_counter($self->_counter + 1); my @next_states; my $p = 0; for my $cell (@{ $self->_cells }) { assert(not $cell->isa('AI::MXNet::BidirectionalCell')); my $n = scalar(@{ $cell->state_info }); my $state = [@{ $states }[$p..$p+$n-1]]; $p += $n; ($inputs, $state) = &{$cell}($inputs, $state); push @next_states, $state; } return ($inputs, [map { @$_} @next_states]); } method unroll( Int $length, Maybe[AI::MXNet::Symbol|ArrayRef[AI::MXNet::Symbol]] :$inputs=, Maybe[AI::MXNet::Symbol|ArrayRef[AI::MXNet::Symbol]] :$begin_state=, Str :$input_prefix='', Str :$layout='NTC', Maybe[Bool] :$merge_outputs= ) { my $num_cells = @{ $self->_cells }; $begin_state //= $self->begin_state; my $p = 0; my $states; my @next_states; enumerate(sub { my ($i, $cell) = @_; my $n = @{ $cell->state_info }; $states = [@{$begin_state}[$p..$p+$n-1]]; $p += $n; ($inputs, $states) = $cell->unroll( $length, inputs => $inputs, input_prefix => $input_prefix, begin_state => $states, layout => $layout, merge_outputs => ($i < $num_cells-1) ? undef : $merge_outputs ); push @next_states, $states; }, $self->_cells); return ($inputs, [map { @{ $_ } } @next_states]); } package AI::MXNet::RNN::BidirectionalCell; use Mouse; use AI::MXNet::Base; extends 'AI::MXNet::RNN::Cell::Base'; =head1 NAME AI::MXNet::RNN::BidirectionalCell =cut =head1 DESCRIPTION Bidirectional RNN cell Parameters ---------- l_cell : AI::MXNet::RNN::Cell::Base cell for forward unrolling r_cell : AI::MXNet::RNN::Cell::Base cell for backward unrolling output_prefix : str, default 'bi_' prefix for name of output =cut has 'l_cell' => (is => 'ro', isa => 'AI::MXNet::RNN::Cell::Base', required => 1); has 'r_cell' => (is => 'ro', isa => 'AI::MXNet::RNN::Cell::Base', required => 1); has '_output_prefix' => (is => 'ro', init_arg => 'output_prefix', isa => 'Str', default => 'bi_'); has [qw/_override_cell_params _cells/] => (is => 'rw', init_arg => undef); around BUILDARGS => sub { my $orig = shift; my $class = shift; if(@_ >= 2 and blessed $_[0] and blessed $_[1]) { my $l_cell = shift(@_); my $r_cell = shift(@_); return $class->$orig( l_cell => $l_cell, r_cell => $r_cell, @_ ); } return $class->$orig(@_); }; sub BUILD { my ($self, $original_arguments) = @_; $self->_override_cell_params(defined $original_arguments->{params}); if($self->_override_cell_params) { assert( ($self->l_cell->_own_params and $self->r_cell->_own_params), "Either specify params for BidirectionalCell ". "or child cells, not both." ); %{ $self->l_cell->params->_params } = (%{ $self->l_cell->params->_params }, %{ $self->params->_params }); %{ $self->r_cell->params->_params } = (%{ $self->r_cell->params->_params }, %{ $self->params->_params }); } %{ $self->params->_params } = (%{ $self->params->_params }, %{ $self->l_cell->params->_params }); %{ $self->params->_params } = (%{ $self->params->_params }, %{ $self->r_cell->params->_params }); $self->_cells([$self->l_cell, $self->r_cell]); } method unpack_weights(HashRef[AI::MXNet::NDArray] $args) { return $self->_cells_unpack_weights($self->_cells, $args) } method pack_weights(HashRef[AI::MXNet::NDArray] $args) { return $self->_cells_pack_weights($self->_cells, $args); } method call($inputs, $states) { confess("Bidirectional cannot be stepped. Please use unroll"); } method state_info() { return $self->_cells_state_info($self->_cells); } method begin_state(@kwargs) { assert((not $self->_modified), "After applying modifier cells (e.g. DropoutCell) the base " ."cell cannot be called directly. Call the modifier cell instead." ); return $self->_cells_begin_state($self->_cells, @kwargs); } method unroll( Int $length, Maybe[AI::MXNet::Symbol|ArrayRef[AI::MXNet::Symbol]] :$inputs=, Maybe[AI::MXNet::Symbol|ArrayRef[AI::MXNet::Symbol]] :$begin_state=, Str :$input_prefix='', Str :$layout='NTC', Maybe[Bool] :$merge_outputs= ) { my $axis = index($layout, 'T'); if(not defined $inputs) { $inputs = [ map { AI::MXNet::Symbol->Variable("${input_prefix}t${_}_data") } (0..$length-1) ]; } elsif(blessed($inputs)) { assert( (@{ $inputs->list_outputs() } == 1), "unroll doesn't allow grouped symbol as input. Please " ."convert to list first or let unroll handle slicing" ); $inputs = [ @{ AI::MXNet::Symbol->SliceChannel( $inputs, axis => $axis, num_outputs => $length, squeeze_axis => 1 ) }]; } else { assert(@$inputs == $length); } $begin_state //= $self->begin_state; my $states = $begin_state; my ($l_cell, $r_cell) = @{ $self->_cells }; my ($l_outputs, $l_states) = $l_cell->unroll( $length, inputs => $inputs, begin_state => [@{$states}[0..@{$l_cell->state_info}-1]], layout => $layout, merge_outputs => $merge_outputs ); my ($r_outputs, $r_states) = $r_cell->unroll( $length, inputs => [reverse @{$inputs}], begin_state => [@{$states}[@{$l_cell->state_info}..@{$states}-1]], layout => $layout, merge_outputs => $merge_outputs ); if(not defined $merge_outputs) { $merge_outputs = ( blessed $l_outputs and $l_outputs->isa('AI::MXNet::Symbol') and blessed $r_outputs and $r_outputs->isa('AI::MXNet::Symbol') ); if(not $merge_outputs) { if(blessed $l_outputs and $l_outputs->isa('AI::MXNet::Symbol')) { $l_outputs = [ @{ AI::MXNet::Symbol->SliceChannel( $l_outputs, axis => $axis, num_outputs => $length, squeeze_axis => 1 ) } ]; } if(blessed $r_outputs and $r_outputs->isa('AI::MXNet::Symbol')) { $r_outputs = [ @{ AI::MXNet::Symbol->SliceChannel( $r_outputs, axis => $axis, num_outputs => $length, squeeze_axis => 1 ) } ]; } } } if($merge_outputs) { $l_outputs = [@{ $l_outputs }]; $r_outputs = [@{ AI::MXNet::Symbol->reverse(blessed $r_outputs ? $r_outputs : @{ $r_outputs }, axis=>$axis) }]; } else { $r_outputs = [reverse(@{ $r_outputs })]; } my $outputs = []; zip(sub { my ($i, $l_o, $r_o) = @_; push @$outputs, AI::MXNet::Symbol->Concat( $l_o, $r_o, dim=>(1+($merge_outputs?1:0)), name => $merge_outputs ? sprintf('%sout', $self->_output_prefix) : sprintf('%st%d', $self->_output_prefix, $i) ); }, [0..@{ $l_outputs }-1], [@{ $l_outputs }], [@{ $r_outputs }]); if($merge_outputs) { $outputs = @{ $outputs }[0]; } $states = [$l_states, $r_states]; return($outputs, $states); } package AI::MXNet::RNN::ModifierCell; use Mouse; use AI::MXNet::Base; extends 'AI::MXNet::RNN::Cell::Base'; =head1 NAME AI::MXNet::RNN::ModifierCell =cut =head1 DESCRIPTION Base class for modifier cells. A modifier cell takes a base cell, apply modifications on it (e.g. Dropout), and returns a new cell. After applying modifiers the base cell should no longer be called directly. The modifer cell should be used instead. =cut has 'base_cell' => (is => 'ro', isa => 'AI::MXNet::RNN::Cell::Base', required => 1); around BUILDARGS => sub { my $orig = shift; my $class = shift; if(@_%2) { my $base_cell = shift; return $class->$orig(base_cell => $base_cell, @_); } return $class->$orig(@_); }; sub BUILD { my $self = shift; $self->base_cell->_modified(1); } method params() { $self->_own_params(0); return $self->base_cell->params; } method state_info() { return $self->base_cell->state_info; } method begin_state(CodeRef :$init_sym=AI::MXNet::Symbol->can('zeros'), @kwargs) { assert( (not $self->_modified), "After applying modifier cells (e.g. DropoutCell) the base " ."cell cannot be called directly. Call the modifier cell instead." ); $self->base_cell->_modified(0); my $begin_state = $self->base_cell->begin_state(func => $init_sym, @kwargs); $self->base_cell->_modified(1); return $begin_state; } method unpack_weights(HashRef[AI::MXNet::NDArray] $args) { return $self->base_cell->unpack_weights($args) } method pack_weights(HashRef[AI::MXNet::NDArray] $args) { return $self->base_cell->pack_weights($args) } method call(AI::MXNet::Symbol $inputs, SymbolOrArrayOfSymbols $states) { confess("Not Implemented"); } package AI::MXNet::RNN::DropoutCell; use Mouse; extends 'AI::MXNet::RNN::ModifierCell'; has [qw/dropout_outputs dropout_states/] => (is => 'ro', isa => 'Num', default => 0); =head1 NAME AI::MXNet::RNN::DropoutCell =cut =head1 DESCRIPTION Apply the dropout on base cell =cut method call(AI::MXNet::Symbol $inputs, SymbolOrArrayOfSymbols $states) { my ($output, $states) = &{$self->base_cell}($inputs, $states); if($self->dropout_outputs > 0) { $output = AI::MXNet::Symbol->Dropout(data => $output, p => $self->dropout_outputs); } if($self->dropout_states > 0) { $states = [map { AI::MXNet::Symbol->Dropout(data => $_, p => $self->dropout_states) } @{ $states }]; } return ($output, $states); } package AI::MXNet::RNN::ZoneoutCell; use Mouse; use AI::MXNet::Base; extends 'AI::MXNet::RNN::ModifierCell'; has [qw/zoneout_outputs zoneout_states/] => (is => 'ro', isa => 'Num', default => 0); has 'prev_output' => (is => 'rw', init_arg => undef); =head1 NAME AI::MXNet::RNN::ZoneoutCell =cut =head1 DESCRIPTION Apply Zoneout on base cell. =cut sub BUILD { my $self = shift; assert( (not $self->base_cell->isa('AI::MXNet::RNN::FusedCell')), "FusedRNNCell doesn't support zoneout. ". "Please unfuse first." ); assert( (not $self->base_cell->isa('AI::MXNet::RNN::BidirectionalCell')), "BidirectionalCell doesn't support zoneout since it doesn't support step. ". "Please add ZoneoutCell to the cells underneath instead." ); assert( (not $self->base_cell->isa('AI::MXNet::RNN::SequentialCell') or not $self->_bidirectional), "Bidirectional SequentialCell doesn't support zoneout. ". "Please add ZoneoutCell to the cells underneath instead." ); } method reset() { $self->SUPER::reset; $self->prev_output(undef); } method call(AI::MXNet::Symbol $inputs, SymbolOrArrayOfSymbols $states) { my ($cell, $p_outputs, $p_states) = ($self->base_cell, $self->zoneout_outputs, $self->zoneout_states); my ($next_output, $next_states) = &{$cell}($inputs, $states); my $mask = sub { my ($p, $like) = @_; AI::MXNet::Symbol->Dropout( AI::MXNet::Symbol->ones_like( $like ), p => $p ); }; my $prev_output = $self->prev_output || AI::MXNet::Symbol->zeros(shape => [0, 0]); my $output = $p_outputs != 0 ? AI::MXNet::Symbol->where( &{$mask}($p_outputs, $next_output), $next_output, $prev_output ) : $next_output; my @states; if($p_states != 0) { zip(sub { my ($new_s, $old_s) = @_; push @states, AI::MXNet::Symbol->where( &{$mask}($p_states, $new_s), $new_s, $old_s ); }, $next_states, $states); } $self->prev_output($output); return ($output, @states ? \@states : $next_states); } package AI::MXNet::RNN::ResidualCell; use Mouse; use AI::MXNet::Base; extends 'AI::MXNet::RNN::ModifierCell'; =head1 NAME AI::MXNet::RNN::ResidualCell =cut =head1 DESCRIPTION Adds residual connection as described in Wu et al, 2016 (https://arxiv.org/abs/1609.08144). Output of the cell is output of the base cell plus input. =cut method call(AI::MXNet::Symbol $inputs, SymbolOrArrayOfSymbols $states) { my $output; ($output, $states) = &{$self->base_cell}($inputs, $states); $output = AI::MXNet::Symbol->elemwise_add($output, $inputs, name => $output->name.'_plus_residual'); return ($output, $states) } method unroll( Int $length, Maybe[AI::MXNet::Symbol|ArrayRef[AI::MXNet::Symbol]] :$inputs=, Maybe[AI::MXNet::Symbol|ArrayRef[AI::MXNet::Symbol]] :$begin_state=, Str :$input_prefix='', Str :$layout='NTC', Maybe[Bool] :$merge_outputs= ) { $self->reset; $self->base_cell->_modified(0); my ($outputs, $states) = $self->base_cell->unroll($length, inputs=>$inputs, begin_state=>$begin_state, layout=>$layout, merge_outputs=>$merge_outputs); $self->base_cell->_modified(1); $merge_outputs //= (blessed($outputs) and $outputs->isa('AI::MXNet::Symbol')); ($inputs) = _normalize_sequence($length, $inputs, $layout, $merge_outputs); if($merge_outputs) { $outputs = AI::MXNet::Symbol->elemwise_add($outputs, $inputs, name => $outputs->name . "_plus_residual"); } else { my @temp; zip(sub { my ($output_sym, $input_sym) = @_; push @temp, AI::MXNet::Symbol->elemwise_add($output_sym, $input_sym, name=>$output_sym->name."_plus_residual"); }, [@{ $outputs }], [@{ $inputs }]); $outputs = \@temp; } return ($outputs, $states); } func _normalize_sequence($length, $inputs, $layout, $merge, $in_layout=) { assert((defined $inputs), "unroll(inputs=>undef) has been deprecated. ". "Please create input variables outside unroll." ); my $axis = index($layout, 'T'); my $in_axis = defined $in_layout ? index($in_layout, 'T') : $axis; if(blessed($inputs)) { if(not $merge) { assert( (@{ $inputs->list_outputs() } == 1), "unroll doesn't allow grouped symbol as input. Please " ."convert to list first or let unroll handle splitting" ); $inputs = [ @{ AI::MXNet::Symbol->split( $inputs, axis => $in_axis, num_outputs => $length, squeeze_axis => 1 ) }]; } } else { assert(not defined $length or @$inputs == $length); if($merge) { $inputs = [map { AI::MXNet::Symbol->expand_dims($_, axis=>$axis) } @{ $inputs }]; $inputs = AI::MXNet::Symbol->Concat(@{ $inputs }, dim=>$axis); $in_axis = $axis; } } if(blessed($inputs) and $axis != $in_axis) { $inputs = AI::MXNet::Symbol->swapaxes($inputs, dim0=>$axis, dim1=>$in_axis); } return ($inputs, $axis); } 1;