in python/dglke/dataloader/sampler.py [0:0]
def BalancedRelationPartition(edges, n, has_importance=False):
"""This partitions a list of edges based on relations to make sure
each partition has roughly the same number of edges and relations.
Algo:
For r in relations:
Find partition with fewest edges
if r.size() > num_of empty_slot
put edges of r into this partition to fill the partition,
find next partition with fewest edges to put r in.
else
put edges of r into this partition.
Parameters
----------
edges : (heads, rels, tails) triple
Edge list to partition
n : int
number of partitions
Returns
-------
List of np.array
Edges of each partition
List of np.array
Edge types of each partition
bool
Whether there exists some relations belongs to multiple partitions
"""
if has_importance:
heads, rels, tails, e_impts = edges
else:
heads, rels, tails = edges
print('relation partition {} edges into {} parts'.format(len(heads), n))
uniq, cnts = np.unique(rels, return_counts=True)
idx = np.flip(np.argsort(cnts))
cnts = cnts[idx]
uniq = uniq[idx]
assert cnts[0] > cnts[-1]
edge_cnts = np.zeros(shape=(n,), dtype=np.int64)
rel_cnts = np.zeros(shape=(n,), dtype=np.int64)
rel_dict = {}
rel_parts = []
for _ in range(n):
rel_parts.append([])
max_edges = (len(rels) // n) + 1
num_cross_part = 0
for i in range(len(cnts)):
cnt = cnts[i]
r = uniq[i]
r_parts = []
while cnt > 0:
idx = np.argmin(edge_cnts)
if edge_cnts[idx] + cnt <= max_edges:
r_parts.append([idx, cnt])
rel_parts[idx].append(r)
edge_cnts[idx] += cnt
rel_cnts[idx] += 1
cnt = 0
else:
cur_cnt = max_edges - edge_cnts[idx]
r_parts.append([idx, cur_cnt])
rel_parts[idx].append(r)
edge_cnts[idx] += cur_cnt
rel_cnts[idx] += 1
num_cross_part += 1
cnt -= cur_cnt
rel_dict[r] = r_parts
for i, edge_cnt in enumerate(edge_cnts):
print('part {} has {} edges and {} relations'.format(i, edge_cnt, rel_cnts[i]))
print('{}/{} duplicated relation across partitions'.format(num_cross_part, len(cnts)))
parts = []
for i in range(n):
parts.append([])
rel_parts[i] = np.array(rel_parts[i])
for i, r in enumerate(rels):
r_part = rel_dict[r][0]
part_idx = r_part[0]
cnt = r_part[1]
parts[part_idx].append(i)
cnt -= 1
if cnt == 0:
rel_dict[r].pop(0)
else:
rel_dict[r][0][1] = cnt
for i, part in enumerate(parts):
parts[i] = np.array(part, dtype=np.int64)
shuffle_idx = np.concatenate(parts)
heads[:] = heads[shuffle_idx]
rels[:] = rels[shuffle_idx]
tails[:] = tails[shuffle_idx]
if has_importance:
e_impts[:] = e_impts[shuffle_idx]
off = 0
for i, part in enumerate(parts):
parts[i] = np.arange(off, off + len(part))
off += len(part)
return parts, rel_parts, num_cross_part > 0