in sagemaker_notebook_instance/containers/relationship_extraction/package/training.py [0:0]
def train_fn(args):
print(args)
# load tokenizer
tokenizer = RelationshipTokenizer.from_pretrained(
pretrained_model_name_or_path='bert-base-uncased',
contains_entity_tokens=False
)
tokenizer.save(file_path=Path(args.model_dir, 'tokenizer.json'), pretty=True)
# load data
train_file_path = Path(args.train_data_dir, 'train.txt')
test_file_path = Path(args.test_data_dir, 'test.txt')
# construct label encoder
labels = list(label_set(train_file_path))
label_encoder = LabelEncoder.from_str_list(sorted(labels))
print('Using the following label encoder mappings:\n\n', label_encoder)
label_encoder.save(file_path=str(Path(args.model_dir, 'label_encoder.json')))
# prepare datasets
model_size = 512
tokenizer.set_truncation(model_size)
tokenizer.set_padding(model_size)
train_dataset = RelationStatementDataset(
file_path=train_file_path,
tokenizer=tokenizer,
label_encoder=label_encoder
)
test_dataset = RelationStatementDataset(
file_path=test_file_path,
tokenizer=tokenizer,
label_encoder=label_encoder
)
batch_size = 16
train_dataloader = torch.utils.data.DataLoader(
dataset=train_dataset,
batch_size=batch_size,
num_workers=4
)
test_dataloader = torch.utils.data.DataLoader(
dataset=test_dataset,
batch_size=batch_size,
num_workers=4
)
# create model
relationship_encoder = RelationshipEncoderLightningModule(
tokenizer,
label_encoder,
learning_rate=float(args.learning_rate)
)
checkpoint_callback = ModelCheckpoint(
monitor='valid_loss',
filepath=str(Path(args.model_dir, 'model'))
)
# train model
trainer = Trainer(
default_root_dir=args.output_dir,
accumulate_grad_batches=2,
gradient_clip_val=1.0,
max_epochs=1,
weights_summary='full',
gpus=args.gpus,
checkpoint_callback=checkpoint_callback,
fast_dev_run=True
)
trainer.fit(relationship_encoder, train_dataloader, test_dataloader)