def import_node()

in tensorwatch/model_graph/hiddenlayer/tf_builder.py [0:0]


def import_node(tf_node, tf_graph, verbose=False):
    # Operation type and name
    op = tf_node.op
    uid = tf_node.name
    name = None

    # Shape
    shape = None
    if tf_node.op != "NoOp":
        try:
            shape = tf.graph_util.tensor_shape_from_node_def_name(tf_graph, tf_node.name)
            # Is the shape is known, convert to a list
            if shape.ndims is not None:
                shape = shape.as_list()
        except:
            if verbose:
                logging.exception("Error reading shape of {}".format(tf_node.name))

    # Parameters
    # At this stage, we really only care about two parameters:
    # 1/ the kernel size used by convolution layers
    # 2/ the stride used by convolutional and pooling layers  (TODO: not fully working yet)

    # 1/ The kernel size is actually not stored in the convolution tensor but in its weight input.
    # The weights input has the shape [shape=[kernel, kernel, in_channels, filters]]
    # So we must fish for it
    params = {}
    if op == "Conv2D" or op == "DepthwiseConv2dNative":
        kernel_shape = tf.graph_util.tensor_shape_from_node_def_name(tf_graph, tf_node.input[1])
        kernel_shape = [int(a) for a in kernel_shape]
        params["kernel_shape"] = kernel_shape[0:2]
        if 'strides' in tf_node.attr.keys():
            strides = [int(a) for a in tf_node.attr['strides'].list.i]
            params["stride"] = strides[1:3]
    elif op == "MaxPool" or op == "AvgPool":
        # 2/ the stride used by pooling layers
        # See https://stackoverflow.com/questions/44124942/how-to-access-values-in-protos-in-tensorflow
        if 'ksize' in tf_node.attr.keys():
            kernel_shape = [int(a) for a in tf_node.attr['ksize'].list.i]
            params["kernel_shape"] = kernel_shape[1:3]
        if 'strides' in tf_node.attr.keys():
            strides = [int(a) for a in tf_node.attr['strides'].list.i]
            params["stride"] = strides[1:3]

    return op, uid, name, shape, params