in python/dglke/dataloader/sampler.py [0:0]
def SoftRelationPartition(edges, n, has_importance=False, threshold=0.05):
"""This partitions a list of edges to n partitions according to their
relation types. For any relation with number of edges larger than the
threshold, its edges will be evenly distributed into all partitions.
For any relation with number of edges smaller than the threshold, its
edges will be put into one single partition.
Algo:
For r in relations:
if r.size() > threshold
Evenly divide edges of r into n parts and put into each relation.
else
Find partition with fewest edges, and put edges of r into
this partition.
Parameters
----------
edges : (heads, rels, tails) triple
Edge list to partition
n : int
Number of partitions
threshold : float
The threshold of whether a relation is LARGE or SMALL
Default: 5%
Returns
-------
List of np.array
Edges of each partition
List of np.array
Edge types of each partition
bool
Whether there exists some relations belongs to multiple partitions
"""
if has_importance:
heads, rels, tails, e_impts = edges
else:
heads, rels, tails = edges
print('relation partition {} edges into {} parts'.format(len(heads), n))
uniq, cnts = np.unique(rels, return_counts=True)
idx = np.flip(np.argsort(cnts))
cnts = cnts[idx]
uniq = uniq[idx]
assert cnts[0] > cnts[-1]
edge_cnts = np.zeros(shape=(n,), dtype=np.int64)
rel_cnts = np.zeros(shape=(n,), dtype=np.int64)
rel_dict = {}
rel_parts = []
cross_rel_part = []
for _ in range(n):
rel_parts.append([])
large_threshold = int(len(rels) * threshold)
capacity_per_partition = int(len(rels) / n)
# ensure any relation larger than the partition capacity will be split
large_threshold = capacity_per_partition if capacity_per_partition < large_threshold \
else large_threshold
num_cross_part = 0
for i in range(len(cnts)):
cnt = cnts[i]
r = uniq[i]
r_parts = []
if cnt > large_threshold:
avg_part_cnt = (cnt // n) + 1
num_cross_part += 1
for j in range(n):
part_cnt = avg_part_cnt if cnt > avg_part_cnt else cnt
r_parts.append([j, part_cnt])
rel_parts[j].append(r)
edge_cnts[j] += part_cnt
rel_cnts[j] += 1
cnt -= part_cnt
cross_rel_part.append(r)
else:
idx = np.argmin(edge_cnts)
r_parts.append([idx, cnt])
rel_parts[idx].append(r)
edge_cnts[idx] += cnt
rel_cnts[idx] += 1
rel_dict[r] = r_parts
for i, edge_cnt in enumerate(edge_cnts):
print('part {} has {} edges and {} relations'.format(i, edge_cnt, rel_cnts[i]))
print('{}/{} duplicated relation across partitions'.format(num_cross_part, len(cnts)))
parts = []
for i in range(n):
parts.append([])
rel_parts[i] = np.array(rel_parts[i])
for i, r in enumerate(rels):
r_part = rel_dict[r][0]
part_idx = r_part[0]
cnt = r_part[1]
parts[part_idx].append(i)
cnt -= 1
if cnt == 0:
rel_dict[r].pop(0)
else:
rel_dict[r][0][1] = cnt
for i, part in enumerate(parts):
parts[i] = np.array(part, dtype=np.int64)
shuffle_idx = np.concatenate(parts)
heads[:] = heads[shuffle_idx]
rels[:] = rels[shuffle_idx]
tails[:] = tails[shuffle_idx]
if has_importance:
e_impts[:] = e_impts[shuffle_idx]
off = 0
for i, part in enumerate(parts):
parts[i] = np.arange(off, off + len(part))
off += len(part)
cross_rel_part = np.array(cross_rel_part)
return parts, rel_parts, num_cross_part > 0, cross_rel_part