in hype_kg/codes/run.py [0:0]
def main(args):
set_global_seed(args.seed)
args.test_batch_size = 1
assert args.bn in ['no', 'before', 'after']
assert args.n_att >= 1 and args.n_att <= 3
assert args.max_steps == args.stepsforpath
if args.geo == 'box':
assert 'Box' in args.model
elif args.geo == 'vec':
assert 'Box' not in args.model
if args.train_onehop_only:
assert '1c' in args.task
args.center_deepsets = 'mean'
if args.geo == 'box':
args.offset_deepsets = 'min'
if (not args.do_train) and (not args.do_valid) and (not args.do_test) and (not args.evaluate_train):
raise ValueError('one of train/val/test mode must be choosed.')
if args.init_checkpoint:
override_config(args)
elif args.data_path is None:
raise ValueError('one of init_checkpoint/data_path must be choosed.')
# if args.do_train and args.save_path is None:
# raise ValueError('Where do you want to save your trained model?')
cur_time = parse_time()
print ("overide save string.")
if args.task == '1c':
args.stepsforpath = 0
else:
assert args.stepsforpath <= args.max_steps
# logs_newfb237_inter
args.save_path = 'logs/%s/%s/%s/'%(args.data_path.split('/')[-1], args.geo, args.manifold)
writer = SummaryWriter(args.save_path)
if args.save_path and not os.path.exists(args.save_path):
os.makedirs(args.save_path)
set_logger(args)
with open('%s/stats.txt'%args.data_path) as f:
entrel = f.readlines()
nentity = int(entrel[0].split(' ')[-1])
nrelation = int(entrel[1].split(' ')[-1])
args.nentity = nentity
args.nrelation = nrelation
logging.info('Geo: %s' % args.geo)
logging.info('Model: %s' % args.model)
logging.info('Data Path: %s' % args.data_path)
logging.info('#entity: %d' % nentity)
logging.info('#relation: %d' % nrelation)
logging.info('#max steps: %d' % args.max_steps)
logging.info('#stepsforpath: %d' % args.stepsforpath)
logging.info('#manifold: %s' % args.manifold)
logging.info('#curvature: %d' % args.curvature)
logging.info('#trainable curvature: %s' % args.trainable_curvature)
logging.info('#use_semantics: %s', args.use_semantics)
tasks = args.task.split('.')
train_ans = dict()
valid_ans = dict()
valid_ans_hard = dict()
test_ans = dict()
test_ans_hard = dict()
if '1c' in tasks:
with open('%s/train_triples_1c.pkl'%args.data_path, 'rb') as handle:
train_triples = pickle.load(handle)
with open('%s/train_ans_1c.pkl'%args.data_path, 'rb') as handle:
train_ans_1 = pickle.load(handle)
with open('%s/valid_triples_1c.pkl'%args.data_path, 'rb') as handle:
valid_triples = pickle.load(handle)
with open('%s/valid_ans_1c.pkl'%args.data_path, 'rb') as handle:
valid_ans_1 = pickle.load(handle)
with open('%s/valid_ans_1c_hard.pkl'%args.data_path, 'rb') as handle:
valid_ans_1_hard = pickle.load(handle)
with open('%s/test_triples_1c.pkl'%args.data_path, 'rb') as handle:
test_triples = pickle.load(handle)
with open('%s/test_ans_1c.pkl'%args.data_path, 'rb') as handle:
test_ans_1 = pickle.load(handle)
with open('%s/test_ans_1c_hard.pkl'%args.data_path, 'rb') as handle:
test_ans_1_hard = pickle.load(handle)
train_ans.update(train_ans_1)
valid_ans.update(valid_ans_1)
valid_ans_hard.update(valid_ans_1_hard)
test_ans.update(test_ans_1)
test_ans_hard.update(test_ans_1_hard)
if '2c' in tasks:
with open('%s/train_triples_2c.pkl'%args.data_path, 'rb') as handle:
train_triples_2 = pickle.load(handle)
with open('%s/train_ans_2c.pkl'%args.data_path, 'rb') as handle:
train_ans_2 = pickle.load(handle)
with open('%s/valid_triples_2c.pkl'%args.data_path, 'rb') as handle:
valid_triples_2 = pickle.load(handle)
with open('%s/valid_ans_2c.pkl'%args.data_path, 'rb') as handle:
valid_ans_2 = pickle.load(handle)
with open('%s/valid_ans_2c_hard.pkl'%args.data_path, 'rb') as handle:
valid_ans_2_hard = pickle.load(handle)
with open('%s/test_triples_2c.pkl'%args.data_path, 'rb') as handle:
test_triples_2 = pickle.load(handle)
with open('%s/test_ans_2c.pkl'%args.data_path, 'rb') as handle:
test_ans_2 = pickle.load(handle)
with open('%s/test_ans_2c_hard.pkl'%args.data_path, 'rb') as handle:
test_ans_2_hard = pickle.load(handle)
train_ans.update(train_ans_2)
valid_ans.update(valid_ans_2)
valid_ans_hard.update(valid_ans_2_hard)
test_ans.update(test_ans_2)
test_ans_hard.update(test_ans_2_hard)
if '3c' in tasks:
with open('%s/train_triples_3c.pkl'%args.data_path, 'rb') as handle:
train_triples_3 = pickle.load(handle)
with open('%s/train_ans_3c.pkl'%args.data_path, 'rb') as handle:
train_ans_3 = pickle.load(handle)
with open('%s/valid_triples_3c.pkl'%args.data_path, 'rb') as handle:
valid_triples_3 = pickle.load(handle)
with open('%s/valid_ans_3c.pkl'%args.data_path, 'rb') as handle:
valid_ans_3 = pickle.load(handle)
with open('%s/valid_ans_3c_hard.pkl'%args.data_path, 'rb') as handle:
valid_ans_3_hard = pickle.load(handle)
with open('%s/test_triples_3c.pkl'%args.data_path, 'rb') as handle:
test_triples_3 = pickle.load(handle)
with open('%s/test_ans_3c.pkl'%args.data_path, 'rb') as handle:
test_ans_3 = pickle.load(handle)
with open('%s/test_ans_3c_hard.pkl'%args.data_path, 'rb') as handle:
test_ans_3_hard = pickle.load(handle)
train_ans.update(train_ans_3)
valid_ans.update(valid_ans_3)
valid_ans_hard.update(valid_ans_3_hard)
test_ans.update(test_ans_3)
test_ans_hard.update(test_ans_3_hard)
if '2i' in tasks:
with open('%s/train_triples_2i.pkl'%args.data_path, 'rb') as handle:
train_triples_2i = pickle.load(handle)
with open('%s/train_ans_2i.pkl'%args.data_path, 'rb') as handle:
train_ans_2i = pickle.load(handle)
with open('%s/valid_triples_2i.pkl'%args.data_path, 'rb') as handle:
valid_triples_2i = pickle.load(handle)
with open('%s/valid_ans_2i.pkl'%args.data_path, 'rb') as handle:
valid_ans_2i = pickle.load(handle)
with open('%s/valid_ans_2i_hard.pkl'%args.data_path, 'rb') as handle:
valid_ans_2i_hard = pickle.load(handle)
with open('%s/test_triples_2i.pkl'%args.data_path, 'rb') as handle:
test_triples_2i = pickle.load(handle)
with open('%s/test_ans_2i.pkl'%args.data_path, 'rb') as handle:
test_ans_2i = pickle.load(handle)
with open('%s/test_ans_2i_hard.pkl'%args.data_path, 'rb') as handle:
test_ans_2i_hard = pickle.load(handle)
train_ans.update(train_ans_2i)
valid_ans.update(valid_ans_2i)
valid_ans_hard.update(valid_ans_2i_hard)
test_ans.update(test_ans_2i)
test_ans_hard.update(test_ans_2i_hard)
if '3i' in tasks:
with open('%s/train_triples_3i.pkl'%args.data_path, 'rb') as handle:
train_triples_3i = pickle.load(handle)
with open('%s/train_ans_3i.pkl'%args.data_path, 'rb') as handle:
train_ans_3i = pickle.load(handle)
with open('%s/valid_triples_3i.pkl'%args.data_path, 'rb') as handle:
valid_triples_3i = pickle.load(handle)
with open('%s/valid_ans_3i.pkl'%args.data_path, 'rb') as handle:
valid_ans_3i = pickle.load(handle)
with open('%s/valid_ans_3i_hard.pkl'%args.data_path, 'rb') as handle:
valid_ans_3i_hard = pickle.load(handle)
with open('%s/test_triples_3i.pkl'%args.data_path, 'rb') as handle:
test_triples_3i = pickle.load(handle)
with open('%s/test_ans_3i.pkl'%args.data_path, 'rb') as handle:
test_ans_3i = pickle.load(handle)
with open('%s/test_ans_3i_hard.pkl'%args.data_path, 'rb') as handle:
test_ans_3i_hard = pickle.load(handle)
train_ans.update(train_ans_3i)
valid_ans.update(valid_ans_3i)
valid_ans_hard.update(valid_ans_3i_hard)
test_ans.update(test_ans_3i)
test_ans_hard.update(test_ans_3i_hard)
if 'ci' in tasks:
with open('%s/valid_triples_ci.pkl'%args.data_path, 'rb') as handle:
valid_triples_ci = pickle.load(handle)
with open('%s/valid_ans_ci.pkl'%args.data_path, 'rb') as handle:
valid_ans_ci = pickle.load(handle)
with open('%s/valid_ans_ci_hard.pkl'%args.data_path, 'rb') as handle:
valid_ans_ci_hard = pickle.load(handle)
with open('%s/test_triples_ci.pkl'%args.data_path, 'rb') as handle:
test_triples_ci = pickle.load(handle)
with open('%s/test_ans_ci.pkl'%args.data_path, 'rb') as handle:
test_ans_ci = pickle.load(handle)
with open('%s/test_ans_ci_hard.pkl'%args.data_path, 'rb') as handle:
test_ans_ci_hard = pickle.load(handle)
valid_ans.update(valid_ans_ci)
valid_ans_hard.update(valid_ans_ci_hard)
test_ans.update(test_ans_ci)
test_ans_hard.update(test_ans_ci_hard)
if 'ic' in tasks:
with open('%s/valid_triples_ic.pkl'%args.data_path, 'rb') as handle:
valid_triples_ic = pickle.load(handle)
with open('%s/valid_ans_ic.pkl'%args.data_path, 'rb') as handle:
valid_ans_ic = pickle.load(handle)
with open('%s/valid_ans_ic_hard.pkl'%args.data_path, 'rb') as handle:
valid_ans_ic_hard = pickle.load(handle)
with open('%s/test_triples_ic.pkl'%args.data_path, 'rb') as handle:
test_triples_ic = pickle.load(handle)
with open('%s/test_ans_ic.pkl'%args.data_path, 'rb') as handle:
test_ans_ic = pickle.load(handle)
with open('%s/test_ans_ic_hard.pkl'%args.data_path, 'rb') as handle:
test_ans_ic_hard = pickle.load(handle)
valid_ans.update(valid_ans_ic)
valid_ans_hard.update(valid_ans_ic_hard)
test_ans.update(test_ans_ic)
test_ans_hard.update(test_ans_ic_hard)
if 'uc' in tasks:
with open('%s/valid_triples_uc.pkl'%args.data_path, 'rb') as handle:
valid_triples_uc = pickle.load(handle)
with open('%s/valid_ans_uc.pkl'%args.data_path, 'rb') as handle:
valid_ans_uc = pickle.load(handle)
with open('%s/valid_ans_uc_hard.pkl'%args.data_path, 'rb') as handle:
valid_ans_uc_hard = pickle.load(handle)
with open('%s/test_triples_uc.pkl'%args.data_path, 'rb') as handle:
test_triples_uc = pickle.load(handle)
with open('%s/test_ans_uc.pkl'%args.data_path, 'rb') as handle:
test_ans_uc = pickle.load(handle)
with open('%s/test_ans_uc_hard.pkl'%args.data_path, 'rb') as handle:
test_ans_uc_hard = pickle.load(handle)
valid_ans.update(valid_ans_uc)
valid_ans_hard.update(valid_ans_uc_hard)
test_ans.update(test_ans_uc)
test_ans_hard.update(test_ans_uc_hard)
if '2u' in tasks:
with open('%s/valid_triples_2u.pkl'%args.data_path, 'rb') as handle:
valid_triples_2u = pickle.load(handle)
with open('%s/valid_ans_2u.pkl'%args.data_path, 'rb') as handle:
valid_ans_2u = pickle.load(handle)
with open('%s/valid_ans_2u_hard.pkl'%args.data_path, 'rb') as handle:
valid_ans_2u_hard = pickle.load(handle)
with open('%s/test_triples_2u.pkl'%args.data_path, 'rb') as handle:
test_triples_2u = pickle.load(handle)
with open('%s/test_ans_2u.pkl'%args.data_path, 'rb') as handle:
test_ans_2u = pickle.load(handle)
with open('%s/test_ans_2u_hard.pkl'%args.data_path, 'rb') as handle:
test_ans_2u_hard = pickle.load(handle)
valid_ans.update(valid_ans_2u)
valid_ans_hard.update(valid_ans_2u_hard)
test_ans.update(test_ans_2u)
test_ans_hard.update(test_ans_2u_hard)
if '1c' in tasks:
logging.info('#train: %d' % len(train_triples))
logging.info('#valid: %d' % len(valid_triples))
logging.info('#test: %d' % len(test_triples))
if '2c' in tasks:
logging.info('#train_2c: %d' % len(train_triples_2))
logging.info('#valid_2c: %d' % len(valid_triples_2))
logging.info('#test_2c: %d' % len(test_triples_2))
if '3c' in tasks:
logging.info('#train_3c: %d' % len(train_triples_3))
logging.info('#valid_3c: %d' % len(valid_triples_3))
logging.info('#test_3c: %d' % len(test_triples_3))
if '2i' in tasks:
logging.info('#train_2i: %d' % len(train_triples_2i))
logging.info('#valid_2i: %d' % len(valid_triples_2i))
logging.info('#test_2i: %d' % len(test_triples_2i))
if '3i' in tasks:
logging.info('#train_3i: %d' % len(train_triples_3i))
logging.info('#valid_3i: %d' % len(valid_triples_3i))
logging.info('#test_3i: %d' % len(test_triples_3i))
if 'ci' in tasks:
logging.info('#valid_ci: %d' % len(valid_triples_ci))
logging.info('#test_ci: %d' % len(test_triples_ci))
if 'ic' in tasks:
logging.info('#valid_ic: %d' % len(valid_triples_ic))
logging.info('#test_ic: %d' % len(test_triples_ic))
if '2u' in tasks:
logging.info('#valid_2u: %d' % len(valid_triples_2u))
logging.info('#test_2u: %d' % len(test_triples_2u))
if 'uc' in tasks:
logging.info('#valid_uc: %d' % len(valid_triples_uc))
logging.info('#test_uc: %d' % len(test_triples_uc))
query2manifold = Query2Manifold(
model_name=args.model,
nentity=nentity,
nrelation=nrelation,
hidden_dim=args.hidden_dim,
gamma=args.gamma,
writer=writer,
geo=args.geo,
cen=args.center_reg,
offset_deepsets = args.offset_deepsets,
center_deepsets = args.center_deepsets,
offset_use_center = args.offset_use_center,
center_use_offset = args.center_use_offset,
att_reg = args.att_reg,
off_reg = args.off_reg,
att_tem = args.att_tem,
euo = args.entity_use_offset,
gamma2 = args.gamma2,
bn = args.bn,
nat = args.n_att,
activation = args.activation,
manifold = args.manifold,
curvature = args.curvature,
trainable_curvature = args.trainable_curvature,
use_semantics = args.use_semantics
)
logging.info('Model Parameter Configuration:')
num_params = 0
for name, param in query2manifold.named_parameters():
logging.info('Parameter %s: %s, require_grad = %s' % (name, str(param.size()), str(param.requires_grad)))
if param.requires_grad:
num_params += np.prod(param.size())
logging.info('Parameter Number: %d' % num_params)
if args.cuda:
query2manifold = query2manifold.cuda()
if args.do_train:
# Set training dataloader iterator
if '1c' in tasks:
train_dataloader_tail = DataLoader(
TrainDataset(train_triples, nentity, nrelation, args.negative_sample_size, train_ans, 'tail-batch'),
batch_size=args.batch_size,
shuffle=True,
num_workers=max(1, args.cpu_num),
collate_fn=TrainDataset.collate_fn
)
train_iterator = SingledirectionalOneShotIterator(train_dataloader_tail, train_triples[0][-1])
if '2c' in tasks:
train_dataloader_2_tail = DataLoader(
TrainDataset(train_triples_2, nentity, nrelation, args.negative_sample_size, train_ans, 'tail-batch'),
batch_size=args.batch_size,
shuffle=True,
num_workers=max(1, args.cpu_num),
collate_fn=TrainDataset.collate_fn
)
train_iterator_2 = SingledirectionalOneShotIterator(train_dataloader_2_tail, train_triples_2[0][-1])
if '3c' in tasks:
train_dataloader_3_tail = DataLoader(
TrainDataset(train_triples_3, nentity, nrelation, args.negative_sample_size, train_ans, 'tail-batch'),
batch_size=args.batch_size,
shuffle=True,
num_workers=max(1, args.cpu_num),
collate_fn=TrainDataset.collate_fn
)
train_iterator_3 = SingledirectionalOneShotIterator(train_dataloader_3_tail, train_triples_3[0][-1])
if '2i' in tasks:
train_dataloader_2i_tail = DataLoader(
TrainInterDataset(train_triples_2i, nentity, nrelation, args.negative_sample_size, train_ans, 'tail-batch'),
batch_size=args.batch_size,
shuffle=True,
num_workers=max(1, args.cpu_num),
collate_fn=TrainInterDataset.collate_fn
)
train_iterator_2i = SingledirectionalOneShotIterator(train_dataloader_2i_tail, train_triples_2i[0][-1])
if '3i' in tasks:
train_dataloader_3i_tail = DataLoader(
TrainInterDataset(train_triples_3i, nentity, nrelation, args.negative_sample_size, train_ans, 'tail-batch'),
batch_size=args.batch_size,
shuffle=True,
num_workers=max(1, args.cpu_num),
collate_fn=TrainInterDataset.collate_fn
)
train_iterator_3i = SingledirectionalOneShotIterator(train_dataloader_3i_tail, train_triples_3i[0][-1])
# Set training configuration
current_learning_rate = args.learning_rate
optimizer = radam.RiemannianAdam(
filter(lambda p: p.requires_grad, query2manifold.parameters()),
lr=current_learning_rate
)
# torch.optim.Adam(
# filter(lambda p: p.requires_grad, query2manifold.parameters()),
# lr=current_learning_rate
# )
if args.warm_up_steps:
warm_up_steps = args.warm_up_steps
else:
warm_up_steps = args.max_steps // 2
if args.init_checkpoint:
# Restore model from checkpoint directory
logging.info('Loading checkpoint %s...' % args.init_checkpoint)
checkpoint = torch.load(os.path.join(args.init_checkpoint, 'checkpoint'))
init_step = checkpoint['step']
query2manifold.load_state_dict(checkpoint['model_state_dict'])
if args.do_train:
current_learning_rate = checkpoint['current_learning_rate']
warm_up_steps = checkpoint['warm_up_steps']
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
else:
logging.info('Ramdomly Initializing %s Model...' % args.model)
init_step = 0
step = init_step
logging.info('task = %s' % args.task)
logging.info('init_step = %d' % init_step)
if args.do_train:
logging.info('Start Training...')
logging.info('learning_rate = %d' % current_learning_rate)
logging.info('batch_size = %d' % args.batch_size)
logging.info('negative_adversarial_sampling = %d' % args.negative_adversarial_sampling)
logging.info('hidden_dim = %d' % args.hidden_dim)
logging.info('gamma = %f' % args.gamma)
logging.info('negative_adversarial_sampling = %s' % str(args.negative_adversarial_sampling))
if args.negative_adversarial_sampling:
logging.info('adversarial_temperature = %f' % args.adversarial_temperature)
# Set valid dataloader as it would be evaluated during training
def evaluate_test():
average_metrics = collections.defaultdict(list)
average_c_metrics = collections.defaultdict(list)
average_c2_metrics = collections.defaultdict(list)
average_i_metrics = collections.defaultdict(list)
average_ex_metrics = collections.defaultdict(list)
average_u_metrics = collections.defaultdict(list)
if '2i' in tasks:
metrics = query2manifold.test_step(query2manifold, test_triples_2i, test_ans, test_ans_hard, args)
log_metrics('Test 2i', step, metrics)
for metric in metrics:
writer.add_scalar('Test_2i_'+metric, metrics[metric], step)
average_metrics[metric].append(metrics[metric])
average_i_metrics[metric].append(metrics[metric])
if '3i' in tasks:
metrics = query2manifold.test_step(query2manifold, test_triples_3i, test_ans, test_ans_hard, args)
log_metrics('Test 3i', step, metrics)
for metric in metrics:
writer.add_scalar('Test_3i_'+metric, metrics[metric], step)
average_metrics[metric].append(metrics[metric])
average_i_metrics[metric].append(metrics[metric])
if '2c' in tasks:
metrics = query2manifold.test_step(query2manifold, test_triples_2, test_ans, test_ans_hard, args)
log_metrics('Test 2c', step, metrics)
for metric in metrics:
writer.add_scalar('Test_2c_'+metric, metrics[metric], step)
average_metrics[metric].append(metrics[metric])
average_c_metrics[metric].append(metrics[metric])
average_c2_metrics[metric].append(metrics[metric])
if '3c' in tasks:
metrics = query2manifold.test_step(query2manifold, test_triples_3, test_ans, test_ans_hard, args)
log_metrics('Test 3c', step, metrics)
for metric in metrics:
writer.add_scalar('Test_3c_'+metric, metrics[metric], step)
average_metrics[metric].append(metrics[metric])
average_c_metrics[metric].append(metrics[metric])
average_c2_metrics[metric].append(metrics[metric])
if '1c' in tasks:
metrics = query2manifold.test_step(query2manifold, test_triples, test_ans, test_ans_hard, args)
log_metrics('Test 1c', step, metrics)
for metric in metrics:
writer.add_scalar('Test_1c_'+metric, metrics[metric], step)
average_metrics[metric].append(metrics[metric])
average_c_metrics[metric].append(metrics[metric])
if 'ci' in tasks:
metrics = query2manifold.test_step(query2manifold, test_triples_ci, test_ans, test_ans_hard, args)
log_metrics('Test ci', step, metrics)
for metric in metrics:
writer.add_scalar('Test_ci_'+metric, metrics[metric], step)
average_metrics[metric].append(metrics[metric])
average_ex_metrics[metric].append(metrics[metric])
if 'ic' in tasks:
metrics = query2manifold.test_step(query2manifold, test_triples_ic, test_ans, test_ans_hard, args)
log_metrics('Test ic', step, metrics)
for metric in metrics:
writer.add_scalar('Test_ic_'+metric, metrics[metric], step)
average_metrics[metric].append(metrics[metric])
average_ex_metrics[metric].append(metrics[metric])
if '2u' in tasks:
metrics = query2manifold.test_step(query2manifold, test_triples_2u, test_ans, test_ans_hard, args)
log_metrics('Test 2u', step, metrics)
for metric in metrics:
writer.add_scalar('Test_2u_'+metric, metrics[metric], step)
average_metrics[metric].append(metrics[metric])
average_u_metrics[metric].append(metrics[metric])
if 'uc' in tasks:
metrics = query2manifold.test_step(query2manifold, test_triples_uc, test_ans, test_ans_hard, args)
log_metrics('Test uc', step, metrics)
for metric in metrics:
writer.add_scalar('Test_uc_'+metric, metrics[metric], step)
average_metrics[metric].append(metrics[metric])
average_u_metrics[metric].append(metrics[metric])
for metric in average_metrics:
writer.add_scalar('Test_average_'+metric, np.mean(average_metrics[metric]), step)
for metric in average_c_metrics:
writer.add_scalar('Test_average_c_'+metric, np.mean(average_c_metrics[metric]), step)
for metric in average_c2_metrics:
writer.add_scalar('Test_average_c2_'+metric, np.mean(average_c2_metrics[metric]), step)
for metric in average_i_metrics:
writer.add_scalar('Test_average_i_'+metric, np.mean(average_i_metrics[metric]), step)
for metric in average_u_metrics:
writer.add_scalar('Test_average_u_'+metric, np.mean(average_u_metrics[metric]), step)
for metric in average_ex_metrics:
writer.add_scalar('Test_average_ex_'+metric, np.mean(average_ex_metrics[metric]), step)
def evaluate_val():
average_metrics = collections.defaultdict(list)
average_c_metrics = collections.defaultdict(list)
average_c2_metrics = collections.defaultdict(list)
average_i_metrics = collections.defaultdict(list)
average_ex_metrics = collections.defaultdict(list)
average_u_metrics = collections.defaultdict(list)
if '2i' in tasks:
metrics = query2manifold.test_step(query2manifold, valid_triples_2i, valid_ans, valid_ans_hard, args)
log_metrics('Valid 2i', step, metrics)
for metric in metrics:
writer.add_scalar('Valid_2i_'+metric, metrics[metric], step)
average_metrics[metric].append(metrics[metric])
average_i_metrics[metric].append(metrics[metric])
if '3i' in tasks:
metrics = query2manifold.test_step(query2manifold, valid_triples_3i, valid_ans, valid_ans_hard, args)
log_metrics('Valid 3i', step, metrics)
for metric in metrics:
writer.add_scalar('Valid_3i_'+metric, metrics[metric], step)
average_metrics[metric].append(metrics[metric])
average_i_metrics[metric].append(metrics[metric])
if '2c' in tasks:
metrics = query2manifold.test_step(query2manifold, valid_triples_2, valid_ans, valid_ans_hard, args)
log_metrics('Valid 2c', step, metrics)
for metric in metrics:
writer.add_scalar('Valid_2c_'+metric, metrics[metric], step)
average_metrics[metric].append(metrics[metric])
average_c_metrics[metric].append(metrics[metric])
average_c2_metrics[metric].append(metrics[metric])
if '3c' in tasks:
metrics = query2manifold.test_step(query2manifold, valid_triples_3, valid_ans, valid_ans_hard, args)
log_metrics('Valid 3c', step, metrics)
for metric in metrics:
writer.add_scalar('Valid_3c_'+metric, metrics[metric], step)
average_metrics[metric].append(metrics[metric])
average_c_metrics[metric].append(metrics[metric])
average_c2_metrics[metric].append(metrics[metric])
if '1c' in tasks:
metrics = query2manifold.test_step(query2manifold, valid_triples, valid_ans, valid_ans_hard, args)
log_metrics('Valid 1c', step, metrics)
for metric in metrics:
writer.add_scalar('Valid_1c_'+metric, metrics[metric], step)
average_metrics[metric].append(metrics[metric])
average_c_metrics[metric].append(metrics[metric])
if 'ci' in tasks:
metrics = query2manifold.test_step(query2manifold, valid_triples_ci, valid_ans, valid_ans_hard, args)
log_metrics('Valid ci', step, metrics)
for metric in metrics:
writer.add_scalar('Valid_ci_'+metric, metrics[metric], step)
average_metrics[metric].append(metrics[metric])
average_ex_metrics[metric].append(metrics[metric])
if 'ic' in tasks:
metrics = query2manifold.test_step(query2manifold, valid_triples_ic, valid_ans, valid_ans_hard, args)
log_metrics('Valid ic', step, metrics)
for metric in metrics:
writer.add_scalar('Valid_ic_'+metric, metrics[metric], step)
average_metrics[metric].append(metrics[metric])
average_ex_metrics[metric].append(metrics[metric])
if '2u' in tasks:
metrics = query2manifold.test_step(query2manifold, valid_triples_2u, valid_ans, valid_ans_hard, args)
log_metrics('Valid 2u', step, metrics)
for metric in metrics:
writer.add_scalar('Valid_2u_'+metric, metrics[metric], step)
average_metrics[metric].append(metrics[metric])
average_u_metrics[metric].append(metrics[metric])
if 'uc' in tasks:
metrics = query2manifold.test_step(query2manifold, valid_triples_uc, valid_ans, valid_ans_hard, args)
log_metrics('Valid uc', step, metrics)
for metric in metrics:
writer.add_scalar('Valid_uc_'+metric, metrics[metric], step)
average_metrics[metric].append(metrics[metric])
average_u_metrics[metric].append(metrics[metric])
for metric in average_metrics:
writer.add_scalar('Valid_average_'+metric, np.mean(average_metrics[metric]), step)
for metric in average_c_metrics:
writer.add_scalar('Valid_average_c_'+metric, np.mean(average_c_metrics[metric]), step)
for metric in average_c2_metrics:
writer.add_scalar('Valid_average_c2_'+metric, np.mean(average_c2_metrics[metric]), step)
for metric in average_i_metrics:
writer.add_scalar('Valid_average_i_'+metric, np.mean(average_i_metrics[metric]), step)
for metric in average_u_metrics:
writer.add_scalar('Valid_average_u_'+metric, np.mean(average_u_metrics[metric]), step)
for metric in average_ex_metrics:
writer.add_scalar('Valid_average_ex_'+metric, np.mean(average_ex_metrics[metric]), step)
def evaluate_train():
average_metrics = collections.defaultdict(list)
average_c_metrics = collections.defaultdict(list)
average_c2_metrics = collections.defaultdict(list)
average_i_metrics = collections.defaultdict(list)
if '2i' in tasks:
metrics = query2manifold.test_step(query2manifold, train_triples_2i, train_ans, train_ans, args)
log_metrics('train 2i', step, metrics)
for metric in metrics:
writer.add_scalar('train_2i_'+metric, metrics[metric], step)
average_metrics[metric].append(metrics[metric])
average_i_metrics[metric].append(metrics[metric])
if '3i' in tasks:
metrics = query2manifold.test_step(query2manifold, train_triples_3i, train_ans, train_ans, args)
log_metrics('train 3i', step, metrics)
for metric in metrics:
writer.add_scalar('train_3i_'+metric, metrics[metric], step)
average_metrics[metric].append(metrics[metric])
average_i_metrics[metric].append(metrics[metric])
if '2c' in tasks:
metrics = query2manifold.test_step(query2manifold, train_triples_2, train_ans, train_ans, args)
log_metrics('train 2c', step, metrics)
for metric in metrics:
writer.add_scalar('train_2c_'+metric, metrics[metric], step)
average_metrics[metric].append(metrics[metric])
average_c_metrics[metric].append(metrics[metric])
average_c2_metrics[metric].append(metrics[metric])
if '3c' in tasks:
metrics = query2manifold.test_step(query2manifold, train_triples_3, train_ans, train_ans, args)
log_metrics('train 3c', step, metrics)
for metric in metrics:
writer.add_scalar('train_3c_'+metric, metrics[metric], step)
average_metrics[metric].append(metrics[metric])
average_c_metrics[metric].append(metrics[metric])
average_c2_metrics[metric].append(metrics[metric])
if '1c' in tasks:
metrics = query2manifold.test_step(query2manifold, train_triples, train_ans, train_ans, args)
log_metrics('train 1c', step, metrics)
for metric in metrics:
writer.add_scalar('train_1c_'+metric, metrics[metric], step)
average_metrics[metric].append(metrics[metric])
average_c_metrics[metric].append(metrics[metric])
for metric in average_metrics:
writer.add_scalar('train_average_'+metric, np.mean(average_metrics[metric]), step)
for metric in average_c_metrics:
writer.add_scalar('train_average_c_'+metric, np.mean(average_c_metrics[metric]), step)
for metric in average_c2_metrics:
writer.add_scalar('train_average_c2_'+metric, np.mean(average_c2_metrics[metric]), step)
for metric in average_i_metrics:
writer.add_scalar('train_average_i_'+metric, np.mean(average_i_metrics[metric]), step)
if args.do_train:
training_logs = []
if args.task == '1c':
begin_pq_step = args.max_steps
else:
begin_pq_step = args.max_steps - args.stepsforpath
#Training Loop
for step in range(init_step, args.max_steps):
# print ("begining training step", step)
# if step == 100:
# exit(-1)
if step == 2*args.max_steps//3:
args.valid_steps *= 4
if step >= begin_pq_step and not args.train_onehop_only:
if '2i' in tasks:
log = query2manifold.train_step(query2manifold, optimizer, train_iterator_2i, args, step)
for metric in log:
writer.add_scalar('2i_'+metric, log[metric], step)
training_logs.append(log)
if '3i' in tasks:
log = query2manifold.train_step(query2manifold, optimizer, train_iterator_3i, args, step)
for metric in log:
writer.add_scalar('3i_'+metric, log[metric], step)
training_logs.append(log)
if '2c' in tasks:
log = query2manifold.train_step(query2manifold, optimizer, train_iterator_2, args, step)
for metric in log:
writer.add_scalar('2c_'+metric, log[metric], step)
training_logs.append(log)
if '3c' in tasks:
log = query2manifold.train_step(query2manifold, optimizer, train_iterator_3, args, step)
for metric in log:
writer.add_scalar('3c_'+metric, log[metric], step)
training_logs.append(log)
if '1c' in tasks:
log = query2manifold.train_step(query2manifold, optimizer, train_iterator, args, step)
for metric in log:
writer.add_scalar('1c_'+metric, log[metric], step)
training_logs.append(log)
if training_logs == []:
raise Exception("No tasks are trained!!")
if step >= warm_up_steps:
current_learning_rate = current_learning_rate / 10
logging.info('Change learning_rate to %f at step %d' % (current_learning_rate, step))
optimizer = radam.RiemannianAdam(
filter(lambda p: p.requires_grad, query2manifold.parameters()),
lr=current_learning_rate,
)
#optimizer = torch.optim.Adam(
# filter(lambda p: p.requires_grad, query2manifold.parameters()),
# lr=current_learning_rate
#)
warm_up_steps = warm_up_steps * 3
if step % args.save_checkpoint_steps == 0:
save_variable_list = {
'step': step,
'current_learning_rate': current_learning_rate,
'warm_up_steps': warm_up_steps
}
save_model(query2manifold, optimizer, save_variable_list, args)
if step % args.log_steps == 0:
metrics = {}
for metric in training_logs[0].keys():
if metric == 'inter_loss':
continue
metrics[metric] = sum([log[metric] for log in training_logs])/len(training_logs)
inter_loss_sum = 0.
inter_loss_num = 0.
for log in training_logs:
if 'inter_loss' in log:
inter_loss_sum += log['inter_loss']
inter_loss_num += 1
if inter_loss_num != 0:
metrics['inter_loss'] = inter_loss_sum / inter_loss_num
log_metrics('Training average', step, metrics)
training_logs = []
if args.do_valid and step % args.valid_steps == 0:
logging.info('Evaluating on Valid Dataset...')
evaluate_val()
save_variable_list = {
'step': step,
'current_learning_rate': current_learning_rate,
'warm_up_steps': warm_up_steps
}
save_model(query2manifold, optimizer, save_variable_list, args)
try:
print (step)
except:
step = 0
if args.do_valid:
logging.info('Evaluating on Valid Dataset...')
evaluate_val()
if args.do_test:
logging.info('Evaluating on Test Dataset...')
evaluate_test()
if args.evaluate_train:
logging.info('Evaluating on Training Dataset...')
evaluate_train()
print ('Training finished!!')
logging.info("training finished!!")