model_ctx.py (108 lines of code) (raw):

import pandas as pd import numpy as np import torch import torch.nn as nn import torch.utils.data as data import catalyst.dl as dl import sentencepiece as spm from sklearn.model_selection import train_test_split import shutil import os from preprocessing import clean_text from submission_code.tools import MtDataset, Transformer import argparse import sys sys.path.append('../clai/utils') SEED = 1337 np.random.seed(SEED) torch.manual_seed(SEED) torch.cuda.manual_seed_all(SEED) bos_id = 1 eos_id = 2 pad_id = 0 from submission_code import config_ctx config = config_ctx.Config() def train(dev_dir, logdir, device): if not config.joined_vocab: spm.SentencePieceTrainer.train(input=f'{dev_dir}/text', model_prefix=f'{dev_dir}/txt_bpe_ctx', model_type='bpe', vocab_size=config.src_vocab_size) spm.SentencePieceTrainer.train(input=f'{dev_dir}/cmd', model_prefix=f'{dev_dir}/cmd_bpe_ctx', model_type='bpe', vocab_size=config.tgt_vocab_size, ) text_tokenizer = spm.SentencePieceProcessor(f'{dev_dir}/txt_bpe_ctx.model') cmd_tokenizer = spm.SentencePieceProcessor(f'{dev_dir}/cmd_bpe_ctx.model') else: spm.SentencePieceTrainer.train(input=f'{dev_dir}/all', model_prefix=f'{dev_dir}/all_bpe_ctx', model_type='bpe', vocab_size=config.src_vocab_size, ) text_tokenizer = spm.SentencePieceProcessor(f'{dev_dir}/all_bpe_ctx.model') cmd_tokenizer = text_tokenizer train = pd.read_csv(f'{dev_dir}/train.csv', index_col=0) train = train.dropna() train['cmd_cleaned'] = train['cmd_cleaned'].apply(lambda cmd: cmd.replace('|', ' |')) train['util'] = train.cmd_cleaned.apply(lambda x: x.strip(' $()').split()[0]) train = train[train.util != ']'] train = train.reset_index(drop=True) mandf = pd.read_csv(f'{dev_dir}/man.csv', index_col=0) mandf['ctx'] = mandf.apply(make_ctx, axis=1) mandf = mandf.drop_duplicates(subset=('cmd')) mandf = mandf.set_index('cmd') train['ctx'] = train['util'].map(mandf.ctx) train.text_cleaned = train.text_cleaned + ' ' + train.ctx.fillna('') train['text_enc'] = train.text_cleaned.progress_apply(text_tokenizer.encode) train['cmd_enc'] = train.cmd_cleaned.progress_apply(cmd_tokenizer.encode) tdf = train[train.origin == 'original'] tdf2 = train[train.origin != 'original'] train, valid = train_test_split(tdf, test_size=500, random_state=SEED) train = pd.concat([train, tdf2]).reset_index(drop=True) train_ds = MtDataset(train.text_enc, train.cmd_enc, config, bos_id, eos_id, pad_id) valid_ds = MtDataset(valid.text_enc, valid.cmd_enc, config, bos_id, eos_id, pad_id) model = Transformer(config, pad_id) print('# params', sum(p.numel() for p in model.parameters() if p.requires_grad)) loaders = { 'train': data.DataLoader(train_ds, batch_size=config.batch_size, shuffle=True), 'valid': data.DataLoader(valid_ds, batch_size=config.batch_size), } criterion = nn.CrossEntropyLoss(ignore_index=pad_id) optimizer = torch.optim.Adam(model.parameters(), lr=config.optimizer_lr, weight_decay=config.weight_decay, amsgrad=True) callbacks=[ dl.CheckpointCallback(config.num_epochs), ] callbacks.append( dl.SchedulerCallback(mode="epoch") ) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=config.plateau_factor, patience=3, cooldown=2, threshold=1e-3, min_lr=1e-6) shutil.rmtree(logdir, ignore_errors=True) os.makedirs(logdir, exist_ok=True) runner = dl.SupervisedRunner(device=device) runner.train( model=model, loaders=loaders, criterion=criterion, optimizer=optimizer, scheduler=scheduler if config.schedule else None, num_epochs=config.num_epochs, verbose=True, logdir=logdir, callbacks=callbacks, # check=True ) def make_ctx(rec): options = eval(rec['options']) cmd = rec['cmd'] synopsis = clean_text(rec['synopsis']) r = f'|{cmd} {synopsis}' for opt in options: short_flag = opt['short'][0] if len(opt['short']) > 0 else '' text = clean_text(opt['text']) r += f'|{short_flag} ' + ' '.join(text.split()[:5]) return r if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument("dev_dir", type=str) parser.add_argument("logdir", type=str) parser.add_argument('-d', '--device', type=str, default='cpu', required=False) args = parser.parse_args() train(args.dev_dir, args.logdir, args.device)