in sagemaker/26_document_ai_donut/scripts/train.py [0:0]
def training_function(args):
# set seed
set_seed(args.seed)
# Set up logging
logger = logging.getLogger(__name__)
logging.basicConfig(
level=logging.getLevelName("INFO"),
handlers=[logging.StreamHandler(sys.stdout)],
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
# load datasets
train_dataset = load_from_disk(args.dataset_path)
image_size = list(torch.tensor(train_dataset[0]["pixel_values"][0]).shape) # height, width
logger.info(f"loaded train_dataset length is: {len(train_dataset)}")
# Load processor and set up new special tokens
processor = DonutProcessor.from_pretrained(args.model_id)
# add new special tokens to tokenizer and resize feature extractor
special_tokens = args.special_tokens.split(",")
processor.tokenizer.add_special_tokens({"additional_special_tokens": special_tokens})
processor.feature_extractor.size = image_size[::-1] # should be (width, height)
processor.feature_extractor.do_align_long_axis = False
# Load model from huggingface.co
config = VisionEncoderDecoderConfig.from_pretrained(args.model_id, use_cache=False if args.gradient_checkpointing else True)
model = VisionEncoderDecoderModel.from_pretrained(args.model_id, config=config)
# Resize embedding layer to match vocabulary size & adjust our image size and output sequence lengths
model.decoder.resize_token_embeddings(len(processor.tokenizer))
model.config.encoder.image_size = image_size
model.config.decoder.max_length = len(max(train_dataset["labels"], key=len))
# Add task token for decoder to start
model.config.pad_token_id = processor.tokenizer.pad_token_id
model.config.decoder_start_token_id = processor.tokenizer.convert_tokens_to_ids(['<s>'])[0]
# Arguments for training
output_dir = "/tmp"
training_args = Seq2SeqTrainingArguments(
output_dir=output_dir,
num_train_epochs=args.epochs,
learning_rate=args.lr,
per_device_train_batch_size=args.per_device_train_batch_size,
bf16=True,
tf32=True,
gradient_checkpointing=args.gradient_checkpointing,
logging_steps=10,
save_total_limit=1,
evaluation_strategy="no",
save_strategy="epoch",
)
# Create Trainer
trainer = Seq2SeqTrainer(
model=model,
args=training_args,
train_dataset=train_dataset,
)
# Start training
trainer.train()
# save model and processor
trainer.model.save_pretrained("/opt/ml/model/")
processor.save_pretrained("/opt/ml/model/")
# copy inference script
os.makedirs("/opt/ml/model/code", exist_ok=True)
shutil.copyfile(
os.path.join(os.path.dirname(__file__), "inference.py"),
"/opt/ml/model/code/inference.py",
)