in coremltools/converters/mil/frontend/torch/ops.py [0:0]
def lstm(context, node):
inputs = _get_inputs(context, node, expected=9)
_input = inputs[0]
# there are two cases here,
# (1) the input tensor is a PackedSequence object,
# in this case, the second input of the lstm layer is the batch_size (MIL Var).
# (2) the input tensor is a normal tensor,
# in this case, the second input is an array.
# As the result, we can use the second input to identify which category the graph is.
has_batch_sizes = not isinstance(inputs[1], Iterable)
if has_batch_sizes:
batch_sizes = inputs[1]
h0, c0 = inputs[2]
weights_list = inputs[3]
has_bias = inputs[4].val
num_layers = inputs[5].val
dropout = inputs[6]
bidirectional = inputs[8].val
# the output of the _pack_padded_sequence is always in the layout of batch first
batch_first = True
else:
h0, c0 = inputs[1]
weights_list = inputs[2]
has_bias = inputs[3].val
num_layers = inputs[4].val
dropout = inputs[5]
bidirectional = inputs[7].val
batch_first = inputs[8].val
'''
Torch LSTM layer's input shapes:
(1) first input
(Seq, B, C) : if batch_first = False
(B, Seq, C) : if batch_first = True
(2) & (3) initialization states
(num_layers, B, H) : if bidirectional = False
(num_layers * 2, B, H) : if bidirectional = True
For the MIL LSTM layer, these are the input shapes:
(1) first input: (Seq, B, C)
this means, if batch_first=True, we need to insert a transpose op first
(2) & (3) initialization states
MIL's LSTM layer does not natively support the "num_layers" parameters.
So, when num_layers > 1, we add multiple MIL LSTM ops in a sequence.
Each of these LSTM ops will take in initialization states in the following shape:
(B, H) if bidirectional = False
(B, 2*H) if bidirectional = True
'''
if batch_first:
_input = mb.transpose(x=_input, perm=[1, 0, 2], name=_input.name + "_batch_first_transpose")
expected_num_weights = 2 * num_layers * (int(bidirectional) + 1) * (int(has_bias) + 1)
if len(weights_list) != expected_num_weights:
raise ValueError(
"Incorrect weights shape for lstm layer: Expected: {}. Recieved {}".format(
expected_num_weights, len(weights_list)
)
)
# shape of h0 and c0 are (num_layers * n_directions, B, H)
if num_layers == 1:
all_initial_h = [h0] # [(n_directions, B, H)]
all_initial_c = [c0] # [(n_directions, B, H)]
else:
all_initial_h = mb.split(x=h0, num_splits=num_layers, axis=0) # [(n_directions, B, H)]
all_initial_c = mb.split(x=c0, num_splits=num_layers, axis=0) # [(n_directions, B, H)]
n_weights_per_layer = int(len(weights_list) / num_layers)
x = _input
h_out_list = []
c_out_list = []
for i in range(num_layers):
if i < num_layers - 1:
op_name = node.name + "_lstm_layer_{}".format(i)
else:
if batch_first:
op_name = node.name + "_batch_first"
else:
op_name = node.name
lstm_out = _add_mil_lstm(input=x,
initial_h=all_initial_h[i],
initial_c=all_initial_c[i],
weights=weights_list[i * n_weights_per_layer : (i+1) * n_weights_per_layer],
has_bias=has_bias,
bidirectional=bidirectional,
name=op_name)
x = lstm_out[0] # shape of lstm_out[0] == (S,B,H) if bidirectional = True else (S, B, 2*H)
h_out_list.append(lstm_out[1]) # shape of lstm_out[1] == (B,H) if bidirectional = False else (B, 2*H)
c_out_list.append(lstm_out[2]) # shape of lstm_out[2] == (B,H) if bidirectional = False else (B, 2*H)
'''
For torch, these are the dimensions of the 3 output tensors:
(1) output[0] :
(Seq, B, H) if batch_first = False, bidirectional = False
(Seq, B, 2*H) if batch_first = False, bidirectional = True
(B, Seq, H) if batch_first = True, bidirectional = False
(B, Seq, 2*H) if batch_first = True, bidirectional = True
(2) & (3) these are the state outputs:
(num_layers, B, H) if bidirectional = False
(num_layers * 2, B, H) if bidirectional = True
MIL lstm layer's output shapes:
(1) output[0]:
(Seq, B, H) if bidirectional = False
(Seq, B, 2*H) if bidirectional = True
This means we need a transpose op if batch_first is True
(2) & (3) shapes of the state outputs:
each MIL LSTM op will produce final state tensors with the following shape:
(B, H) if bidirectional = False
(B, 2*H) if bidirectional = True
stack/expand the final state tensors to match the Torch output
'''
for index, (name, output) in enumerate(zip(node.outputs, lstm_out)):
if index > 0:
# index > 0 ===> its one of the state outputs (h or c)
if bidirectional:
if num_layers == 1:
out1, out2 = mb.split(x=output, num_splits=2, axis=1) # each output of shape [B, H] after the split
final_out = mb.stack(values=[out1, out2], axis=0, name=name) # [2, B, H]
context.add(final_out, name)
else:
out_state_tensors_list = h_out_list if index == 1 else c_out_list # each tensor in the list is of shape (B, 2*H)
list_of_tensors_to_stack = []
for i in range(num_layers):
out1, out2 = mb.split(x=out_state_tensors_list[i], num_splits=2, axis=1) # each output of shape [B, H] after the split
out = mb.stack(values=[out1, out2], axis=0) # [2, B, H]
list_of_tensors_to_stack.append(out)
final_out = mb.concat(values=list_of_tensors_to_stack, axis=0, name=name) # output of shape (num_layers * 2, B, H)
context.add(final_out, name)
else:
if num_layers == 1:
unsqueeze = mb.expand_dims(x=output, axes=[0], name=name)
context.add(unsqueeze, name)
else:
out = mb.stack(values=h_out_list if index == 1 else c_out_list, axis=0, name=name)
context.add(out, name)
else:
if batch_first:
output = mb.transpose(x=output, perm=[1, 0, 2], name=name)
context.add(output, name)