in lib/misc.py [0:0]
def _infer_part(part, concrete_dim, known, index, full_shape):
if type(part) is int:
return part
assert isinstance(part, list), part
lits = []
syms = []
for term in part:
if type(term) is int:
lits.append(term)
elif type(term) is str:
syms.append(term)
else:
raise TypeError(f"got {type(term)} but expected int or str")
int_part = 1
for x in lits:
int_part *= x
if len(syms) == 0:
return int_part
elif len(syms) == 1 and concrete_dim is not None:
assert concrete_dim % int_part == 0, f"{concrete_dim} % {int_part} != 0 (at index {index}, full shape is {full_shape})"
v = concrete_dim // int_part
if syms[0] in known:
assert (
known[syms[0]] == v
), f"known value for {syms[0]} is {known[syms[0]]} but found value {v} at index {index} (full shape is {full_shape})"
else:
known[syms[0]] = v
return concrete_dim
else:
for i in range(len(syms)):
if syms[i] in known:
syms[i] = known[syms[i]]
else:
try:
syms[i] = int(syms[i])
except ValueError:
pass
return lits + syms