in coremltools/converters/mil/frontend/torch/ops.py [0:0]
def batch_norm(context, node):
inputs = _get_inputs(context, node, expected=9)
# inputs skipped:
# float momentum (6)
# bool cudnn_enabled (8)
input_rank = inputs[0].rank
if input_rank < 2 or input_rank > 5:
raise ValueError("BatchNorm: Encountered invalid input rank during translation in torch frontend.")
_input = inputs[0]
weight = inputs[1]
bias = inputs[2]
running_mean = inputs[3]
running_var = inputs[4]
training = inputs[5].val
eps = inputs[7]
name = node.name
# If training = True, the mean and variance of the current batch of data are used to normalize the input data.
# If training = False, data statistics running_mean and running_var are used instead.
# Note that, even in the evaluation mode (after calling model.eval()), the training parameter can still be true
# and it just refers to a different computation as mentioned above.
# helper functions for different type of batch norm
def _add_batch_norm_dynamic():
x = _input
shape = [1] * x.rank
shape[1] = -1 if any_symbolic(running_mean.shape) else running_mean.shape[0]
if training:
axes = [axis for axis in range(x.rank) if axis != 1]
mean = mb.reduce_mean(x=x, axes=axes, keep_dims=True)
num = mb.sub(x=x, y=mean)
square = mb.mul(x=num, y=num)
variance = mb.reduce_mean(x=square, axes=axes, keep_dims=True)
else:
mean = mb.reshape(x=running_mean, shape=shape)
num = mb.sub(x=x, y=mean)
variance = mb.reshape(x=running_var, shape=shape)
variance_add_epsilon = mb.add(x=variance, y=eps)
sqrt = mb.sqrt(x=variance_add_epsilon)
has_weight_bias = weight is not None and bias is not None
name = node.name + "_div" if has_weight_bias else node.name
x = mb.real_div(x=num, y=sqrt, name=name)
if not has_weight_bias:
context.add(x)
return
weight_reshape = mb.reshape(x=weight, shape=shape)
bias_reshape = mb.reshape(x=bias, shape=shape)
x = mb.mul(x=x, y=weight_reshape)
x = mb.add(x=x, y=bias_reshape, name=node.name)
context.add(x)
def _add_batch_norm_1d():
# first expand the 3d tensor to 4d, and call the standard mb.batch_norm
x = mb.expand_dims(x=_input, axes=[-1], name=node.name + "_rank2_expansion")
name = node.name + "_batch_norm_1d"
batch_norm = mb.batch_norm(
x=x,
mean=running_mean,
variance=running_var,
gamma=weight,
beta=bias,
epsilon=eps,
name=name,
)
batch_norm = mb.squeeze(x=batch_norm, name=node.name, axes=[-1])
context.add(batch_norm)
def _add_batch_norm_2d():
batch_norm = mb.batch_norm(
x=_input,
mean=running_mean,
variance=running_var,
gamma=weight,
beta=bias,
epsilon=eps,
name=name,
)
context.add(batch_norm)
def _add_batch_norm_3d():
# # if the input shape is symbolic, bacth norm is computed by breaking it into elementwise ops
# if the input shape is compile time determined, we reshape the tensor
# to a 4d tensor, and call the standard mb.batch_norm
batch_size, channel, height, width, depth = _input.shape
assert not is_symbolic(channel), "Channel dimension must be known for batchnorm layer."
symbolic_num = sum([is_symbolic(x) for x in _input.shape])
if symbolic_num > 1:
weight_expand = mb.expand_dims(x=weight, axes=[0,2,3,4], name=name + "_expand_weight_3d")
bias_exapnd = mb.expand_dims(x=bias, axes=[0,2,3,4], name=name + "_expand_bias_3d")
running_mean_expand = mb.expand_dims(x=running_mean, axes=[0,2,3,4], name=name + "_expand_mean_3d")
running_var_expand = mb.expand_dims(x=running_var, axes=[0,2,3,4], name=name + "_expand_var_3d")
# compute batch norm 3d by decomposing it into elementwise operations
numerator = mb.sub(x=_input, y=running_mean_expand)
denominator = mb.add(x=running_var_expand, y=eps)
denominator = mb.sqrt(x=denominator)
x = mb.real_div(x=numerator, y=denominator)
x = mb.mul(x=x, y=weight_expand)
batch_norm = mb.add(x=x, y=bias_exapnd, name=name)
else:
batch_size, channel, height, width, depth = _input.shape
is_batch_symbloic = is_symbolic(batch_size)
is_height_symbolic = is_symbolic(height)
is_width_symbolic = is_symbolic(width)
is_depth_symbolic = is_symbolic(depth)
if is_batch_symbloic:
shape1 = [-1, channel, height*width, depth]
shape2 = [-1, channel, height, width, depth]
elif is_height_symbolic:
shape1 = [batch_size, channel, -1, width*depth]
shape2 = [batch_size, channel, -1, width, depth]
elif is_width_symbolic:
shape1 = [batch_size, channel, -1, height*depth]
shape2 = [batch_size, channel, height, -1, depth]
elif is_depth_symbolic:
shape1 = [batch_size, channel, height*width, -1]
shape2 = [batch_size, channel, height, width, -1]
else:
shape1 = [batch_size, channel, height*width, depth]
shape2 = [batch_size, channel, height, width, depth]
reshape_4d = mb.reshape(x=_input, shape=shape1, name=name + "_reshape_4d")
batch_norm = mb.batch_norm(
x=reshape_4d,
mean=running_mean,
variance=running_var,
gamma=weight,
beta=bias,
epsilon=eps,
name=name + "_batch_norm_4d",
)
batch_norm = mb.reshape(x=batch_norm, shape=shape2, name=name)
context.add(batch_norm)
is_batch_norm_1d = input_rank == 2
is_batch_norm_2d = (input_rank == 3 or input_rank == 4)
is_batch_norm_3d = input_rank == 5
if training or running_mean.val is None or running_var.val is None or weight is None or bias is None:
_add_batch_norm_dynamic()
elif is_batch_norm_1d:
_add_batch_norm_1d()
elif is_batch_norm_2d:
_add_batch_norm_2d()
elif is_batch_norm_3d:
_add_batch_norm_3d()