in src/engines.py [0:0]
def __init__(self, opt):
self.seed = opt['seed']
set_seed(int(self.seed))
self.alias = _set_exp_alias(opt)
self.cache_eval = _set_cache_path(opt['cache_eval'], opt['dataset'], self.alias)
self.model_cache_path = _set_cache_path(opt['model_cache_path'], opt['dataset'], self.alias)
opt['cache_eval'] = self.cache_eval
# dataset
self.dataset = setup_ds(opt)
opt['size'] = self.dataset.get_shape()
# model
self.model = setup_model(opt)
self.optimizer = setup_optimizer(self.model, opt['optimizer'], opt['learning_rate'], opt['decay1'], opt['decay2'])
self.loss = setup_loss(opt)
opt['loss'] = self.loss
self.batch_size = opt['batch_size']
# regularizer
self.regularizer = setup_regularizer(opt)
self.device = opt['device']
self.max_epochs = opt['max_epochs']
self.world = opt['world']
self.num_neg = opt['num_neg']
self.score_rel = opt['score_rel']
self.score_rhs = opt['score_rhs']
self.score_lhs = opt['score_lhs']
self.w_rel = opt['w_rel']
self.w_lhs = opt['w_lhs']
self.opt = opt
self._epoch_id = 0
wandb.init(project="ssl-relation-prediction",
group=opt['experiment_id'],
tags=opt['run_tags'],
notes=opt['run_notes'])
wandb.config.update(opt)
wandb.watch(self.model, log='all', log_freq=10000)
wandb.run.summary['is_done'] = False
print('Git commit ID: {}'.format(get_git_revision_hash()))