in src/datasets.py [0:0]
def __init__(self, opt, data_path=None):
self.opt = opt
self.name = opt['dataset']
self.device = opt['device']
self.reciprocal = opt['reciprocal']
if data_path is None:
self.root = DATA_PATH / self.name
else:
self.root = Path(data_path)
self.data = {}
self.splits = ['train', 'valid', 'test']
for f in self.splits:
p = str(self.root / (f + '.pickle'))
if os.path.isfile(p):
with open(p, 'rb') as in_file:
self.data[f] = pickle.load(in_file)
else:
p = str(self.root / (f + '.npy'))
with open(p, 'rb') as in_file:
self.data[f] = np.load(in_file)
maxis = np.max(self.data['train'], axis=0)
self.n_entities = int(max(maxis[0], maxis[2]) + 1)
self.n_predicates = int(maxis[1] + 1)
self.include_type = self.name in ['ogbl-biokg'] # self.data['train'].shape[1] == 5
self.bsz_vt = 16 if self.name in ['ogbl-wikikg2'] else 1000
if self.reciprocal:
self.n_predicates *= 2
if os.path.isfile(str(self.root / 'to_skip.pickle')):
print('Loading to_skip file ...')
with open(str(self.root / f'to_skip.pickle'), 'rb') as inp_f:
self.to_skip = pickle.load(inp_f) # {'lhs': {(11, 3): [1, 3, 0, 4, 5, 19]}}
if os.path.isfile(str(self.root / 'meta_info.pickle')):
print('Loading meta_info file ...')
with open(str(self.root / f'meta_info.pickle'), 'rb') as inp_f:
self.meta_info = pickle.load(inp_f)
print('{} Dataset Stat: {}'.format(self.name, self.get_shape()))
n_train = len(self.get_examples('train'))
n_valid = len(self.get_examples('valid'))
n_test = len(self.get_examples('test'))
print('Train/Valid/Test {}/{}/{}'.format(n_train, n_valid, n_test))
tot = 1.0 * (n_train + n_valid + n_test)
print('Train/Valid/Test {:.3f}/{:.3f}/{:.3f}'.format(n_train / tot,
n_valid / tot,
n_test / tot))
self.examples_train = torch.from_numpy(self.get_split(split='train'))
self.examples_valid = torch.from_numpy(self.get_split(split='valid'))