in mmdnn/conversion/tensorflow/tensorflow_emitter.py [0:0]
def emit_Pool(self, IR_node):
pooling_type = IR_node.get_attr('pooling_type')
if pooling_type == 'MAX':
op = 'max_pool'
padding_const = ", constant_values=float('-Inf')"
elif pooling_type == 'AVG':
op = 'avg_pool'
padding_const = ""
else:
raise ValueError("unknown pooling type [{}].".format(pooling_type))
arrlen = len(IR_node.get_attr('strides'))
dim_str = '3d' if arrlen == 5 else ""
if IR_node.layer.attr['global_pooling'].b:
code = "{:<15} = tf.nn.{}{}({}, [1] + {}.get_shape().as_list()[1:-1] + [1], strides = [1] * {}, padding = 'VALID', name = '{}')".format(
IR_node.variable_name,
op,
dim_str,
self.parent_variable_name(IR_node),
self.parent_variable_name(IR_node),
arrlen,
IR_node.name)
else:
dim = len(IR_node.get_attr("strides")) - 2
dilations = IR_node.get_attr('dilations')
if dilations:
for e in IR_node.get_attr('dilations'):
assert e == 1
pool_size = IR_node.get_attr('kernel_shape')[1:-1]
strides = IR_node.get_attr('strides')[1:-1]
padding = IR_node.get_attr('pads')[1:dim]
if pooling_type == "AVG" and pool_size.count(pool_size[0]) == len(pool_size) and strides[0] == 1 and strides.count(strides[0]) == len(strides) and padding.count(padding[0]) == len(padding) and pool_size[0] == padding[0]*2 + 1:
kernel_shape_str = ', '.join('%s' % i for i in IR_node.get_attr('kernel_shape'))
strides_str = ', '.join('%s' % i for i in IR_node.get_attr('strides'))
code = "{:<15} = tf.nn.{}{}({}, [{}], [{}], padding='{}', name='{}')".format(
IR_node.variable_name,
op,
dim_str,
self.parent_variable_name(IR_node),
kernel_shape_str,
strides_str,
'SAME',
IR_node.name)
else:
kernel_shape_str = ', '.join('%s' % i for i in IR_node.get_attr('kernel_shape'))
strides_str = ', '.join('%s' % i for i in IR_node.get_attr('strides'))
input_node, padding = self._defuse_padding(IR_node, padding_const)
code = "{:<15} = tf.nn.{}{}({}, [{}], [{}], padding='{}', name='{}')".format(
IR_node.variable_name,
op,
dim_str,
input_node,
kernel_shape_str,
strides_str,
padding,
IR_node.name)
return code