def batch_norm()

in coremltools/converters/mil/frontend/torch/ops.py [0:0]


def batch_norm(context, node):
    inputs = _get_inputs(context, node, expected=9)
    # inputs skipped:
    #   float momentum (6)
    #   bool cudnn_enabled (8)
    input_rank = inputs[0].rank
    if input_rank < 2 or input_rank > 5:
        raise ValueError("BatchNorm: Encountered invalid input rank during translation in torch frontend.")

    _input = inputs[0]
    weight = inputs[1]
    bias = inputs[2]
    running_mean = inputs[3]
    running_var = inputs[4]
    training = inputs[5].val
    eps = inputs[7]
    name = node.name

    # If training = True, the mean and variance of the current batch of data are used to normalize the input data.
    # If training = False, data statistics running_mean and running_var are used instead.
    # Note that, even in the evaluation mode (after calling model.eval()), the training parameter can still be true
    # and it just refers to a different computation as mentioned above.

    # helper functions for different type of batch norm
    def _add_batch_norm_dynamic():
        x = _input
        shape = [1] * x.rank
        shape[1] = -1 if any_symbolic(running_mean.shape) else running_mean.shape[0]

        if training:
            axes = [axis for axis in range(x.rank) if axis != 1]
            mean = mb.reduce_mean(x=x, axes=axes, keep_dims=True)
            num = mb.sub(x=x, y=mean)
            square = mb.mul(x=num, y=num)
            variance = mb.reduce_mean(x=square, axes=axes, keep_dims=True)
        else:
            mean = mb.reshape(x=running_mean, shape=shape)
            num = mb.sub(x=x, y=mean)
            variance = mb.reshape(x=running_var, shape=shape)

        variance_add_epsilon = mb.add(x=variance, y=eps)
        sqrt = mb.sqrt(x=variance_add_epsilon)

        has_weight_bias = weight is not None and bias is not None
        name = node.name + "_div" if has_weight_bias else node.name

        x = mb.real_div(x=num, y=sqrt, name=name)

        if not has_weight_bias:
            context.add(x)
            return

        weight_reshape = mb.reshape(x=weight, shape=shape)
        bias_reshape = mb.reshape(x=bias, shape=shape)

        x = mb.mul(x=x, y=weight_reshape)
        x = mb.add(x=x, y=bias_reshape, name=node.name)

        context.add(x)

    def _add_batch_norm_1d():
        # first expand the 3d tensor to 4d, and call the standard mb.batch_norm
        x = mb.expand_dims(x=_input, axes=[-1], name=node.name + "_rank2_expansion")
        name = node.name + "_batch_norm_1d"
        batch_norm = mb.batch_norm(
            x=x,
            mean=running_mean,
            variance=running_var,
            gamma=weight,
            beta=bias,
            epsilon=eps,
            name=name,
        )
        batch_norm = mb.squeeze(x=batch_norm, name=node.name, axes=[-1])
        context.add(batch_norm)

    def _add_batch_norm_2d():
        batch_norm = mb.batch_norm(
            x=_input,
            mean=running_mean,
            variance=running_var,
            gamma=weight,
            beta=bias,
            epsilon=eps,
            name=name,
        )
        context.add(batch_norm)

    def _add_batch_norm_3d():
        # # if the input shape is symbolic, bacth norm is computed by breaking it into elementwise ops
        # if the input shape is compile time determined, we reshape the tensor
        # to a 4d tensor, and call the standard mb.batch_norm
        batch_size, channel, height, width, depth = _input.shape
        assert not is_symbolic(channel), "Channel dimension must be known for batchnorm layer."

        symbolic_num = sum([is_symbolic(x) for x in _input.shape])

        if symbolic_num > 1:
            weight_expand = mb.expand_dims(x=weight, axes=[0,2,3,4], name=name + "_expand_weight_3d")
            bias_exapnd = mb.expand_dims(x=bias, axes=[0,2,3,4], name=name + "_expand_bias_3d")
            running_mean_expand = mb.expand_dims(x=running_mean, axes=[0,2,3,4], name=name + "_expand_mean_3d")
            running_var_expand = mb.expand_dims(x=running_var, axes=[0,2,3,4], name=name + "_expand_var_3d")

            # compute batch norm 3d by decomposing it into elementwise operations
            numerator = mb.sub(x=_input, y=running_mean_expand)
            denominator = mb.add(x=running_var_expand, y=eps)
            denominator = mb.sqrt(x=denominator)
            x = mb.real_div(x=numerator, y=denominator)
            x = mb.mul(x=x, y=weight_expand)
            batch_norm = mb.add(x=x, y=bias_exapnd, name=name)

        else:
            batch_size, channel, height, width, depth = _input.shape
            is_batch_symbloic = is_symbolic(batch_size)
            is_height_symbolic = is_symbolic(height)
            is_width_symbolic = is_symbolic(width)
            is_depth_symbolic = is_symbolic(depth)

            if is_batch_symbloic:
                shape1 = [-1, channel, height*width, depth]
                shape2 = [-1, channel, height, width, depth]

            elif is_height_symbolic:
                shape1 = [batch_size, channel, -1, width*depth]
                shape2 = [batch_size, channel, -1, width, depth]

            elif is_width_symbolic:
                shape1 = [batch_size, channel, -1, height*depth]
                shape2 = [batch_size, channel, height, -1, depth]

            elif is_depth_symbolic:
                shape1 = [batch_size, channel, height*width, -1]
                shape2 = [batch_size, channel, height, width, -1]

            else:
                shape1 = [batch_size, channel, height*width, depth]
                shape2 = [batch_size, channel, height, width, depth]

            reshape_4d = mb.reshape(x=_input, shape=shape1, name=name + "_reshape_4d")
            batch_norm = mb.batch_norm(
                x=reshape_4d,
                mean=running_mean,
                variance=running_var,
                gamma=weight,
                beta=bias,
                epsilon=eps,
                name=name + "_batch_norm_4d",
            )
            batch_norm = mb.reshape(x=batch_norm, shape=shape2, name=name)

        context.add(batch_norm)

    is_batch_norm_1d = input_rank == 2
    is_batch_norm_2d = (input_rank == 3 or input_rank == 4)
    is_batch_norm_3d = input_rank == 5

    if training or running_mean.val is None or running_var.val is None or weight is None or bias is None:
        _add_batch_norm_dynamic()
    elif is_batch_norm_1d:
        _add_batch_norm_1d()
    elif is_batch_norm_2d:
        _add_batch_norm_2d()
    elif is_batch_norm_3d:
        _add_batch_norm_3d()