def dim_change_forward()

in tinynn/graph/modifier.py [0:0]


    def dim_change_forward(self, center, tensor, dim_changes_i: typing.List, dim_transform, tensor_constraint):
        # Leaf nodes require no additional computation
        if len(self.next_tensors()) == 0:
            if len(self.pre_tensors()) > 1:
                for dim_change_i in dim_changes_i:
                    dims = [t.shape[dim_change_i] for t in self.pre_tensors()]
                    if len(set(dims)) > 1:
                        log.warning(f"Skip the {self.unique_name()} because the input shape is inconsistent")
                        return True
            self.dim_changes_info.update_i(center, tensor, dim_changes_i, dim_transform)
            return True

        if self.node.kind() == "data":
            return True

        # Skip constant node
        if self.constant_node:
            return True

        # The default implementation is regarded as Identity()
        # Directly inheriting the dim_constraint of the previous layer, reducing the amount of calculation
        tensor_constraint = self.dim_changes_info.update_i(
            center, tensor, dim_changes_i, dim_transform, tensor_constraint=tensor_constraint
        )

        for tensor_o in self.next_tensors():
            if id(tensor) == id(tensor_o):
                continue

            try:
                tensor_o.copy_(tensor.clone())
            except Exception as e:
                log.error(
                    f"error modifier = {self.unique_name()}, type = {type(self.module())}, kind = {self.node.kind()}"
                )
                raise e

        for tensor_o in self.next_tensors():
            # Case [1, c0, c1] + [c0, c1](center_node) -> [1, c0, c1], to keep dim_change_o keep consistent.
            if len(tensor_o.shape) > len(tensor.shape) and tensor_o.shape[0] == 1:
                old_dim_change_i = dim_changes_i
                omitted_dim_len = 1
                dim_changes_i = [i + omitted_dim_len for i in dim_changes_i]
                for dim_ in old_dim_change_i:
                    tensor_constraint[dim_ + omitted_dim_len] = tensor_constraint[dim_]
                    tensor_constraint.pop(dim_)

            self.dim_changes_info.update_o(center, tensor_o, dim_changes_i)

            for m in self.next_modifiers(tensor_o):
                # The identity() operator does not change the constraint, so it can directly pass its own
                # constraints to reduce the calculation of the next layer
                m.dim_change_forward(center, tensor_o, dim_changes_i, dim_transform, tensor_constraint)