def fairseq_train()

in access/fairseq/base.py [0:0]


def fairseq_train(
        preprocessed_dir,
        exp_dir,
        ngpus=None,
        max_tokens=2000,
        arch='fconv_iwslt_de_en',
        pretrained_emb_path=None,
        embeddings_dim=None,
        # Transformer (decoder is the same as encoder for now)
        encoder_embed_dim=512,
        encoder_layers=6,
        encoder_attention_heads=8,
        # encoder_decoder_dim_ratio=1,
        # share_embeddings=True,
        max_epoch=50,
        warmup_updates=None,
        lr=0.1,
        min_lr=1e-9,
        dropout=0.2,
        label_smoothing=0.1,
        lr_scheduler='fixed',
        weight_decay=0.0001,
        criterion='label_smoothed_cross_entropy',
        optimizer='nag',
        validations_before_sari_early_stopping=10,
        fp16=False):
    exp_dir = Path(exp_dir)
    with log_stdout(exp_dir / 'fairseq_train.stdout'):
        preprocessed_dir = Path(preprocessed_dir)
        exp_dir.mkdir(exist_ok=True, parents=True)
        # Copy dictionaries to exp_dir for generation
        shutil.copy(preprocessed_dir / 'dict.complex.txt', exp_dir)
        shutil.copy(preprocessed_dir / 'dict.simple.txt', exp_dir)
        train_parser = options.get_training_parser()
        # if share_embeddings:
        #     assert encoder_decoder_dim_ratio == 1
        args = [
            '--task',
            'translation',
            preprocessed_dir,
            '--raw-text',
            '--source-lang',
            'complex',
            '--target-lang',
            'simple',
            '--save-dir',
            os.path.join(exp_dir, 'checkpoints'),
            '--clip-norm',
            0.1,
            '--criterion',
            criterion,
            '--no-epoch-checkpoints',
            '--save-interval-updates',
            5000,  # Validate every n updates
            '--validations-before-sari-early-stopping',
            validations_before_sari_early_stopping,
            '--arch',
            arch,

            # '--decoder-out-embed-dim', int(embeddings_dim * encoder_decoder_dim_ratio),  # Output dim of decoder
            '--max-tokens',
            max_tokens,
            '--max-epoch',
            max_epoch,
            '--lr-scheduler',
            lr_scheduler,
            '--dropout',
            dropout,
            '--lr',
            lr,
            '--lr-shrink',
            0.5,  # For reduce lr on plateau scheduler
            '--min-lr',
            min_lr,
            '--weight-decay',
            weight_decay,
            '--optimizer',
            optimizer,
            '--label-smoothing',
            label_smoothing,
            '--seed',
            random.randint(1, 1000),
            # '--force-anneal', '200',
            # '--distributed-world-size', '1',
        ]
        if arch == 'transformer':
            args.extend([
                '--encoder-embed-dim',
                encoder_embed_dim,
                '--encoder-ffn-embed-dim',
                4 * encoder_embed_dim,
                '--encoder-layers',
                encoder_layers,
                '--encoder-attention-heads',
                encoder_attention_heads,
                '--decoder-layers',
                encoder_layers,
                '--decoder-attention-heads',
                encoder_attention_heads,
            ])
        if pretrained_emb_path is not None:
            args.extend(['--encoder-embed-path', pretrained_emb_path if pretrained_emb_path is not None else ''])
            args.extend(['--decoder-embed-path', pretrained_emb_path if pretrained_emb_path is not None else ''])
        if embeddings_dim is not None:
            args.extend(['--encoder-embed-dim', embeddings_dim])  # Input and output dim of encoder
            args.extend(['--decoder-embed-dim', embeddings_dim])  # Input dim of decoder
        if ngpus is not None:
            args.extend(['--distributed-world-size', ngpus])
        # if share_embeddings:
        #     args.append('--share-input-output-embed')
        if fp16:
            args.append('--fp16')
        if warmup_updates is not None:
            args.extend(['--warmup-updates', warmup_updates])
        args = [str(arg) for arg in args]
        train_args = options.parse_args_and_arch(train_parser, args)
        train.main(train_args)