def train()

in model_ctx.py [0:0]


def train(dev_dir, logdir, device):
	if not config.joined_vocab:
		spm.SentencePieceTrainer.train(input=f'{dev_dir}/text',
		                               model_prefix=f'{dev_dir}/txt_bpe_ctx',
		                               model_type='bpe',
		                               vocab_size=config.src_vocab_size)
		spm.SentencePieceTrainer.train(input=f'{dev_dir}/cmd',
		                               model_prefix=f'{dev_dir}/cmd_bpe_ctx',
		                               model_type='bpe',
		                               vocab_size=config.tgt_vocab_size, )
		text_tokenizer = spm.SentencePieceProcessor(f'{dev_dir}/txt_bpe_ctx.model')
		cmd_tokenizer = spm.SentencePieceProcessor(f'{dev_dir}/cmd_bpe_ctx.model')

	else:
		spm.SentencePieceTrainer.train(input=f'{dev_dir}/all',
		                               model_prefix=f'{dev_dir}/all_bpe_ctx',
		                               model_type='bpe',
		                               vocab_size=config.src_vocab_size, )
		text_tokenizer = spm.SentencePieceProcessor(f'{dev_dir}/all_bpe_ctx.model')
		cmd_tokenizer = text_tokenizer

	train = pd.read_csv(f'{dev_dir}/train.csv', index_col=0)
	train = train.dropna()
	train['cmd_cleaned'] = train['cmd_cleaned'].apply(lambda cmd: cmd.replace('|', ' |'))
	train['util'] = train.cmd_cleaned.apply(lambda x: x.strip(' $()').split()[0])
	train = train[train.util != ']']
	train = train.reset_index(drop=True)

	mandf = pd.read_csv(f'{dev_dir}/man.csv', index_col=0)
	mandf['ctx']  = mandf.apply(make_ctx, axis=1)
	mandf = mandf.drop_duplicates(subset=('cmd'))
	mandf = mandf.set_index('cmd')

	train['ctx'] = train['util'].map(mandf.ctx)
	train.text_cleaned = train.text_cleaned + ' ' + train.ctx.fillna('')

	train['text_enc'] = train.text_cleaned.progress_apply(text_tokenizer.encode)
	train['cmd_enc'] = train.cmd_cleaned.progress_apply(cmd_tokenizer.encode)


	tdf = train[train.origin == 'original']
	tdf2 = train[train.origin != 'original']
	train, valid = train_test_split(tdf, test_size=500, random_state=SEED)
	train = pd.concat([train, tdf2]).reset_index(drop=True)

	train_ds = MtDataset(train.text_enc, train.cmd_enc, config, bos_id, eos_id, pad_id)
	valid_ds = MtDataset(valid.text_enc, valid.cmd_enc, config, bos_id, eos_id, pad_id)

	model = Transformer(config, pad_id)
	print('# params', sum(p.numel() for p in model.parameters() if p.requires_grad))

	loaders = {
		'train': data.DataLoader(train_ds, batch_size=config.batch_size, shuffle=True),
		'valid': data.DataLoader(valid_ds, batch_size=config.batch_size),
	}

	criterion = nn.CrossEntropyLoss(ignore_index=pad_id)
	optimizer = torch.optim.Adam(model.parameters(), lr=config.optimizer_lr,
	                             weight_decay=config.weight_decay, amsgrad=True)
	callbacks=[
		dl.CheckpointCallback(config.num_epochs),
	]

	callbacks.append( dl.SchedulerCallback(mode="epoch") )
	scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=config.plateau_factor, patience=3, cooldown=2, threshold=1e-3, min_lr=1e-6)

	shutil.rmtree(logdir, ignore_errors=True)
	os.makedirs(logdir, exist_ok=True)

	runner = dl.SupervisedRunner(device=device)
	runner.train(
		model=model,
		loaders=loaders,
		criterion=criterion,
		optimizer=optimizer,
		scheduler=scheduler if config.schedule else None,
		num_epochs=config.num_epochs,
		verbose=True,
		logdir=logdir,
		callbacks=callbacks,
		#     check=True
	)