def train()

in code/scripts/bert_mt.py [0:0]


def train(model_name, train_input):
    """Training function."""
    ## Arguments
    log_interval = 100
    batch_size = 32
    lr = 1e-5
    optimizer = 'adam'
    accumulate = None
    epochs = 20

    ## Load BERT model and vocabulary
    bert, vocabulary = nlp.model.get_model('bert_12_768_12',
                                           dataset_name='wiki_multilingual_uncased',
                                           pretrained=True,
                                           ctx=ctx,
                                           use_pooler=False,
                                           use_decoder=False,
                                           use_classifier=False)

    model = BERTForICSL(bert, num_slot_labels=len(label2idx), num_intents=len(intent2idx))
    model.initialize(init=mx.init.Uniform(0.1), ctx=ctx)
    model.hybridize(static_alloc=True)

    icsl_loss_function = ICSLLoss()
    icsl_loss_function.hybridize(static_alloc=True)

    ic_metric = mx.metric.Accuracy()
    sl_metric = mx.metric.Accuracy()

    ## Load labeled data
    field_separator = nlp.data.Splitter('\t')
    # fields to select from the file: utterance, slot labels, intent, uid
    field_indices = [1, 3, 4, 0]
    train_data = nlp.data.TSVDataset(filename=train_input,
                                     field_separator=field_separator,
                                     num_discard_samples=1,
                                     field_indices=field_indices)

    # use the vocabulary from pre-trained model for tokenization
    bert_tokenizer = nlp.data.BERTTokenizer(vocabulary, lower=True)
    train_data_transform = train_data.transform(fn=lambda x: icsl_transform(x, vocabulary, label2idx, intent2idx, bert_tokenizer)[0])
    # create data loader
    pad_token_id = vocabulary[PAD]
    pad_label_id = label2idx[PAD]
    batchify_fn = nlp.data.batchify.Tuple(
        nlp.data.batchify.Stack(),
        nlp.data.batchify.Pad(axis=0, pad_val=pad_token_id),
        nlp.data.batchify.Pad(axis=0, pad_val=pad_label_id),
        nlp.data.batchify.Stack('float32'),
        nlp.data.batchify.Stack('float32'))
    train_sampler = nlp.data.FixedBucketSampler(lengths=[len(item[1]) for item in train_data_transform],
                                                batch_size=batch_size,
                                                shuffle=True)
    train_dataloader = mx.gluon.data.DataLoader(train_data_transform,
                                                batchify_fn=batchify_fn,
                                                batch_sampler=train_sampler)

    optimizer_params = {'learning_rate': lr}
    trainer = gluon.Trainer(model.collect_params(), optimizer,
                            optimizer_params, update_on_kvstore=False)

    # Collect differentiable parameters
    params = [p for p in model.collect_params().values() if p.grad_req != 'null']
    # Set grad_req if gradient accumulation is required
    if accumulate:
        for p in params:
            p[1].grad_req = 'add'
    # Fix BERT embeddings if required
    for p in model.collect_params().items():
        if 'embed' in p[0]:
            p[1].grad_req = 'null'

    epoch_tic = time.time()
    total_num = 0
    log_num = 0
    for epoch_id in range(epochs):
        step_loss = 0
        tic = time.time()
        # train on labeled data
        for batch_id, data in enumerate(train_dataloader):
            # forward and backward
            with mx.autograd.record():
                if data[0].shape[0] < len(ctx):
                    data = split_and_load(data, [ctx[0]])
                else:
                    data = split_and_load(data, ctx)
                for chunk in data:
                    _, token_ids, slot_label, intent_label, valid_length = chunk

                    log_num += len(token_ids)
                    total_num += len(token_ids)

                    # forward computation
                    intent_pred, slot_pred = model(token_ids, valid_length)
                    ls = icsl_loss_function(intent_pred, slot_pred, intent_label, slot_label, valid_length - 2).mean()

                    if accumulate:
                        ls = ls / accumulate
                    ls.backward()
                    step_loss += ls.asscalar()

            # update
            if not accumulate or (batch_id + 1) % accumulate == 0:
                trainer.allreduce_grads()
                nlp.utils.clip_grad_global_norm(params, 1)
                trainer.update(1, ignore_stale_grad=True)

            if (batch_id + 1) % log_interval == 0:
                toc = time.time()
                # update metrics
                ic_metric.update([intent_label], [intent_pred])
                sl_metric.update(*process_seq_labels(slot_label, slot_pred, ignore_id=pad_label_id))
                log.info('Epoch: {}, Batch: {}/{}, speed: {:.2f} samples/s, lr={:.7f}, loss={:.4f}, intent acc={:.3f}, slot acc={:.3f}'
                         .format(epoch_id,
                                 batch_id,
                                 len(train_dataloader),
                                 log_num / (toc - tic),
                                 trainer.learning_rate,
                                 step_loss / log_interval,
                                 ic_metric.get()[1],
                                 sl_metric.get()[1]))
                tic = time.time()
                step_loss = 0
                log_num = 0

        mx.nd.waitall()
        epoch_toc = time.time()
        log.info('Time cost: {:.2f} s, Speed: {:.2f} samples/s'
                 .format(epoch_toc - epoch_tic, total_num/(epoch_toc - epoch_tic)))
        model.save_parameters(os.path.join(model_dir, model_name + '.params'))