def _infer_part()

in lib/misc.py [0:0]


def _infer_part(part, concrete_dim, known, index, full_shape):
    if type(part) is int:
        return part
    assert isinstance(part, list), part
    lits = []
    syms = []
    for term in part:
        if type(term) is int:
            lits.append(term)
        elif type(term) is str:
            syms.append(term)
        else:
            raise TypeError(f"got {type(term)} but expected int or str")
    int_part = 1
    for x in lits:
        int_part *= x
    if len(syms) == 0:
        return int_part
    elif len(syms) == 1 and concrete_dim is not None:
        assert concrete_dim % int_part == 0, f"{concrete_dim} % {int_part} != 0 (at index {index}, full shape is {full_shape})"
        v = concrete_dim // int_part
        if syms[0] in known:
            assert (
                known[syms[0]] == v
            ), f"known value for {syms[0]} is {known[syms[0]]} but found value {v} at index {index} (full shape is {full_shape})"
        else:
            known[syms[0]] = v
        return concrete_dim
    else:
        for i in range(len(syms)):
            if syms[i] in known:
                syms[i] = known[syms[i]]
            else:
                try:
                    syms[i] = int(syms[i])
                except ValueError:
                    pass
        return lits + syms