def lstm()

in coremltools/converters/mil/backend/nn/op_mapping.py [0:0]


def lstm(const_context, builder, op):
    make_input(const_context, builder, [op.x, op.initial_h, op.initial_c])
    # Input shape [b, s, I]
    input_name = op.x.name
    # Shape: [b, DIRECTION*H]
    initial_h = op.initial_h.name
    initial_c = op.initial_c.name

    wt_ih = op.weight_ih.val
    wt_hh = op.weight_hh.val
    b = op.bias.val if op.bias is not None else None
    direction = op.direction.val
    output_sequence = op.output_sequence.val
    peephole = op.peephole.val if op.peephole is not None else None
    # High enough clip value to be ineffective!
    clip = 500.0 if op.clip is None else op.clip.val

    # Add expand dims for input, in
    _expand_dim(builder, input_name + "_expanded", input_name, [3, 4])
    input_name += "_expanded"

    if direction in {"forward", "reverse"}:
        # Expand initial_h and initial_c,
        # from shape (B, H) to shape (1, Batch, H, 1, 1)
        _expand_dim(builder, initial_h + "_expanded", initial_h, [0, 3, 4])
        initial_h += "_expanded"
        # initial_h may have the same name as initial_c (e.g., same Var).
        # Append a different string to avoid conflict
        _expand_dim(builder, initial_c + "_expanded2", initial_c, [0, 3, 4])
        initial_c += "_expanded2"

        # w_x: [H*I, H*I, H*I, H*I]
        # w_h: [H*H, H*H, H*H, H*H]
        # where format is, [input gate, forget gate, output gate, cell gate]
        w_x = _split(wt_ih, sections=4)
        w_h = _split(wt_hh, sections=4)
        # bias format: [4*H]
        b = _split(b, sections=4)  # ifoz layout
        # peephole format: [3*H]
        # where format is, [input gate, forget gate, output gate]
        peephole = _split(peephole, sections=3)

        input_size = w_x[0].shape[1]
        hidden_size = w_h[0].shape[1]

        # 3 outputs
        # Y  : [s/1, b, h, 1, 1]
        # Y_h: [  1, b, h, 1, 1]
        # Y_c: [  1, b, h, 1, 1]
        output_names = [_output.name + "_5d" for _output in op.outputs]
        builder.add_unilstm(
            name=op.name,
            W_h=w_h,
            W_x=w_x,
            b=b,
            hidden_size=hidden_size,
            input_size=input_size,
            input_names=[input_name, initial_h, initial_c],
            output_names=output_names,
            inner_activation=op.recurrent_activation.val,
            cell_state_update_activation=op.cell_activation.val,
            output_activation=op.activation.val,
            peep=peephole,
            output_all=output_sequence,
            cell_clip_threshold=clip,
            reverse_input=(direction == "reverse"),
        )

        # Squeeze Output
        # to output shape of [Seq Len or 1, Batch Size, Hidden Size]
        _squeeze(builder, op.outputs[0].name, output_names[0], axes=[3, 4])
        # Squeeze Output H and Output C
        # to output shape of [Batch Size, Hidden Size]
        _squeeze(builder, op.outputs[1].name, output_names[1], axes=[0, 3, 4])
        _squeeze(builder, op.outputs[2].name, output_names[2], axes=[0, 3, 4])

    elif direction == "bidirectional":
        # Expand initial_h and initial_c
        # Issue #810
        num_layer = len(builder.layers)
        initial_h_expand = initial_h + "_expanded" + "_" + str(num_layer)
        # from shape (B, 2*H) to shape (1, Batch, 2*H, 1, 1)
        if not (initial_h_expand in set(builder.layers)):
            _expand_dim(builder, initial_h_expand, initial_h, [0, 3, 4])
        initial_h = initial_h_expand

        # initial_h may have the same name as initial_c (e.g., same Var)
        initial_c_expand = initial_c + "_expanded2" + "_" + str(num_layer)
        if not (initial_c_expand in set(builder.layers)):
            _expand_dim(builder, initial_c_expand, initial_c, [0, 3, 4])
        initial_c = initial_c_expand

        initial_h_f = initial_h + "_forward"
        initial_h_r = initial_h + "_reverse"
        initial_c_f = initial_c + "_forward"
        initial_c_r = initial_c + "_reverse"

        # split input_h and input_c into two parts
        builder.add_split_nd(
            name=op.name + "_split_h",
            input_name=initial_h,
            output_names=[initial_h_f, initial_h_r],
            axis=2,
        )
        builder.add_split_nd(
            name=op.name + "_split_c",
            input_name=initial_c,
            output_names=[initial_c_f, initial_c_r],
            axis=2,
        )

        wt_ih_back = op.weight_ih_back.val
        wt_hh_back = op.weight_hh_back.val
        # Get weights here
        # weight format: [I+H, 2*4*H] -> [I+H, 4*H (forward):4*H (backward)]
        hidden_size = wt_hh.shape[1]
        input_size = wt_ih.shape[1]

        # f_w_x and r_w_x: [H*I, H*I, H*I, H*I]
        # f_w_h and r_w_h: [H*H, H*H, H*H, H*H]
        # where format is, [input gate, forget gate, output gate, cell gate]
        w_x = _split(wt_ih, sections=4)
        w_h = _split(wt_hh, sections=4)
        r_w_x = _split(wt_ih_back, sections=4)
        r_w_h = _split(wt_hh_back, sections=4)

        # f_b and r_b format: [4*H]
        b_back = op.bias_back.val if op.bias_back is not None else None
        f_b, r_b = None, None
        if b is not None:
            f_b = _split(b, sections=4)
        if b_back is not None:
            r_b = _split(b_back, sections=4)

        # peephole format: [2*3*H] -> [3*H (forward) : 3*H (backward)]
        peephole_back = op.peephole_back.val if op.peephole_back is not None else None
        f_peephole, r_peephole = None, None
        if peephole is not None:
            f_peephole = _split(peephole, sections=3)
        if peephole_back is not None:
            r_peephole = _split(peephole_back, sections=3)

        output_names = [
            op.outputs[0].name + "_5d",  # Output Y           [s/1, b, 2*h, 1, 1]
            op.outputs[1].name + "_5d_foward",  # Output Y_h         [  1, b,   h, 1, 1]
            op.outputs[2].name
            + "_5d_forward",  # Output Y_c         [  1, b,   h, 1, 1]
            op.outputs[1].name
            + "_5d_reverse",  # Output Y_h_reverse [  1, b,   h, 1, 1]
            op.outputs[2].name + "_5d_reverse",
        ]  # Output Y_c_reverse [  1, b,   h, 1, 1]

        builder.add_bidirlstm(
            name=op.name,
            W_h=w_h,
            W_x=w_x,
            b=f_b,
            W_h_back=r_w_h,
            W_x_back=r_w_x,
            b_back=r_b,
            hidden_size=hidden_size,
            input_size=input_size,
            input_names=[
                input_name,
                initial_h_f,
                initial_c_f,
                initial_h_r,
                initial_c_r,
            ],
            output_names=output_names,
            inner_activation=op.recurrent_activation.val,
            cell_state_update_activation=op.cell_activation.val,
            output_activation=op.activation.val,
            peep=f_peephole,
            peep_back=r_peephole,
            output_all=output_sequence,
            cell_clip_threshold=clip,
        )

        # Squeeze Output
        # to output shape of [Seq Len or 1, Batch Size, 2*Hidden Size]
        _squeeze(builder, op.outputs[0].name, output_names[0], axes=[3, 4])

        # Output H is of format
        # 1, Batch_Size, Hidden_Size, 1, 1
        # Concat to make it
        # 1, Batch_Size, 2*Hidden_Size, 1, 1
        builder.add_elementwise(
            name=op.outputs[1].name + "_5d",
            input_names=[output_names[1], output_names[3]],
            output_name=op.outputs[1].name + "_5d",
            mode="CONCAT",
        )
        # Output C is of format
        # 1, Batch_Size, Hidden_Size, 1, 1
        builder.add_elementwise(
            name=op.outputs[2].name + "_5d",
            input_names=[output_names[2], output_names[4]],
            output_name=op.outputs[2].name + "_5d",
            mode="CONCAT",
        )

        # Squeeze Output H and Output C
        # to output shape of [Batch Size, 2*Hidden Size]
        _squeeze(
            builder, op.outputs[1].name, op.outputs[1].name + "_5d", axes=[0, 3, 4]
        )
        _squeeze(
            builder, op.outputs[2].name, op.outputs[2].name + "_5d", axes=[0, 3, 4]
        )
    else:
        raise ValueError(
            "Unknown direction {} for LSTM layer. Supported are forward, reverse or bidirectional".format(
                direction
            )
        )