example/torch/torch_module.py (26 lines of code) (raw):

# pylint: skip-file from data import mnist_iterator import mxnet as mx import numpy as np import logging # define mlp use_torch_criterion = False data = mx.symbol.Variable('data') fc1 = mx.symbol.TorchModule(data_0=data, lua_string='nn.Linear(784, 128)', num_data=1, num_params=2, num_outputs=1, name='fc1') act1 = mx.symbol.TorchModule(data_0=fc1, lua_string='nn.ReLU(false)', num_data=1, num_params=0, num_outputs=1, name='relu1') fc2 = mx.symbol.TorchModule(data_0=act1, lua_string='nn.Linear(128, 64)', num_data=1, num_params=2, num_outputs=1, name='fc2') act2 = mx.symbol.TorchModule(data_0=fc2, lua_string='nn.ReLU(false)', num_data=1, num_params=0, num_outputs=1, name='relu2') fc3 = mx.symbol.TorchModule(data_0=act2, lua_string='nn.Linear(64, 10)', num_data=1, num_params=2, num_outputs=1, name='fc3') if use_torch_criterion: logsoftmax = mx.symbol.TorchModule(data_0=fc3, lua_string='nn.LogSoftMax()', num_data=1, num_params=0, num_outputs=1, name='logsoftmax') # Torch's label starts from 1 label = mx.symbol.Variable('softmax_label') + 1 mlp = mx.symbol.TorchCriterion(data=logsoftmax, label=label, lua_string='nn.ClassNLLCriterion()', name='softmax') else: mlp = mx.symbol.SoftmaxOutput(data=fc3, name='softmax') # data train, val = mnist_iterator(batch_size=100, input_shape = (784,)) # train logging.basicConfig(level=logging.DEBUG) model = mx.model.FeedForward( ctx = mx.cpu(0), symbol = mlp, num_epoch = 20, learning_rate = 0.1, momentum = 0.9, wd = 0.00001) if use_torch_criterion: model.fit(X=train, eval_data=val, eval_metric=mx.metric.Torch()) else: model.fit(X=train, eval_data=val)