in mmdnn/conversion/coreml/coreml_emitter.py [0:0]
def emit_Pool(self, IR_node):
"""
Convert pooling layer to coreml.
"""
# Get input and output names
input_name = self.IR_graph.get_node(IR_node.in_edges[0]).real_name
# Pooling layer type
pooling_type = IR_node.get_attr('pooling_type')
if pooling_type == 'MAX':
layer_type_str = 'MAX'
elif pooling_type == 'AVG':
layer_type_str = 'AVERAGE'
else:
raise TypeError("Pooling type %s not supported" % pooling_type)
# if it's global, set the global flag
global_pooling = IR_node.get_attr('global_pooling', False)
dim = len(IR_node.get_attr('strides')) - 2
if global_pooling:
if dim == 2:
stride_height, stride_width = tuple(IR_node.get_attr('strides')[1:-1])
height, width = 1, 1
# TODO global pooling modification
# Padding
padding = self._get_padding(IR_node)
if isinstance(padding, list):
padding_type = "VALID"
# see protobuf
padding_top, padding_left, padding_bottom, padding_right = padding[1], padding[2], padding[5], padding[6]
else:
padding_type = "SAME"
padding_top, padding_left, padding_bottom, padding_right = 0, 0, 0, 0
elif dim == 1:
raise NotImplementedError()
global_pooling = False
_, width, channels = keras_layer.input_shape
height = 1
stride_height, stride_width = height, width
padding_type = 'VALID'
else:
raise NotImplementedError()
else:
height, width = tuple(IR_node.get_attr('kernel_shape')[1:-1])
stride_height, stride_width = tuple(IR_node.get_attr('strides')[1:-1])
# Padding
padding = self._get_padding(IR_node)
if isinstance(padding, list):
padding_type = "VALID"
# see protobuf
padding_top, padding_left, padding_bottom, padding_right = padding[1], padding [2], padding[5], padding [6]
else:
padding_type = "SAME"
padding_top, padding_left, padding_bottom, padding_right = 0, 0, 0, 0
self.builder.add_pooling(name=IR_node.name,
height=height,
width=width,
stride_height=stride_height,
stride_width=stride_width,
layer_type=layer_type_str,
padding_type=padding_type,
padding_top= padding_top,
padding_left= padding_left,
padding_bottom= padding_bottom,
padding_right= padding_right,
input_name=input_name,
output_name=IR_node.name,
exclude_pad_area=True,
is_global=global_pooling)