in utils/utils.py [0:0]
def decompose(index, shape, stride=None):
'''
This function solve the math problem below:
There is an equation:
index = sum(idx[i] * stride[i])
And given the value of index, stride.
Return the idx.
This function will used to get the pp/dp/pp_rank
from group_index and rank_in_group.
'''
if stride is None:
stride = prefix_product(shape)
idx = [(index // d) % s for s, d in zip(shape, stride)]
# stride is a prefix_product result. And the value of stride[-1]
# is not used.
assert (
sum([x * y for x, y in zip(idx, stride[:-1])]) == index
), "idx {} with shape {} mismatch the return idx {}".format(index, shape, idx)
return idx