def get_shape()

in src/exporters/coreml/convert.py [0:0]


def get_shape(config, input_desc, dummy_input, axis=-1):
    """
    Returns the ct.Shape object for the given input.
    """
    default_shape = dummy_input[0].shape
    shape = list(default_shape)

    # Does the input shape need to be flexible?
    if config.use_past:
        #shape[0] = ct.RangeDim()  # batch size  #TODO
        shape[axis] = ct.RangeDim()
        default_shape = None
    elif isinstance(input_desc.sequence_length, tuple):
        min_length, max_length = input_desc.sequence_length
        #shape[0] = ct.RangeDim()  # batch size  #TODO
        shape[axis] = ct.RangeDim(min_length, max_length)
        default_shape = None

    return ct.Shape(shape, default=default_shape)