in clutrr/main.py [0:0]
def generate(self, choice, args, num_rows=0, data_type='train', multi=False, split=None):
"""
Choose the task and the relation length
Return the used args for storing
:param choice:
:param args:
:param num_rows:
:param data_type:
:param multi:
:return:
"""
args = copy.deepcopy(args)
args.num_rows = num_rows
args.data_type = data_type
if not multi:
task, relation_length = choice.split('.')
task_name = 'task_{}'.format(task)
logger.info("mode : {}, task : {}, rel_length : {}".format(data_type, task_name, relation_length))
task_method = getattr(self, task_name, lambda: "Task {} not implemented".format(choice))
args = task_method(args)
args.relation_length = int(relation_length)
store = Store(args)
columns, rows, all_puzzles, train_patterns, test_patterns = generate_rows(args,
store, task_name + '.{}'.format(relation_length), split=split, prev_patterns=self.unique_patterns)
self.unique_patterns[int(relation_length)] = {
'train': train_patterns,
'test': test_patterns
}
return (columns, rows, all_puzzles), args
else:
rows = []
columns = []
puzzles = {}
for ch in choice:
task, relation_length = ch.split('.')
task_name = 'task_{}'.format(task)
logger.info("task : {}, rel_length : {}".format(task_name, relation_length))
task_method = getattr(self, task_name, lambda: "Task {} not implemented".format(choice))
args = task_method(args)
args.relation_length = int(relation_length)
store = Store(args)
columns,r,pz = generate_rows(args, store, task_name + '.{}'.format(relation_length))
rows.extend(r)
puzzles.update(pz)
return ((columns, rows, puzzles), args)