in src/run_fusion_in_decoder.py [0:0]
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str, default='')
args, _ = parser.parse_known_args()
cfg = OmegaConf.load(f'/opt/ml/code/model_arcifacts/fusion_config{args.config}.yaml')
cfg.data.data_dir = os.environ['SM_CHANNEL_TRAIN']
cfg.data.output_dir = os.path.join(os.environ['SM_MODEL_DIR'], 'output')
os.makedirs(cfg.data.output_dir, exist_ok=True)
cfg.model.checkpoint_dir = os.path.join(os.environ['SM_MODEL_DIR'], 'ckpt')
os.makedirs(cfg.model.checkpoint_dir, exist_ok=True)
# set seed
seed_everything(cfg.optim.seed)
# checkpoint
checkpoint_dir = os.path.join(cfg.model.checkpoint_dir, cfg.model.model_name)
if not os.path.exists(checkpoint_dir):
os.makedirs(checkpoint_dir)
checkpoint_callback = ModelCheckpoint(
monitor='avg_val_loss',
filepath=os.path.join(checkpoint_dir, '{epoch}-{val_loss:.4f}'),
mode='min',
save_last=False,
save_top_k=2,
)
tokenizer = T5Tokenizer.from_pretrained(
cfg.model.tokenizer_name if cfg.model.tokenizer_name else cfg.model.model_name,
cache_dir=cfg.model.cache_dir,
use_fast=cfg.model.use_fast,
)
model_t5 = T5(cfg, tokenizer)
if cfg.model.model_checkpoint:
logger.info(f"Loading the checkpoint {cfg.model.model_checkpoint} and continue training")
model_checkpoint = torch.load(cfg.model.model_checkpoint, map_location=lambda storage, loc: storage)
model_dict = model_checkpoint['state_dict']
model_t5.load_state_dict(model_dict)
# training and testing
if cfg.do_train:
train_dataloader = generate_dataloader(
data_dir = cfg.data.data_dir,
tokenizer = tokenizer,
max_source_length = cfg.data.max_source_length,
max_target_length = cfg.data.max_target_length,
overwrite_cache = cfg.data.overwrite_cache,
mode = "train",
batch_size = cfg.optim.train_batch_size,
question_type = cfg.data.question_type,
passage_type = cfg.data.passage_type,
enable_sql_supervision = cfg.data.enable_sql_supervision,
cand_for_each_source = cfg.data.cand_for_each_source,
)
dev_dataloader = generate_dataloader(
data_dir = cfg.data.data_dir,
tokenizer = tokenizer,
max_source_length = cfg.data.max_source_length,
max_target_length = cfg.data.max_target_length,
overwrite_cache = cfg.data.overwrite_cache,
mode = "dev",
batch_size = cfg.optim.dev_batch_size,
question_type = cfg.data.question_type,
passage_type = cfg.data.passage_type,
enable_sql_supervision = cfg.data.enable_sql_supervision,
cand_for_each_source = cfg.data.cand_for_each_source,
)
logger.info("Training starts")
# tb_logger = loggers.WandbLogger(save_dir=cfg.optim.logging_dir, project='fusion in decoder')
trainer = pl.Trainer(
# logger=tb_logger,
checkpoint_callback=checkpoint_callback,
**OmegaConf.to_container(cfg.trainer, resolve=True),
)
trainer.fit(model_t5, train_dataloader, dev_dataloader)
# trainer.test(model_t5)
if cfg.do_eval:
test_dataloader = generate_dataloader(
data_dir = cfg.data.data_dir,
tokenizer = tokenizer,
max_source_length = cfg.data.max_source_length,
max_target_length = cfg.data.max_target_length,
overwrite_cache = cfg.data.overwrite_cache,
mode = "test",
batch_size = cfg.optim.test_batch_size,
question_type = cfg.data.question_type,
passage_type = cfg.data.passage_type,
enable_sql_supervision = cfg.data.enable_sql_supervision,
cand_for_each_source = cfg.data.cand_for_each_source,
)
logger.info("Evaluation starts")
best_checkpoint_file = None
if cfg.model.model_checkpoint == None:
# find best checkpoint
best_val_loss = 10000.
for checkpoint_file in glob.glob(os.path.join(checkpoint_dir, "*val_loss*.ckpt")):
try:
val_loss = float(checkpoint_file.split('=')[-1].replace(".ckpt", ""))
except:
continue
if val_loss < best_val_loss:
best_val_loss = val_loss
best_checkpoint_file = checkpoint_file
logger.info(f"Loading the checkpoint: {best_checkpoint_file}")
else:
best_checkpoint_file = cfg.model.model_checkpoint
# load model
if best_checkpoint_file is not None:
best_checkpoint = torch.load(best_checkpoint_file, map_location=lambda storage, loc: storage)
model_t5.load_state_dict(best_checkpoint['state_dict'])
# test using Trainer test function
trainer = pl.Trainer(**OmegaConf.to_container(cfg.trainer, resolve=True))
trainer.test(model_t5, test_dataloader)