def train_fn()

in sagemaker_notebook_instance/containers/relationship_extraction/package/training.py [0:0]


def train_fn(args):
    print(args)
    
    # load tokenizer
    tokenizer = RelationshipTokenizer.from_pretrained(
        pretrained_model_name_or_path='bert-base-uncased',
        contains_entity_tokens=False
    )
    tokenizer.save(file_path=Path(args.model_dir, 'tokenizer.json'), pretty=True)
    
    # load data
    train_file_path = Path(args.train_data_dir, 'train.txt')
    test_file_path = Path(args.test_data_dir, 'test.txt')
    
    # construct label encoder
    labels = list(label_set(train_file_path))
    label_encoder = LabelEncoder.from_str_list(sorted(labels))
    print('Using the following label encoder mappings:\n\n', label_encoder)
    label_encoder.save(file_path=str(Path(args.model_dir, 'label_encoder.json')))
    
    # prepare datasets
    model_size = 512
    tokenizer.set_truncation(model_size)
    tokenizer.set_padding(model_size)
    train_dataset = RelationStatementDataset(
        file_path=train_file_path,
        tokenizer=tokenizer,
        label_encoder=label_encoder
    )
    test_dataset = RelationStatementDataset(
        file_path=test_file_path,
        tokenizer=tokenizer,
        label_encoder=label_encoder
    )

    batch_size = 16
    train_dataloader = torch.utils.data.DataLoader(
        dataset=train_dataset,
        batch_size=batch_size,
        num_workers=4
    )
    test_dataloader = torch.utils.data.DataLoader(
        dataset=test_dataset,
        batch_size=batch_size,
        num_workers=4
    )
    
    # create model
    relationship_encoder = RelationshipEncoderLightningModule(
        tokenizer,
        label_encoder,
        learning_rate=float(args.learning_rate)
    )

    checkpoint_callback = ModelCheckpoint(
        monitor='valid_loss',
        filepath=str(Path(args.model_dir, 'model'))
    )
    
    # train model
    trainer = Trainer(
        default_root_dir=args.output_dir,
        accumulate_grad_batches=2,
        gradient_clip_val=1.0,
        max_epochs=1,
        weights_summary='full',
        gpus=args.gpus,
        checkpoint_callback=checkpoint_callback,
        fast_dev_run=True
    )
    
    trainer.fit(relationship_encoder, train_dataloader, test_dataloader)