perl-package/AI-MXNet/lib/AI/MXNet/RNN/IO.pm (206 lines of code) (raw):
package AI::MXNet::RNN::IO;
use strict;
use warnings;
use AI::MXNet::Base;
use AI::MXNet::Function::Parameters;
=encoding UTF-8
=head1 NAME
AI::MXNet::RNN::IO - Functions for constructing recurrent neural networks.
=cut
=head1 DESCRIPTION
Functions for constructing recurrent neural networks.
=cut
=head2 encode_sentences
Encode sentences and (optionally) build a mapping
from string tokens to integer indices. Unknown keys
will be added to vocabulary.
Parameters
----------
$sentences : array ref of array refs of str
A array ref of sentences to encode. Each sentence
should be a array ref of string tokens.
:$vocab : undef or hash ref of str -> int
Optional input Vocabulary
:$invalid_label : int, default -1
Index for invalid token, like <end-of-sentence>
:$invalid_key : str, default '\n'
Key for invalid token. Uses '\n' for end
of sentence by default.
:$start_label=0 : int
lowest index.
Returns
-------
$result : array ref of array refs of int
encoded sentences
$vocab : hash ref of str -> int
result vocabulary
=cut
method encode_sentences(
ArrayRef[ArrayRef] $sentences,
Maybe[HashRef] :$vocab=,
Int :$invalid_label=-1,
Str :$invalid_key="\n",
Int :$start_label=0
)
{
my $idx = $start_label;
my $new_vocab;
if(not defined $vocab)
{
$vocab = { $invalid_key => $invalid_label };
$new_vocab = 1;
}
else
{
$new_vocab = 0;
}
my @res;
for my $sent (@{ $sentences })
{
my @coded;
for my $word (@{ $sent })
{
if(not exists $vocab->{ $word })
{
assert($new_vocab, "Unknown token: $word");
if($idx == $invalid_label)
{
$idx += 1;
}
$vocab->{$word} = $idx;
$idx += 1;
}
push @coded, $vocab->{ $word };
}
push @res, \@coded;
}
return (\@res, $vocab);
}
package AI::MXNet::BucketSentenceIter;
=encoding UTF-8
=head1 NAME
AI::MXNet::BucketSentenceIter
=cut
=head1 DESCRIPTION
Simple bucketing iterator for language model.
Label for each step is constructed from data of
next step.
=cut
=head2 new
Parameters
----------
sentences : array ref of array refs of int
encoded sentences
batch_size : int
batch_size of data
invalid_label : int, default -1
key for invalid label, e.g. <end-of-sentence>
dtype : str, default 'float32'
data type
buckets : array ref of int
size of data buckets. Automatically generated if undef.
data_name : str, default 'data'
name of data
label_name : str, default 'softmax_label'
name of label
layout : str
format of data and label. 'NT' means (batch_size, length)
and 'TN' means (length, batch_size).
=cut
use Mouse;
use AI::MXNet::Base;
use List::Util qw(shuffle max);
extends 'AI::MXNet::DataIter';
has 'sentences' => (is => 'ro', isa => 'ArrayRef[ArrayRef]', required => 1);
has '+batch_size' => (is => 'ro', isa => 'Int', required => 1);
has 'invalid_label' => (is => 'ro', isa => 'Int', default => -1);
has 'data_name' => (is => 'ro', isa => 'Str', default => 'data');
has 'label_name' => (is => 'ro', isa => 'Str', default => 'softmax_label');
has 'dtype' => (is => 'ro', isa => 'Dtype', default => 'float32');
has 'layout' => (is => 'ro', isa => 'Str', default => 'NT');
has 'buckets' => (is => 'rw', isa => 'Maybe[ArrayRef[Int]]');
has [qw/data nddata ndlabel
major_axis default_bucket_key
provide_data provide_label
idx curr_idx
/] => (is => 'rw', init_arg => undef);
sub BUILD
{
my $self = shift;
if(not defined $self->buckets)
{
my @buckets;
my $p = pdl([map { scalar(@$_) } @{ $self->sentences }]);
enumerate(sub {
my ($i, $j) = @_;
if($j >= $self->batch_size)
{
push @buckets, $i;
}
}, $p->histogram(1,0,$p->max+1)->unpdl);
$self->buckets(\@buckets);
}
@{ $self->buckets } = sort { $a <=> $b } @{ $self->buckets };
my $ndiscard = 0;
$self->data([map { [] } 0..@{ $self->buckets }-1]);
for my $i (0..@{$self->sentences}-1)
{
my $buck = bisect_left($self->buckets, scalar(@{ $self->sentences->[$i] }));
if($buck == @{ $self->buckets })
{
$ndiscard += 1;
next;
}
my $buff = AI::MXNet::NDArray->full(
[$self->buckets->[$buck]],
$self->invalid_label,
dtype => $self->dtype
)->aspdl;
$buff->slice([0, @{ $self->sentences->[$i] }-1]) .= pdl($self->sentences->[$i]);
push @{ $self->data->[$buck] }, $buff;
}
$self->data([map { pdl(PDL::Type->new(DTYPE_MX_TO_PDL->{$self->dtype}), $_) } @{$self->data}]);
AI::MXNet::Logging->warning("discarded $ndiscard sentences longer than the largest bucket.")
if $ndiscard;
$self->nddata([]);
$self->ndlabel([]);
$self->major_axis(index($self->layout, 'N'));
$self->default_bucket_key(max(@{ $self->buckets }));
my $shape;
if($self->major_axis == 0)
{
$shape = [$self->batch_size, $self->default_bucket_key];
}
elsif($self->major_axis == 1)
{
$shape = [$self->default_bucket_key, $self->batch_size];
}
else
{
confess("Invalid layout ${\ $self->layout }: Must by NT (batch major) or TN (time major)");
}
$self->provide_data([
AI::MXNet::DataDesc->new(
name => $self->data_name,
shape => $shape,
dtype => $self->dtype,
layout => $self->layout
)
]);
$self->provide_label([
AI::MXNet::DataDesc->new(
name => $self->label_name,
shape => $shape,
dtype => $self->dtype,
layout => $self->layout
)
]);
$self->idx([]);
enumerate(sub {
my ($i, $buck) = @_;
my $buck_len = $buck->shape->at(-1);
for my $j (0..($buck_len - $self->batch_size))
{
if(not $j%$self->batch_size)
{
push @{ $self->idx }, [$i, $j];
}
}
}, $self->data);
$self->curr_idx(0);
$self->reset;
}
method reset()
{
$self->curr_idx(0);
@{ $self->idx } = shuffle(@{ $self->idx });
$self->nddata([]);
$self->ndlabel([]);
for my $buck (@{ $self->data })
{
$buck = pdl_shuffle($buck);
my $label = $buck->zeros;
$label->slice([0, -2], 'X') .= $buck->slice([1, -1], 'X');
$label->slice([-1, -1], 'X') .= $self->invalid_label;
push @{ $self->nddata }, AI::MXNet::NDArray->array($buck, dtype => $self->dtype);
push @{ $self->ndlabel }, AI::MXNet::NDArray->array($label, dtype => $self->dtype);
}
}
method next()
{
return undef if($self->curr_idx == @{ $self->idx });
my ($i, $j) = @{ $self->idx->[$self->curr_idx] };
$self->curr_idx($self->curr_idx + 1);
my ($data, $label);
if($self->major_axis == 1)
{
$data = $self->nddata->[$i]->slice([$j, $j+$self->batch_size-1])->T;
$label = $self->ndlabel->[$i]->slice([$j, $j+$self->batch_size-1])->T;
}
else
{
$data = $self->nddata->[$i]->slice([$j, $j+$self->batch_size-1]);
$label = $self->ndlabel->[$i]->slice([$j, $j+$self->batch_size-1]);
}
return AI::MXNet::DataBatch->new(
data => [$data],
label => [$label],
bucket_key => $self->buckets->[$i],
pad => 0,
provide_data => [
AI::MXNet::DataDesc->new(
name => $self->data_name,
shape => $data->shape,
dtype => $self->dtype,
layout => $self->layout
)
],
provide_label => [
AI::MXNet::DataDesc->new(
name => $self->label_name,
shape => $label->shape,
dtype => $self->dtype,
layout => $self->layout
)
],
);
}
1;