tinynn/converter/operators/torch/aten.py (3,397 lines of code) (raw):
import warnings
import numpy as np
import torch
from tinynn.util.util import get_logger
from ...schemas.tflite import schema_generated as tfl_schema
from ...schemas.torch.aten_schema import *
from .. import CommonGraph
from .. import tflite as tfl
log = get_logger(__name__, 'INFO')
class AtenSignOperator(ATenSignSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
self.elementwise_unary(tfl.SignOperator, graph_converter)
class ATenLstmOperator(ATenLstmSchema):
def lstm_input_helper(
self, input_tensors, params_tensors, has_biases, param_start_index, input_start_index, layer_idx, suffix
):
hybrid = isinstance(self, ATenQuantizedLstmOperator)
weight_ih_slices = torch.chunk(params_tensors[param_start_index], 4, 0)
gates = ["input", "forget", "cell", "output"]
for idx, (weight_ih, gate) in enumerate(zip(weight_ih_slices, gates)):
input_tensors[input_start_index + idx] = self.create_attr_tensor(weight_ih, hybrid=hybrid)
weight_hh_slices = torch.chunk(params_tensors[param_start_index + 1], 4, 0)
for idx, (weight_hh, gate) in enumerate(zip(weight_hh_slices, gates)):
input_tensors[input_start_index + 4 + idx] = self.create_attr_tensor(weight_hh, hybrid=hybrid)
if has_biases:
assert params_tensors[param_start_index + 2].dtype == torch.float32
assert params_tensors[param_start_index + 3].dtype == torch.float32
fused_bias = params_tensors[param_start_index + 2] + params_tensors[param_start_index + 3]
fused_bias_slices = torch.chunk(fused_bias, 4, 0)
for idx, (bias, gate) in enumerate(zip(fused_bias_slices, gates)):
input_tensors[input_start_index + 11 + idx] = self.create_attr_tensor(bias)
else:
bias_shape = input_tensors[input_start_index + 3].shape[:1]
for idx, gate in enumerate(gates):
bias = torch.zeros(bias_shape, dtype=torch.float32)
input_tensors[input_start_index + 11 + idx] = self.create_attr_tensor(bias)
def lstm_hidden_state_helper(
self,
input_tensors,
hidden_state_tensors,
hidden_state_index,
input_index,
num_directions,
direction_idx,
num_layers,
layer_idx,
suffix,
state_type,
tf_state_tensors,
):
hidden_state_tensor = hidden_state_tensors[hidden_state_index]
tf_state_tensor = tf_state_tensors[hidden_state_index]
assert hidden_state_tensor.dim() == 3
slice_idx = layer_idx * num_directions + direction_idx
if tf_state_tensor[slice_idx] is None:
input_tensors[input_index] = self.create_attr_tensor(hidden_state_tensor[slice_idx])
input_tensors[input_index].is_variable = True
else:
assert self.unroll_rnn, "Input state tensors are only supported when unroll_rnn=True is specified"
input_tensors[input_index] = tf_state_tensor[slice_idx]
def parse_common(
self,
input_tensor,
hidden_state_tensors,
params_tensors,
has_biases,
num_layers,
dropout,
is_train,
bidirectional,
batch_first,
graph_converter,
):
assert is_train in (False, 0)
expected_num_params = 2 * num_layers
params_step = 2
if has_biases:
expected_num_params *= 2
params_step *= 2
if bidirectional:
expected_num_params *= 2
assert (
len(params_tensors) == expected_num_params
), f'num of params in LSTM is wrong. got: {len(params_tensors)}, expected: {expected_num_params}'
num_input_tensors = 24
num_directions = 1
state_start_index = 18
if bidirectional:
num_input_tensors *= 2
num_directions *= 2
state_start_index = 35
suffixes = ["_fw", "_bw"]
state_kinds = ["act", "cell"]
param_start_indices = [0, params_step]
input_start_indices = [1, 18]
ops = []
names = graph_converter.get_list_expanded_names(self.input_names[1])
tf_in_state_tensors = [graph_converter.tensor_map.get(n, None) for n in names]
tf_state_tensors = []
unpacked_tensors = {}
for t in tf_in_state_tensors:
if t is not None and self.unroll_rnn:
tensors = [
self.create_transform_tensor(np.squeeze(x, 0))
for x in np.split(t.tensor, num_directions * num_layers, 0)
]
tf_state_tensors.append(tensors)
ops.append(tfl.UnpackOperator([t], tensors, len(tensors), 0))
else:
tf_state_tensors.append([None] * num_directions * num_layers)
current_input = self.find_or_create_input(0, graph_converter)
lstm_output = self.to_tfl_tensors(self.output_names[:1], self.output_tensors[:1])[0]
params_offset = 0
tf_out_state_tensors = [[], []]
for layer_idx in range(num_layers):
inputs = [current_input] + [tfl.OptionalTensorInstance] * (num_input_tensors - 1)
for direction_idx in range(num_directions):
self.lstm_input_helper(
inputs,
params_tensors,
has_biases,
params_offset + param_start_indices[direction_idx],
input_start_indices[direction_idx],
layer_idx,
suffixes[direction_idx],
)
for direction_idx in range(num_directions):
for state_kind_idx in range(len(state_kinds)):
self.lstm_hidden_state_helper(
inputs,
hidden_state_tensors,
state_kind_idx,
state_start_index + direction_idx * num_directions + state_kind_idx,
num_directions,
direction_idx,
num_layers,
layer_idx,
suffixes[direction_idx],
state_kinds[state_kind_idx],
tf_state_tensors,
)
if layer_idx == num_layers - 1:
layer_output = lstm_output
else:
output_shape = list(input_tensor.shape)
output_shape[-1] = inputs[6].shape[1] * num_directions
layer_output = self.create_transform_tensor(np.empty(output_shape, dtype=inputs[0].dtype))
outputs = [layer_output]
if self.unroll_rnn:
ts_axis = 1 if batch_first else 0
num_timestep = inputs[0].shape[ts_axis]
if inputs[0].name in unpacked_tensors:
input_ts = unpacked_tensors[inputs[0].name]
else:
input_ts = [
self.create_transform_tensor(np.squeeze(x, ts_axis))
for x in np.split(inputs[0].tensor, num_timestep, ts_axis)
]
ops.append(tfl.UnpackOperator([inputs[0]], input_ts, num_timestep, ts_axis))
strides = [1, -1]
output_ts = []
for direction_idx in range(num_directions):
input_start = input_start_indices[direction_idx]
if not self.separated_rnn_gate_calc:
w_i = self.create_attr_tensor(
np.concatenate([inputs[x].tensor for x in range(input_start, input_start + 4)], 0),
quantization=inputs[input_start].quantization,
)
w_r = self.create_attr_tensor(
np.concatenate([inputs[x].tensor for x in range(input_start + 4, input_start + 8)], 0),
quantization=inputs[input_start + 4].quantization,
)
b_i = self.create_attr_tensor(
np.concatenate([inputs[x].tensor for x in range(input_start + 11, input_start + 15)], 0)
)
b_r = self.create_attr_tensor(np.zeros_like(b_i.tensor))
else:
w_i_list = [inputs[x] for x in range(input_start, input_start + 4)]
w_r_list = [inputs[x] for x in range(input_start + 4, input_start + 8)]
b_i_list = [inputs[x] for x in range(input_start + 11, input_start + 15)]
b_r_list = [self.create_attr_tensor(np.zeros_like(b_i.tensor)) for b_i in b_i_list]
state_start = state_start_index + direction_idx * num_directions
h = inputs[state_start]
c = inputs[state_start + 1]
stride = strides[direction_idx]
# Skip some computations for the first timestep
compute_h = h.buffer is None or np.any(h.tensor)
compute_c = c.buffer is None or np.any(c.tensor)
stacked_hs = []
for i, t in enumerate(input_ts[::stride]):
if not self.separated_rnn_gate_calc:
input_mm = self.create_transform_tensor(
np.matmul(t.tensor, np.transpose(w_i.tensor, [1, 0])) + b_i.tensor
)
ops.append(tfl.FullyConnectedOperator([t, w_i, b_i], [input_mm]))
else:
input_mm_list = []
for j, (w_i, b_i) in enumerate(zip(w_i_list, b_i_list)):
if j == 1 and i == 0 and not compute_c:
input_mm_list.append(None)
continue
input_mm = self.create_transform_tensor(
np.matmul(t.tensor, np.transpose(w_i.tensor, [1, 0])) + b_i.tensor
)
ops.append(tfl.FullyConnectedOperator([t, w_i, b_i], [input_mm]))
input_mm_list.append(input_mm)
if i != 0 or compute_h:
if not self.separated_rnn_gate_calc:
hidden_mm = self.create_transform_tensor(
np.matmul(h.tensor, np.transpose(w_r.tensor, [1, 0])) + b_r.tensor
)
ops.append(tfl.FullyConnectedOperator([h, w_r, b_r], [hidden_mm]))
add_out = self.create_transform_tensor(input_mm.tensor + hidden_mm.tensor)
ops.append(tfl.AddOperator([input_mm, hidden_mm], [add_out]))
else:
hidden_mm_list = []
for j, (w_r, b_r) in enumerate(zip(w_r_list, b_r_list)):
if j == 1 and i == 0 and not compute_c:
hidden_mm_list.append(None)
continue
hidden_mm = self.create_transform_tensor(
np.matmul(h.tensor, np.transpose(w_r.tensor, [1, 0])) + b_r.tensor
)
ops.append(tfl.FullyConnectedOperator([h, w_r, b_r], [hidden_mm]))
hidden_mm_list.append(hidden_mm)
gate_outs = []
for input_mm, hidden_mm in zip(input_mm_list, hidden_mm_list):
if input_mm is not None and hidden_mm is not None:
add_out = self.create_transform_tensor(input_mm.tensor + hidden_mm.tensor)
ops.append(tfl.AddOperator([input_mm, hidden_mm], [add_out]))
gate_outs.append(add_out)
else:
if not self.separated_rnn_gate_calc:
add_out = input_mm
else:
gate_outs = input_mm_list
if not self.separated_rnn_gate_calc:
gate_outs = [self.create_transform_tensor(t) for t in np.split(add_out.tensor, 4, 1)]
split_dim_tensor = self.create_attr_tensor(np.array(1, dtype='int32'))
ops.append(tfl.SplitOperator([split_dim_tensor, add_out], gate_outs, 4))
gate_i = self.create_transform_tensor(
torch.sigmoid(torch.from_numpy(gate_outs[0].tensor)).numpy()
)
ops.append(tfl.LogisticOperator([gate_outs[0]], [gate_i]))
if i != 0 or compute_c:
gate_f = self.create_transform_tensor(
torch.sigmoid(torch.from_numpy(gate_outs[1].tensor)).numpy()
)
ops.append(tfl.LogisticOperator([gate_outs[1]], [gate_f]))
gate_g = self.create_transform_tensor(np.tanh(gate_outs[2].tensor))
ops.append(tfl.TanhOperator([gate_outs[2]], [gate_g]))
gate_o = self.create_transform_tensor(
torch.sigmoid(torch.from_numpy(gate_outs[3].tensor)).numpy()
)
ops.append(tfl.LogisticOperator([gate_outs[3]], [gate_o]))
if i != 0 or compute_c:
c_left = self.create_transform_tensor(gate_f.tensor * c.tensor)
ops.append(tfl.MulOperator([gate_f, c], [c_left]))
c_right = self.create_transform_tensor(gate_i.tensor * gate_g.tensor)
ops.append(tfl.MulOperator([gate_i, gate_g], [c_right]))
if i != 0 or compute_c:
c = self.create_transform_tensor(c_left.tensor + c_right.tensor)
ops.append(tfl.AddOperator([c_left, c_right], [c]))
else:
c = c_right
c_act = self.create_transform_tensor(np.tanh(c.tensor))
ops.append(tfl.TanhOperator([c], [c_act]))
h = self.create_transform_tensor(gate_o.tensor * c_act.tensor)
ops.append(tfl.MulOperator([gate_o, c_act], [h]))
stacked_hs.append(h)
tf_out_state_tensors[0].append(h)
tf_out_state_tensors[1].append(c)
output_ts.extend(stacked_hs[::stride])
if bidirectional:
# For bidirectional LSTMs, the forward output tensors and the backward output tensors are
# concatenated before we pack them together
fw_out = self.create_transform_tensor(
np.stack([x.tensor for x in output_ts[:num_timestep]], ts_axis)
)
ops.append(tfl.PackOperator(output_ts[:num_timestep], [fw_out], num_timestep, axis=ts_axis))
bw_out = self.create_transform_tensor(
np.stack([x.tensor for x in output_ts[num_timestep:]], ts_axis)
)
ops.append(tfl.PackOperator(output_ts[num_timestep:], [bw_out], num_timestep, axis=ts_axis))
ops.append(tfl.ConcatenationOperator([fw_out, bw_out], outputs, axis=2))
elif layer_idx != num_layers - 1:
# Reusing unpacked tensors for the logic in the next layer
unpacked_tensors[outputs[0].name] = output_ts
else:
# For the last layer, we have to pack the together
ops.append(tfl.PackOperator(output_ts, outputs, len(output_ts), axis=ts_axis))
elif bidirectional:
if not self.map_bilstm_to_lstm:
ops.append(
tfl.BidirectionalSequenceLstmOperator(
inputs,
outputs,
fusedActivationFunction=tfl_schema.ActivationFunctionType.TANH,
timeMajor=not batch_first,
mergeOutputs=True,
asymmetricQuantizeInputs=self.hybrid_asymmetric_inputs,
)
)
else:
fw_i_end = input_start_indices[-1]
fw_s_start = state_start_index
fw_s_end = state_start_index + len(state_kinds)
fw_pad = num_input_tensors // 2 - fw_s_end
fw_lstm_inputs = (
inputs[:fw_i_end] + inputs[fw_s_start:fw_s_end] + [tfl.OptionalTensorInstance] * fw_pad
)
fw_out, bw_out = [
self.create_transform_tensor(t, quantization=outputs[0].quantization)
for t in np.split(outputs[0].tensor, 2, -1)
]
ops.append(
tfl.UnidirectionalSequenceLstmOperator(
fw_lstm_inputs,
[fw_out],
fusedActivationFunction=tfl_schema.ActivationFunctionType.TANH,
timeMajor=not batch_first,
asymmetricQuantizeInputs=self.hybrid_asymmetric_inputs,
)
)
time_dim = 1 if batch_first else 0
bw_in = self.create_transform_tensor(np.flip(current_input.tensor, time_dim))
bw_dim = self.create_attr_tensor(np.array([time_dim], dtype='int32'))
ops.append(tfl.ReverseV2Operator([current_input, bw_dim], [bw_in]))
bw_raw_out = self.create_transform_tensor(np.flip(bw_out.tensor, time_dim))
bw_o_start = input_start_indices[-1]
bw_o_end = state_start_index
bw_s_start = state_start_index + len(state_kinds)
bw_s_end = state_start_index + len(state_kinds) * num_directions
bw_pad = num_input_tensors // 2 - bw_s_end
bw_lstm_inputs = (
[bw_in]
+ inputs[bw_o_start:bw_o_end]
+ inputs[bw_s_start:bw_s_end]
+ [tfl.OptionalTensorInstance] * bw_pad
)
ops.append(
tfl.UnidirectionalSequenceLstmOperator(
bw_lstm_inputs,
[bw_raw_out],
fusedActivationFunction=tfl_schema.ActivationFunctionType.TANH,
timeMajor=not batch_first,
asymmetricQuantizeInputs=self.hybrid_asymmetric_inputs,
)
)
ops.append(tfl.ReverseV2Operator([bw_raw_out, bw_dim], [bw_out]))
ops.append(tfl.ConcatenationOperator([fw_out, bw_out], outputs, axis=2))
else:
ops.append(
tfl.UnidirectionalSequenceLstmOperator(
inputs,
outputs,
fusedActivationFunction=tfl_schema.ActivationFunctionType.TANH,
timeMajor=not batch_first,
asymmetricQuantizeInputs=self.hybrid_asymmetric_inputs,
)
)
current_input = outputs[0]
params_offset += params_step * num_directions
if self.unroll_rnn:
state_outputs = self.to_tfl_tensors(self.output_names[1:], self.output_tensors[1:])
for i, (orig, new) in enumerate(zip(tf_in_state_tensors, tf_out_state_tensors)):
if orig is not None:
pack_op = tfl.PackOperator(new, state_outputs[i : i + 1], len(new), 0)
pack_op.extra_hints['warn_on_unused'] = False
ops.append(pack_op)
else:
ops[-1].extra_hints['cell_output'] = self.output_names[-1]
common_names = set(self.output_names[1:]) & set(graph_converter.outputs)
assert len(common_names) == 0, (
f"Please remove the LSTM state outputs ({common_names}) from the model. Alternatively, you can try"
" unroll_rnn=True"
)
for op in ops:
graph_converter.add_operator(op)
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
input_tensor, hidden_state_tensors, params_tensors = self.input_tensors[:3]
has_biases, num_layers, dropout, is_train, bidirectional, batch_first = self.input_tensors[3:]
self.parse_common(
input_tensor,
hidden_state_tensors,
params_tensors,
has_biases,
num_layers,
dropout,
is_train,
bidirectional,
batch_first,
graph_converter,
)
class ATenGruOperator(ATenGruSchema):
def gru_input_helper(
self, input_tensors, params_tensors, has_biases, param_start_index, input_start_index, layer_idx, suffix
):
wir, wiz, win = torch.chunk(params_tensors[param_start_index], 3, 0)
whr, whz, whn = torch.chunk(params_tensors[param_start_index + 1], 3, 0)
wr = torch.cat((wir, whr), -1)
wz = torch.cat((wiz, whz), -1)
# [2*n_output, n_input+n_output]
input_tensors[input_start_index] = self.create_attr_tensor(torch.cat((wr, wz), 0))
# [n_output, n_input+n_output]
input_tensors[input_start_index + 2] = self.create_attr_tensor(torch.cat((win, whn), -1))
w_i_list = [self.create_attr_tensor(wir), self.create_attr_tensor(wiz), self.create_attr_tensor(win)]
w_r_list = [self.create_attr_tensor(whr), self.create_attr_tensor(whz), self.create_attr_tensor(whn)]
if has_biases:
assert params_tensors[param_start_index + 2].dtype == torch.float32
assert params_tensors[param_start_index + 3].dtype == torch.float32
bir, biz, bin = torch.chunk(params_tensors[param_start_index + 2], 3, 0)
bhr, bhz, bhn = torch.chunk(params_tensors[param_start_index + 3], 3, 0)
br = torch.cat((bir, bhr), -1)
bz = torch.cat((biz, bhz), -1)
input_tensors[input_start_index + 1] = self.create_attr_tensor(torch.cat((br, bz), -1)) # [2*n_output]
input_tensors[input_start_index + 3] = self.create_attr_tensor(torch.cat((bin, bhn), -1)) # [n_output]
b_i_list = [self.create_attr_tensor(bir), self.create_attr_tensor(biz), self.create_attr_tensor(bin)]
b_r_list = [self.create_attr_tensor(bhr), self.create_attr_tensor(bhz), self.create_attr_tensor(bhn)]
else:
bir = torch.zeros(input_tensors[input_start_index + 2].shape[0])
biz = torch.zeros_like(bir)
bin = torch.zeros_like(biz)
bhr = torch.zeros_like(bin)
bhz = torch.zeros_like(bhr)
bhn = torch.zeros_like(bhz)
input_tensors[input_start_index + 1] = self.create_attr_tensor(
torch.zeros(input_tensors[input_start_index].shape[0], dtype=torch.float32)
)
input_tensors[input_start_index + 3] = self.create_attr_tensor(
torch.zeros(input_tensors[input_start_index + 2].shape[0], dtype=torch.float32)
)
b_i_list = [self.create_attr_tensor(bir), self.create_attr_tensor(biz), self.create_attr_tensor(bin)]
b_r_list = [self.create_attr_tensor(bhr), self.create_attr_tensor(bhz), self.create_attr_tensor(bhn)]
return w_i_list, w_r_list, b_i_list, b_r_list
def gru_hidden_state_helper(
self,
input_tensors,
hidden_state_tensor,
input_index,
num_directions,
direction_idx,
num_layers,
layer_idx,
suffix,
state_type,
tf_state_tensors,
):
tf_state_tensor = tf_state_tensors[0]
assert hidden_state_tensor.dim() == 3
slice_idx = layer_idx * num_directions + direction_idx
if tf_state_tensor[slice_idx] is None:
input_tensors[input_index] = self.create_attr_tensor(hidden_state_tensor[slice_idx])
input_tensors[input_index].is_variable = True
else:
assert self.unroll_rnn, "Input state tensors are only supported when unroll_rnn=True is specified"
input_tensors[input_index] = tf_state_tensor[slice_idx]
def parse_common(
self,
input_tensor,
hidden_state_tensor,
params_tensors,
has_biases,
num_layers,
dropout,
is_train,
bidirectional,
batch_first,
graph_converter,
):
assert is_train in (False, 0)
self.unroll_rnn = True
expected_num_params = 2 * num_layers
params_step = 2
if has_biases:
expected_num_params *= 2
params_step *= 2
if bidirectional:
expected_num_params *= 2
assert (
len(params_tensors) == expected_num_params
), f'num of params in GRU is wrong. got: {len(params_tensors)}, expected: {expected_num_params}'
num_input_tensors = 7
num_directions = 1
state_start_index = [1, 8]
if bidirectional:
num_input_tensors *= 2
num_directions *= 2
suffixes = ["_fw", "_bw"]
state_kinds = ["hidden"]
param_start_indices = [0, params_step]
input_start_indices = [2, 9]
ops = []
name = self.input_names[1]
tf_in_state_tensors = [graph_converter.tensor_map.get(n, None) for n in name]
tf_in_state_tensors = [
self.find_or_create_input(1, graph_converter) if name in graph_converter.tensor_map else None
]
tf_state_tensors = []
unpacked_tensors = {}
for t in tf_in_state_tensors:
if t is not None and self.unroll_rnn:
tensors = [
self.create_transform_tensor(np.squeeze(x, 0))
for x in np.split(t.tensor, num_directions * num_layers, 0)
]
tf_state_tensors.append(tensors)
ops.append(tfl.UnpackOperator([t], tensors, len(tensors), 0))
else:
tf_state_tensors.append([None] * num_directions * num_layers)
current_input = self.find_or_create_input(0, graph_converter)
gru_output = self.to_tfl_tensors(self.output_names[:1], self.output_tensors[:1])[0]
params_offset = 0
tf_out_state_tensors = [[]]
for layer_idx in range(num_layers):
inputs = [current_input] + [tfl.OptionalTensorInstance] * (num_input_tensors - 1)
for direction_idx in range(num_directions):
w_i_list, w_r_list, b_i_list, b_r_list = self.gru_input_helper(
inputs,
params_tensors,
has_biases,
params_offset + param_start_indices[direction_idx],
input_start_indices[direction_idx],
layer_idx,
suffixes[direction_idx],
)
self.gru_hidden_state_helper(
inputs,
hidden_state_tensor,
state_start_index[direction_idx],
num_directions,
direction_idx,
num_layers,
layer_idx,
suffixes[direction_idx],
state_kinds[0],
tf_state_tensors,
)
if layer_idx == num_layers - 1:
layer_output = gru_output
else:
output_shape = list(input_tensor.shape)
output_shape[-1] = inputs[4].shape[0] * num_directions
layer_output = self.create_transform_tensor(np.empty(output_shape, dtype=inputs[0].dtype))
outputs = [layer_output]
if self.unroll_rnn:
ts_axis = 1 if batch_first else 0
num_timestep = inputs[0].shape[ts_axis]
if inputs[0].name in unpacked_tensors:
input_ts = unpacked_tensors[inputs[0].name]
else:
input_ts = [
self.create_transform_tensor(np.squeeze(x, ts_axis))
for x in np.split(inputs[0].tensor, num_timestep, ts_axis)
]
ops.append(tfl.UnpackOperator([inputs[0]], input_ts, num_timestep, ts_axis))
strides = [1, -1]
output_ts = []
for direction_idx in range(num_directions):
w_i_list, w_r_list, b_i_list, b_r_list = self.gru_input_helper(
inputs,
params_tensors,
has_biases,
params_offset + param_start_indices[direction_idx],
input_start_indices[direction_idx],
layer_idx,
suffixes[direction_idx],
)
state_start = state_start_index[direction_idx]
h = inputs[state_start]
stride = strides[direction_idx]
# Skip some computations for the first timestep
compute_h = h.buffer is None or np.any(h.tensor)
stacked_hs = []
for i, t in enumerate(input_ts[::stride]):
input_mm_list = []
if not self.separated_rnn_gate_calc:
wir, wiz, win = w_i_list
whr, whz, whn = w_r_list
bir, biz, bin = b_i_list
bhr, bhz, bhn = b_r_list
w_i = self.create_attr_tensor(np.concatenate([wir.tensor, wiz.tensor, win.tensor], 0))
w_h = self.create_attr_tensor(np.concatenate([whr.tensor, whz.tensor, whn.tensor], 0))
b_i = self.create_attr_tensor(np.concatenate([bir.tensor, biz.tensor, bin.tensor], 0))
b_h = self.create_attr_tensor(np.concatenate([bhr.tensor, bhz.tensor, bhn.tensor], 0))
input_mm = self.create_transform_tensor(
np.matmul(t.tensor, np.transpose(w_i.tensor, [1, 0])) + b_i.tensor
)
hidden_mm = self.create_transform_tensor(
np.matmul(h.tensor, np.transpose(w_h.tensor, [1, 0])) + b_h.tensor
)
ops.append(tfl.FullyConnectedOperator([t, w_i, b_i], [input_mm]))
ops.append(tfl.FullyConnectedOperator([h, w_h, b_h], [hidden_mm]))
left_in = np.split(input_mm.tensor, 3, axis=1)
dim_tensor = self.create_attr_tensor(np.array(1, dtype='int32'))
splited_left_in = [self.create_transform_tensor(t) for t in left_in]
ops.append(tfl.SplitOperator([dim_tensor, input_mm], splited_left_in, 3))
right_in = np.split(hidden_mm.tensor, 3, axis=-1)
splited_right_in = [self.create_transform_tensor(t) for t in right_in]
ops.append(tfl.SplitOperator([dim_tensor, hidden_mm], splited_right_in, 3))
rgate_left_in, zgate_left_in, ngate_left_in = splited_left_in
rgate_right_in, zgate_right_in, ngate_right_in_b = splited_right_in
rgate_in = self.create_transform_tensor(rgate_left_in.tensor + rgate_right_in.tensor)
ops.append(tfl.AddOperator([rgate_left_in, rgate_right_in], [rgate_in]))
rgate_out = self.create_transform_tensor(
torch.sigmoid(torch.from_numpy(rgate_in.tensor)).numpy()
)
ops.append(tfl.LogisticOperator([rgate_in], [rgate_out]))
zgate_in = self.create_transform_tensor(zgate_left_in.tensor + zgate_right_in.tensor)
ops.append(tfl.AddOperator([zgate_left_in, zgate_right_in], [zgate_in]))
zgate_out = self.create_transform_tensor(
torch.sigmoid(torch.from_numpy(zgate_in.tensor)).numpy()
)
ops.append(tfl.LogisticOperator([zgate_in], [zgate_out]))
ngate_right_in = self.create_transform_tensor(rgate_out.tensor * ngate_right_in_b.tensor)
ops.append(tfl.MulOperator([rgate_out, ngate_right_in_b], [ngate_right_in]))
ngate_in = self.create_transform_tensor(ngate_left_in.tensor + ngate_right_in.tensor)
ops.append(tfl.AddOperator([ngate_left_in, ngate_right_in], [ngate_in]))
ngate_out = self.create_transform_tensor(
torch.tanh(torch.from_numpy(ngate_in.tensor)).numpy()
)
ops.append(tfl.TanhOperator([ngate_in], [ngate_out]))
constant_tensor = self.create_attr_tensor(torch.tensor(1, dtype=torch.float32))
h_left_0 = self.create_transform_tensor(constant_tensor.tensor - zgate_out.tensor)
ops.append(tfl.SubOperator([constant_tensor, zgate_out], [h_left_0]))
h_left = self.create_transform_tensor(h_left_0.tensor * ngate_out.tensor)
ops.append(tfl.MulOperator([h_left_0, ngate_out], [h_left]))
if i != 0 or compute_h:
h_right = self.create_transform_tensor(zgate_out.tensor * h.tensor)
ops.append(tfl.MulOperator([zgate_out, h], [h_right]))
h = self.create_transform_tensor(h_left.tensor + h_right.tensor)
ops.append(tfl.AddOperator([h_left, h_right], [h]))
elif i == 0 and not compute_h:
h = h_left
stacked_hs.append(h)
else:
for j, (w_i, b_i) in enumerate(zip(w_i_list, b_i_list)):
input_mm = self.create_transform_tensor(
np.matmul(t.tensor, np.transpose(w_i.tensor, [1, 0])) + b_i.tensor
)
ops.append(tfl.FullyConnectedOperator([t, w_i, b_i], [input_mm]))
input_mm_list.append(input_mm)
if i != 0 or compute_h:
hidden_mm_list = []
for j, (w_r, b_r) in enumerate(zip(w_r_list, b_r_list)):
hidden_mm = self.create_transform_tensor(
np.matmul(h.tensor, np.transpose(w_r.tensor, [1, 0])) + b_r.tensor
)
ops.append(tfl.FullyConnectedOperator([h, w_r, b_r], [hidden_mm]))
hidden_mm_list.append(hidden_mm)
else:
hidden_mm_list = b_r_list
# calculate r,z,n gates
rgate_in = self.create_transform_tensor(input_mm_list[0].tensor + hidden_mm_list[0].tensor)
ops.append(tfl.AddOperator([input_mm_list[0], hidden_mm_list[0]], [rgate_in]))
zgate_in = self.create_transform_tensor(input_mm_list[1].tensor + hidden_mm_list[1].tensor)
ops.append(tfl.AddOperator([input_mm_list[1], hidden_mm_list[1]], [zgate_in]))
zgate_out = self.create_transform_tensor(
torch.sigmoid(torch.from_numpy(zgate_in.tensor)).numpy()
)
ops.append(tfl.LogisticOperator([zgate_in], [zgate_out]))
rgate_out = self.create_transform_tensor(
torch.sigmoid(torch.from_numpy(rgate_in.tensor)).numpy()
)
ops.append(tfl.LogisticOperator([rgate_in], [rgate_out]))
ngate_in_hside = self.create_transform_tensor(rgate_out.tensor * hidden_mm_list[2].tensor)
ops.append(tfl.MulOperator([rgate_out, hidden_mm_list[2]], [ngate_in_hside]))
ngate_in = self.create_transform_tensor(input_mm_list[2].tensor + ngate_in_hside.tensor)
ops.append(tfl.AddOperator([input_mm_list[2], ngate_in_hside], [ngate_in]))
ngate_out = self.create_transform_tensor(
torch.tanh(torch.from_numpy(ngate_in.tensor)).numpy()
)
ops.append(tfl.TanhOperator([ngate_in], [ngate_out]))
constant_tensor = self.create_attr_tensor(torch.tensor(1, dtype=torch.float32))
h_left_0 = self.create_transform_tensor(constant_tensor.tensor - zgate_out.tensor)
ops.append(tfl.SubOperator([constant_tensor, zgate_out], [h_left_0]))
h_left = self.create_transform_tensor(h_left_0.tensor * ngate_out.tensor)
ops.append(tfl.MulOperator([h_left_0, ngate_out], [h_left]))
if i != 0 or compute_h:
h_right = self.create_transform_tensor(zgate_out.tensor * h.tensor)
ops.append(tfl.MulOperator([zgate_out, h], [h_right]))
h = self.create_transform_tensor(h_left.tensor + h_right.tensor)
ops.append(tfl.AddOperator([h_left, h_right], [h]))
elif i == 0 and not compute_h:
h = h_left
stacked_hs.append(h)
tf_out_state_tensors[0].append(h)
output_ts.extend(stacked_hs[::stride])
if bidirectional:
fw_out = self.create_transform_tensor(
np.stack([x.tensor for x in output_ts[:num_timestep]], ts_axis)
)
ops.append(tfl.PackOperator(output_ts[:num_timestep], [fw_out], num_timestep, axis=ts_axis))
bw_out = self.create_transform_tensor(
np.stack([x.tensor for x in output_ts[:num_timestep]], ts_axis)
)
ops.append(tfl.PackOperator(output_ts[num_timestep:], [bw_out], num_timestep, axis=ts_axis))
ops.append(tfl.ConcatenationOperator([fw_out, bw_out], outputs, axis=2))
elif layer_idx != num_layers - 1:
# Reusing unpacked tensors for the logic in the next layer
unpacked_tensors[outputs[0].name] = output_ts
else:
# For the last layer, we have to pack the together
ops.append(tfl.PackOperator(output_ts, outputs, len(output_ts), axis=ts_axis))
current_input = outputs[0]
params_offset += params_step * num_directions
if self.unroll_rnn:
state_outputs = self.to_tfl_tensors(self.output_names[1:], self.output_tensors[1:])
for i, (orig, new) in enumerate(zip(tf_in_state_tensors, tf_out_state_tensors)):
if orig is not None:
pack_op = tfl.PackOperator(new, state_outputs[i : i + 1], len(new), 0)
pack_op.extra_hints['warn_on_unused'] = False
ops.append(pack_op)
for op in ops:
graph_converter.add_operator(op)
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
input_tensor, hidden_state_tensor, params_tensors = self.input_tensors[:3]
has_biases, num_layers, dropout, is_train, bidirectional, batch_first = self.input_tensors[3:]
self.parse_common(
input_tensor,
hidden_state_tensor,
params_tensors,
has_biases,
num_layers,
dropout,
is_train,
bidirectional,
batch_first,
graph_converter,
)
class ATenBatchNormOperator(ATenBatchNormSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
eps = self.input_tensors[args['eps']]
# weight
if self.input_tensors[1] is None:
self.input_names[1] = self.get_unique_attr_name()
self.input_tensors[1] = torch.ones(self.input_tensors[0].size(1), dtype=torch.float32)
# bias
if self.input_tensors[2] is None:
self.input_names[2] = self.get_unique_attr_name()
self.input_tensors[2] = torch.zeros(self.input_tensors[0].size(1), dtype=torch.float32)
# running mean & var
assert (
self.input_tensors[3] is not None and self.input_tensors[4] is not None
), "Running mean and variance should not be None for aten::batch_norm. Otherwise, use LayerNorm instead."
inputs = [self.find_or_create_input(i, graph_converter) for i in range(5)]
outputs = self.to_tfl_tensors(self.output_names, self.output_tensors)
graph_converter.add_operator(tfl.BatchNormOperator(inputs, outputs, eps))
class ATenConstantPadNdOperator(ATenConstantPadNdSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
input_tensor = self.find_or_create_input(0, graph_converter)
pads = self.input_tensors[1]
constant_value = self.input_tensors[2]
orig_pad = np.array(pads, dtype='int32').reshape(-1, 2)
pad_fill = np.zeros((input_tensor.tensor.ndim - orig_pad.shape[0], 2), dtype='int32')
pad_arr = np.flip(np.concatenate((orig_pad, pad_fill)), 0)
pad_tensor = self.create_attr_tensor(pad_arr)
inputs = [input_tensor, pad_tensor]
outputs = self.to_tfl_tensors(self.output_names, self.output_tensors)
if constant_value not in (0, 0.0):
output = outputs[0]
if output.quantization is None:
constant_arr = np.array([constant_value], dtype='float32')
else:
float_arr = torch.tensor([constant_value], dtype=torch.float32)
constant_arr = torch.quantize_per_tensor(
float_arr, output.quantization.scale, output.quantization.zero_point, torch.quint8
)
inputs.append(self.create_attr_tensor(constant_arr))
graph_converter.add_operator(tfl.Padv2Operator(inputs, outputs))
else:
graph_converter.add_operator(tfl.PadOperator(inputs, outputs))
class ATenUpsampleNearest2dOperator(ATenUpsampleNearest2dSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
input_tensor = self.find_or_create_input(0, graph_converter)
output_size = self.input_tensors[1]
if output_size is None:
scale_factors = np.array(self.input_tensors[2], dtype='float64')
input_sizes = np.array(input_tensor.shape[2:], dtype='float64')
output_size = (input_sizes * scale_factors).astype('int32')
output_sizes = self.create_attr_tensor(np.array(output_size, dtype='int32'))
inputs = [input_tensor, output_sizes]
outputs = self.to_tfl_tensors(self.output_names, self.output_tensors)
ops = [tfl.ResizeNearestNeighborOperator(inputs, outputs, halfPixelCenters=False)]
ops = self.wrap_ops_with_nhwc_nchw_transposes(ops)
for op in ops:
graph_converter.add_operator(op)
class ATenUpsampleBilinear2dOperator(ATenUpsampleBilinear2dSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
input_tensor = self.find_or_create_input(0, graph_converter)
output_size = self.input_tensors[1]
if output_size is None:
scale_factors = np.array(self.input_tensors[3], dtype='float64')
input_sizes = np.array(input_tensor.shape[2:], dtype='float64')
output_size = (input_sizes * scale_factors).astype('int32')
output_sizes = self.create_attr_tensor(np.array(output_size, dtype='int32'))
align_corners = self.input_tensors[2] in (True, 1)
half_pixel_centers = not align_corners
inputs = [input_tensor, output_sizes]
outputs = self.to_tfl_tensors(self.output_names, self.output_tensors)
ops = [tfl.ResizeBilinearOperator(inputs, outputs, align_corners, half_pixel_centers)]
ops = self.wrap_ops_with_nhwc_nchw_transposes(ops)
for op in ops:
graph_converter.add_operator(op)
class ATenAvgPool2dOperator(ATenAvgPool2dSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
inputs = [self.find_or_create_input(0, graph_converter)]
outputs = self.to_tfl_tensors(self.output_names, self.output_tensors)
kernel_h, kernel_w = self.input_tensors[1]
stride_h, stride_w = self.input_tensors[2] or (kernel_h, kernel_w)
padding_h, padding_w = self.input_tensors[3]
ceil_mode = self.input_tensors[4] in (True, 1)
count_include_pad = self.input_tensors[5] in (True, 1)
divisor_override = self.input_tensors[6]
assert (
divisor_override is None or divisor_override == kernel_h == kernel_w
), "Only divisor_override == kernel_h == kernel_w is supported"
padding = tfl_schema.Padding.VALID
avgpool_op = tfl.AveragePool2dOperator(inputs, outputs, padding, stride_w, stride_h, kernel_w, kernel_h)
ops = self.wrap_ops_with_nhwc_nchw_transposes([avgpool_op])
self.handle_padding(padding_h, padding_w, 1, ops, ceil_mode)
if not count_include_pad:
mask = 1.0 / torch.nn.functional.avg_pool2d(
torch.ones_like(self.input_tensors[0]),
(kernel_h, kernel_w),
(stride_h, stride_w),
(padding_h, padding_w),
ceil_mode,
count_include_pad=True,
)
mask_permuted = mask.permute(0, 2, 3, 1)
mask_t = self.create_attr_tensor(mask_permuted)
before_mask = outputs[0].tensor / mask_permuted
before_mask_t = self.create_transform_tensor(before_mask)
actual_out = ops[-2].outputs[0]
ops[-2].outputs[0] = before_mask_t
ops.insert(-1, tfl.MulOperator([before_mask_t, mask_t], [actual_out]))
for op in ops:
graph_converter.add_operator(op)
class ATenAdaptiveAvgPool2dOperator(ATenAdaptiveAvgPool2dSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
input_tensor = self.find_or_create_input(0, graph_converter)
output_h, output_w = self.input_tensors[1]
dim_h, dim_w = input_tensor.shape[2:]
assert (
dim_h % output_h == 0 and dim_w % output_w == 0
), f'not supported: input dim: [{dim_h}, {dim_w}], output size: [{output_h}, {output_w}]'
assert input_tensor.tensor.ndim == 4, 'Only 4D input is supported'
ops = []
dims = self.create_attr_tensor(np.array([1, 2], dtype='int32'))
outputs = self.to_tfl_tensors(self.output_names, self.output_tensors)
if output_h == 1 and output_w == 1:
inputs = [input_tensor, dims]
ops.append(tfl.MeanOperator(inputs, outputs, True))
else:
inputs = [input_tensor]
padding = tfl_schema.Padding.VALID
stride_h, stride_w = dim_h // output_h, dim_w // output_w
kernel_h, kernel_w = dim_h - (output_h - 1) * stride_h, dim_w - (output_w - 1) * stride_w
ops.append(tfl.AveragePool2dOperator(inputs, outputs, padding, stride_w, stride_h, kernel_w, kernel_h))
ops = self.wrap_ops_with_nhwc_nchw_transposes(ops)
for op in ops:
graph_converter.add_operator(op)
class ATenAdaptiveMaxPool2dOperator(ATenAdaptiveMaxPool2dSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
input_tensor = self.find_or_create_input(0, graph_converter)
output_h, output_w = self.input_tensors[1]
dim_h, dim_w = input_tensor.shape[2:]
assert (
dim_h % output_h == 0 and dim_w % output_w == 0
), f'not supported: input dim: [{dim_h}, {dim_w}], output size: [{output_h}, {output_w}]'
assert input_tensor.tensor.ndim == 4, 'Only 4D input is supported'
ops = []
dims = self.create_attr_tensor(np.array([1, 2], dtype='int32'))
log.warning(
'OPs like`F.adaptive_maxpool_2d` have multiple outputs. However, only the first '
'output will be preserved in our converter. If you need that tensor, please '
'use the `torch.argmax` instead.'
)
outputs = self.to_tfl_tensors(self.output_names[:1], self.output_tensors[:1])
if output_h == 1 and output_w == 1:
inputs = [input_tensor, dims]
ops.append(tfl.ReduceMaxOperator(inputs, outputs, True))
else:
inputs = [input_tensor]
padding = tfl_schema.Padding.VALID
stride_h, stride_w = dim_h // output_h, dim_w // output_w
kernel_h, kernel_w = dim_h - (output_h - 1) * stride_h, dim_w - (output_w - 1) * stride_w
ops.append(tfl.MaxPool2dOperator(inputs, outputs, padding, stride_w, stride_h, kernel_w, kernel_h))
ops = self.wrap_ops_with_nhwc_nchw_transposes(ops)
for op in ops:
graph_converter.add_operator(op)
class ATenLeakyReluOperator(ATenLeakyReluSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
self.elementwise_unary(tfl.LeakyReluOperator, graph_converter, self.input_tensors[1])
class ATenEluOperator(ATenEluSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
assert all(x == 1.0 for x in self.input_tensors[1:]), "Only alpha == scale == input_scale == 1 is supported"
self.elementwise_unary(tfl.EluOperator, graph_converter)
class ATenReciprocalOperator(ATenReciprocalSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
old_inp = self.input_tensors[0].to(dtype=torch.float32)
self.input_tensors.clear()
self.input_tensors.insert(0, torch.tensor([1], dtype=old_inp.dtype))
self.input_names.insert(0, self.get_unique_attr_name())
self.elementwise_binary(tfl.DivOperator, graph_converter, False)
class ATenRsqrtOperator(ATenRsqrtSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
self.elementwise_unary(tfl.RsqrtOperator, graph_converter)
class ATenHardtanhOperator(ATenHardtanhSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
min_value, max_value = self.input_tensors[1:]
if min_value == 0 and max_value == 6:
self.elementwise_unary(tfl.Relu6Operator, graph_converter)
else:
ops = []
input_tensor = self.find_or_create_input(0, graph_converter)
inter_tensor = self.create_transform_tensor(
np.where(input_tensor.tensor > min_value, input_tensor.tensor, min_value)
)
min_value_tensor = self.create_attr_tensor(np.array([min_value], dtype=input_tensor.dtype))
ops.append(tfl.MaximumOperator([input_tensor, min_value_tensor], [inter_tensor]))
outputs = self.to_tfl_tensors(self.output_names, self.output_tensors)
max_value_tensor = self.create_attr_tensor(np.array([max_value], dtype=input_tensor.dtype))
ops.append(tfl.MinimumOperator([inter_tensor, max_value_tensor], outputs))
for op in ops:
graph_converter.add_operator(op)
class ATenSubOperator(ATenSubSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
other = self.input_tensors[1]
alpha = self.input_tensors[-1]
assert alpha == 1, "Only alpha == 1 is supported"
if type(other) in (int, float, bool):
self.input_tensors[1] = torch.tensor([other], dtype=self.input_tensors[0].dtype)
elif not isinstance(other, torch.Tensor):
assert False, "other should have type int, float, tensor in aten::sub(input, other)"
self.elementwise_binary(tfl.SubOperator, graph_converter, True)
class ATenRsubOperator(ATenRsubSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
other = self.input_tensors[1]
alpha = self.input_tensors[-1]
assert alpha == 1, "Only alpha == 1 is supported"
if type(other) in (int, float, bool):
self.input_tensors[1] = torch.tensor([other], dtype=self.input_tensors[0].dtype)
elif not isinstance(other, torch.Tensor):
assert False, "other should have type int, float, tensor in aten::rsub(input, other)"
# Swap the first two input tensors and their names
self.input_names[0], self.input_names[1] = self.input_names[1], self.input_names[0]
self.input_tensors[0], self.input_tensors[1] = self.input_tensors[1], self.input_tensors[0]
self.elementwise_binary(tfl.SubOperator, graph_converter, True)
class ATenTransposeOperator(ATenTransposeSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
dim_1, dim_2 = self.input_tensors[1:]
input_tensor = self.find_or_create_input(0, graph_converter)
perm = np.arange(input_tensor.tensor.ndim, dtype='int32')
perm[dim_1], perm[dim_2] = perm[dim_2], perm[dim_1]
perm_tensor = self.create_attr_tensor(perm)
outputs = self.to_tfl_tensors(self.output_names, self.output_tensors)
graph_converter.add_operator(tfl.TransposeOperator([input_tensor, perm_tensor], outputs))
class ATenMulOperator(ATenMulSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
other = self.input_tensors[1]
if type(other) in (int, float):
self.input_tensors[1] = torch.tensor([other], dtype=self.input_tensors[0].dtype)
elif not isinstance(other, torch.Tensor):
assert False, "other should have type int, float, tensor in aten::mul(input, other)"
self.elementwise_binary(tfl.MulOperator, graph_converter, True)
class ATenDequantizeOperator(ATenDequantizeSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
self.elementwise_unary(tfl.DequantizeOperator, graph_converter)
class ATenQuantizePerTensorOperator(ATenQuantizePerTensorSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
self.elementwise_unary(tfl.QuantizeOperator, graph_converter)
class ATenFakeQuantizePerTensorAffineOperator(ATenFakeQuantizePerTensorAffineSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
self.passthrough(graph_converter)
class ATenFakeQuantizePerChannelAffineOperator(ATenFakeQuantizePerChannelAffineSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
self.passthrough(graph_converter)
class ATenFlipOperator(ATenFlipSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
n_dim = self.input_tensors[0].dim()
dims = [x + n_dim if x < 0 else x for x in self.input_tensors[1]]
self.input_tensors[1] = np.array(dims, dtype='int32')
if len(dims) == 1:
self.elementwise_binary(tfl.ReverseV2Operator, graph_converter, False)
else:
actual_input = self.find_or_create_input(0, graph_converter)
for dim in dims[:-1]:
transform = self.create_transform_tensor(np.flip(actual_input.tensor, dim))
dim_tensor = self.create_attr_tensor(np.array([dim], dtype='int32'))
graph_converter.add_operator(tfl.ReverseV2Operator([actual_input, dim_tensor], [transform]))
actual_input = transform
outputs = self.to_tfl_tensors(self.output_names, self.output_tensors)
dim_tensor = self.create_attr_tensor(np.array(dims[-1:], dtype='int32'))
graph_converter.add_operator(tfl.ReverseV2Operator([actual_input, dim_tensor], outputs))
class ATenDivOperator(ATenDivSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
other = self.input_tensors[1]
if type(other) in (int, float):
self.input_tensors[1] = torch.tensor([other], dtype=self.input_tensors[0].dtype)
elif not isinstance(other, torch.Tensor):
assert False, "other should have type int, float, tensor in aten::div(input, other)"
self.elementwise_binary(tfl.DivOperator, graph_converter, True)
class ATenMeanOperator(ATenMeanSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
self.handle_reduce(tfl.MeanOperator, args, graph_converter, True)
class ATenPowOperator(ATenPowSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
assert self.input_tensors[0].dtype in (
torch.float32,
torch.int32,
), "Input should be tensors of type torch.float32 or torch.int32"
if not isinstance(self.input_tensors[1], torch.Tensor):
self.input_tensors[1] = torch.tensor([self.input_tensors[1]], dtype=self.input_tensors[0].dtype)
self.elementwise_binary(tfl.PowOperator, graph_converter, True)
class ATenMaxPool2dOperator(ATenMaxPool2dSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
inputs = [self.find_or_create_input(0, graph_converter)]
outputs = self.to_tfl_tensors(self.output_names, self.output_tensors)
kernel_h, kernel_w = self.input_tensors[1]
stride_h, stride_w = self.input_tensors[2] or (kernel_h, kernel_w)
pad_h, pad_w = self.input_tensors[3]
dilation_h, dilation_w = self.input_tensors[4]
ceil_mode = self.input_tensors[5]
assert dilation_h == dilation_w == 1, "Only dilation == 1 is supported"
add_pad_op = not (
stride_h == stride_w == 1 and pad_h == kernel_h // 2 and pad_w == kernel_w // 2 and not ceil_mode
)
padding = tfl_schema.Padding.SAME
if add_pad_op:
padding = tfl_schema.Padding.VALID
maxpool_op = tfl.MaxPool2dOperator(inputs, outputs, padding, stride_w, stride_h, kernel_w, kernel_h)
ops = self.wrap_ops_with_nhwc_nchw_transposes([maxpool_op])
if add_pad_op:
self.handle_padding(pad_h, pad_w, 1, ops, ceil_mode)
for op in ops:
graph_converter.add_operator(op)
class ATenMatmulOperator(ATenMatmulSchema):
def parse_common(self, node, attrs, args, graph_converter):
input_tensor, weight_tensor = [self.find_or_create_input(i, graph_converter) for i in range(2)]
if input_tensor.tensor.ndim >= 2 and input_tensor.tensor.ndim <= 5:
if weight_tensor.tensor.ndim == 2:
bias_tensor = self.create_attr_tensor(np.zeros(weight_tensor.shape[1], dtype='float32'))
perm = [1, 0]
perm_tensor = self.create_attr_tensor(np.array(perm, dtype='int32'))
weight_transformed = self.create_transform_tensor(np.transpose(weight_tensor.tensor, perm))
graph_converter.add_operator(tfl.TransposeOperator([weight_tensor, perm_tensor], [weight_transformed]))
inputs = [input_tensor, weight_transformed, bias_tensor]
outputs = self.to_tfl_tensors(self.output_names, self.output_tensors)
keep_dims = len(outputs[0].shape) > 2
graph_converter.add_operator(tfl.FullyConnectedOperator(inputs, outputs, keepNumDims=keep_dims))
elif weight_tensor.tensor.ndim >= 2 and weight_tensor.tensor.ndim <= 5:
self.elementwise_binary(tfl.BatchMatmulOperator, graph_converter, False)
else:
self.unimplemented(node, attrs, args)
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
self.parse_common(node, attrs, args, graph_converter)
class ATenFlattenOperator(ATenFlattenSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
self.reshape(graph_converter)
class ATenDropoutOperator(ATenDropoutSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
train = self.input_tensors[args['train']]
if train not in (0, False):
log.warning('aten::dropout with train=True found and will add randomness to the model.')
input_tensor = self.find_or_create_input(0, graph_converter)
assert len(input_tensor.shape) in (2, 3), "Only supports dropout with 2d input for training mode"
assert input_tensor.quantization is None, "Only supports dropout with floating input for training mode"
p = self.input_tensors[args['p']]
ops = []
if len(input_tensor.shape) == 3:
assert (
input_tensor.shape[0] == 1
), "Only supports dropout with 3d input with batch_size=1 for training mode"
batch_size = input_tensor.shape[-2]
num_samples = input_tensor.shape[-1]
logits = self.create_attr_tensor(np.log(np.array([[p, 1 - p]] * batch_size, dtype='float32')))
num_samples_tensor = self.create_attr_tensor(np.array(num_samples, dtype='int32'))
multinomial_out = self.create_transform_tensor(np.empty((batch_size, num_samples), dtype='int32'))
ops.append(tfl.MultinomialOperator([logits, num_samples_tensor], [multinomial_out]))
casted = self.create_transform_tensor(np.empty((batch_size, num_samples), dtype='float32'))
ops.append(
tfl.CastOperator(
[multinomial_out],
[casted],
tfl.numpy_tflite_dtype_mappings[str(multinomial_out.dtype)],
tfl.numpy_tflite_dtype_mappings[str(casted.dtype)],
)
)
scale = self.create_attr_tensor(np.array([1.0 / (1.0 - p)], dtype='float32'))
scaled = self.create_transform_tensor(np.empty((batch_size, num_samples), dtype='float32'))
ops.append(tfl.MulOperator([casted, scale], [scaled]))
outputs = self.to_tfl_tensors(self.output_names, self.output_tensors)
ops.append(tfl.MulOperator([input_tensor, scaled], outputs))
for op in ops:
graph_converter.add_operator(op)
else:
self.passthrough(graph_converter)
class ATenFeatureDropoutOperator(ATenFeatureDropoutSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
train = self.input_tensors[args['train']]
if train not in (0, False):
log.warning('aten::dropout with train=True found. Please check your model.')
self.run(node)
self.passthrough(graph_converter)
class ATenSoftmaxOperator(ATenSoftmaxSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
dim = self.input_tensors[1]
if dim < 0:
dim += len(self.input_tensors[0].shape)
ops = []
inputs = [self.find_or_create_input(0, graph_converter)]
outputs = self.to_tfl_tensors(self.output_names, self.output_tensors)
softmax_op = tfl.SoftmaxOperator(inputs, outputs, 1.0)
ops.append(softmax_op)
ops = self.wrap_ops_with_last_dim_transposes(ops, dim)
for op in ops:
graph_converter.add_operator(op)
class ATenAtan2Operator(ATenAtan2Schema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
self.elementwise_binary(tfl.Atan2Operator, graph_converter, False)
class ATenSqrtOperator(ATenSqrtSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
self.elementwise_unary(tfl.SqrtOperator, graph_converter)
class ATenAddmmOperator(ATenAddmmSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
bias_tensor, input_tensor, weight_tensor = [self.find_or_create_input(i, graph_converter) for i in range(3)]
assert len(weight_tensor.shape) == 2, "Weight of AddMM should be 2D"
perm = [1, 0]
perm_tensor = self.create_attr_tensor(np.array(perm, dtype='int32'))
weight_transformed = self.create_transform_tensor(np.transpose(weight_tensor.tensor, perm))
graph_converter.add_operator(tfl.TransposeOperator([weight_tensor, perm_tensor], [weight_transformed]))
inputs = [input_tensor, weight_transformed, bias_tensor]
outputs = self.to_tfl_tensors(self.output_names, self.output_tensors)
keep_dims = len(outputs[0].shape) > 2
graph_converter.add_operator(tfl.FullyConnectedOperator(inputs, outputs, keepNumDims=keep_dims))
class ATenStackOperator(ATenStackSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
dim = self.input_tensors[1]
assert type(dim) is int
if dim < 0:
dim += self.input_tensors[0][0].ndim + 1
names = graph_converter.get_list_expanded_names(self.input_names[0])
orig_inputs = self.to_tfl_tensors(
names, self.input_tensors[0], graph_converter=graph_converter, non_existent_as_buffer=True
)
outputs = self.to_tfl_tensors(self.output_names, self.output_tensors)
as_unpack = True
for it in orig_inputs:
if it.quantization is not None and outputs[0].quantization is not None:
if (
it.quantization.scale != outputs[0].quantization.scale
or it.quantization.zero_point != outputs[0].quantization.zero_point
or it.quantization.dim != outputs[0].quantization.dim
):
as_unpack = False
break
elif it.quantization is not None or outputs[0].quantization is not None:
as_unpack = False
break
ops = []
if as_unpack:
ops.append(tfl.PackOperator(orig_inputs, outputs, len(orig_inputs), dim))
else:
inputs = [
self.create_transform_tensor(np.expand_dims(orig_inputs[i].tensor, dim))
for i in range(len(orig_inputs))
]
attrs = [self.create_attr_tensor(np.array(t.shape, dtype='int32')) for t in inputs]
ops.extend(
[
tfl.ReshapeOperator([orig, attr], [new], new.tensor.shape)
for orig, new, attr in zip(orig_inputs, inputs, attrs)
]
)
for op in ops:
op.extra_hints['direction'] = 'up'
ops.append(tfl.ConcatenationOperator(inputs, outputs, dim))
for op in ops:
graph_converter.add_operator(op)
class ATenCatOperator(ATenCatSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
dim = self.input_tensors[1]
assert type(dim) is int
if dim < 0:
dim += self.input_tensors[0][0].ndim
names = graph_converter.get_list_expanded_names(self.input_names[0])
inputs = self.to_tfl_tensors(
names, self.input_tensors[0], graph_converter=graph_converter, non_existent_as_buffer=True
)
outputs = self.to_tfl_tensors(self.output_names, self.output_tensors)
graph_converter.add_operator(tfl.ConcatenationOperator(inputs, outputs, dim))
class ATenPreluOperator(ATenPreluSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
alpha = self.input_tensors[1]
weight_c = alpha.numel()
input_c = self.input_tensors[0].shape[1]
new_shape = [input_c] + [1] * (self.input_tensors[0].ndim - 2)
alpha_tensor = self.find_or_create_input(1, graph_converter)
shape_tensor = self.create_attr_tensor(np.array(new_shape, dtype='int32'))
update_name = None
if weight_c == input_c:
new_alpha = self.create_transform_tensor(np.reshape(alpha_tensor.tensor, new_shape))
graph_converter.add_operator(tfl.ReshapeOperator([alpha_tensor, shape_tensor], [new_alpha], new_shape))
elif input_c != weight_c:
new_alpha = self.create_transform_tensor(np.tile(alpha_tensor.tensor, new_shape))
if alpha_tensor.buffer is None:
graph_converter.add_operator(tfl.TileOperator([alpha_tensor, shape_tensor], [new_alpha]))
else:
store = graph_converter.get_transform_store(alpha_tensor.name, str(input_c))
if store is None:
graph_converter.add_transform_store(alpha_tensor.name, str(input_c), new_alpha.name)
update_name = new_alpha.name
new_alpha = new_alpha.tensor
else:
update_name = store
self.input_tensors[1] = new_alpha
if update_name is None:
self.input_names[1] = new_alpha.name
else:
self.input_names[1] = update_name
self.elementwise_binary(tfl.PreluOperator, graph_converter, False)
class ATenToOperator(ATenToSchema):
def parse_common(self, node, attrs, args, graph_converter):
out_type = self.output_tensors[0].dtype
patch = False
if out_type == torch.float64:
patch = True
out_type = torch.float32
temp_tensor = self.output_tensors[0]
self.output_tensors[0] = temp_tensor.detach().clone().to(dtype=torch.float32)
self.elementwise_unary(
tfl.CastOperator,
graph_converter,
tfl.torch_tflite_dtype_mappings[self.input_tensors[0].dtype],
tfl.torch_tflite_dtype_mappings[out_type],
)
if patch:
self.output_tensors[0] = temp_tensor
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
self.parse_common(node, attrs, args, graph_converter)
class ATenViewOperator(ATenViewSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
self.reshape(graph_converter)
class ATenSinOperator(ATenSinSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
self.elementwise_unary(tfl.SinOperator, graph_converter)
class ATenUnsqueezeOperator(ATenUnsqueezeSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
self.reshape(graph_converter)
class ATenFloorOperator(ATenFloorSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
self.elementwise_unary(tfl.FloorOperator, graph_converter)
class ATenFloorDivideOperator(ATenFloorDivideSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
if not isinstance(self.input_tensors[1], torch.Tensor):
self.input_tensors[1] = torch.tensor([self.input_tensors[1]], dtype=self.input_tensors[0].dtype)
elif self.input_tensors[1].dtype != self.input_tensors[0].dtype:
other = self.find_or_create_input(1, graph_converter)
if other.buffer is None:
new_other = self.input_tensors[1].detach().clone().to(dtype=self.input_tensors[0].dtype)
new_other_t = self.create_transform_tensor(new_other)
graph_converter.add_operator(
tfl.CastOperator(
[other],
[new_other_t],
tfl.torch_tflite_dtype_mappings[self.input_tensors[1].dtype],
tfl.torch_tflite_dtype_mappings[self.input_tensors[0].dtype],
)
)
self.input_tensors[1] = new_other
self.input_names[1] = new_other_t.name
else:
self.input_tensors[1] = self.input_tensors[1].to(dtype=self.input_tensors[0].dtype)
assert all(
(not t.is_floating_point() for t in self.input_tensors[:2])
), "floor_divide for floats is not supported"
assert all(
((t >= 0).all() for t in self.input_tensors[:2])
), "floor_divide for negative numbers is not supported"
self.elementwise_binary(tfl.FloorDivOperator, graph_converter, False)
class ATenCosOperator(ATenCosSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
self.elementwise_unary(tfl.CosOperator, graph_converter)
class ATenConv2dOperator(ATenConv2dSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
input, weight, bias, stride, padding, dilation, groups = self.input_tensors[:7]
if bias is None:
end_index = 2
else:
end_index = 3
output_padding = [0] * 2
inputs = [self.find_or_create_input(i, graph_converter) for i in range(end_index)]
outputs = self.to_tfl_tensors(self.output_names, self.output_tensors)
graph_converter.add_operator(
tfl.GenericConvOperator(inputs, outputs, stride, padding, dilation, output_padding, groups)
)
class ATenConvolutionOperator(ATenConvolutionSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
input, weight, bias, stride, padding, dilation, transpose, output_padding, groups = self.input_tensors[:9]
if bias is None:
end_index = 2
else:
end_index = 3
inputs = [self.find_or_create_input(i, graph_converter) for i in range(end_index)]
outputs = self.to_tfl_tensors(self.output_names, self.output_tensors)
if len(stride) != len(padding) and len(stride) == 1:
stride = stride * len(padding)
if transpose == 0:
graph_converter.add_operator(
tfl.GenericConvOperator(inputs, outputs, stride, padding, dilation, output_padding, groups)
)
else:
graph_converter.add_operator(
tfl.GenericTransposeConvOperator(
inputs,
outputs,
stride,
padding,
dilation,
output_padding,
groups,
self.enable_mtk_ops,
self.conv_transpose_with_bias,
)
)
class ATenSliceOperator(ATenSliceSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
input_tensor = self.find_or_create_input(0, graph_converter)
dim, start, end, step = self.input_tensors[1:]
if start is None:
start = 0
if end is None:
end = input_tensor.tensor.shape[dim]
if start < 0:
start += input_tensor.tensor.shape[dim]
if end < 0:
end += input_tensor.tensor.shape[dim]
if dim < 0:
dim += input_tensor.tensor.ndim
if start > end:
end = start
if end >= input_tensor.tensor.shape[dim]:
end = input_tensor.tensor.shape[dim]
starts = np.zeros(input_tensor.tensor.ndim, dtype='int32')
starts[dim] = start
if self.input_names[2] in graph_converter.constant_mapping:
start_t = graph_converter.constant_mapping[self.input_names[2]]
new_shape_arr = np.array((1,), dtype='int32')
new_shape_tensor = self.create_attr_tensor(new_shape_arr)
start_reshaped = self.create_transform_tensor(np.reshape(start_t.tensor, new_shape_arr))
graph_converter.add_operator(
tfl.ReshapeOperator([start_t, new_shape_tensor], [start_reshaped], new_shape_arr)
)
start_casted = self.create_transform_tensor(start_reshaped.tensor.astype('int32'))
graph_converter.add_operator(
tfl.CastOperator(
[start_reshaped],
[start_casted],
tfl.numpy_tflite_dtype_mappings[str(start_reshaped.dtype)],
tfl.numpy_tflite_dtype_mappings[str(start_casted.dtype)],
)
)
start_tensor = self.create_transform_tensor(starts)
starts_left = starts[:dim]
starts_right = starts[dim + 1 :]
starts_tensors = []
if len(starts_left) > 0:
starts_tensors.append(self.create_attr_tensor(starts_left))
starts_tensors.append(start_casted)
if len(starts_right) > 0:
starts_tensors.append(self.create_attr_tensor(starts_right))
if len(starts_tensors) > 1:
graph_converter.add_operator(tfl.ConcatenationOperator(starts_tensors, [start_tensor], 0))
else:
start_tensor = starts_tensors[0]
else:
start_tensor = self.create_attr_tensor(starts)
ends = np.array(input_tensor.tensor.shape, dtype='int32')
if step != 1 or start_tensor.buffer is None or self.input_names[3] in graph_converter.constant_mapping:
ends[dim] = end
else:
ends[dim] = end - start
if self.input_names[3] in graph_converter.constant_mapping:
end_t = graph_converter.constant_mapping[self.input_names[3]]
new_shape_arr = np.array((1,), dtype='int32')
new_shape_tensor = self.create_attr_tensor(new_shape_arr)
end_reshaped = self.create_transform_tensor(np.reshape(end_t.tensor, new_shape_arr))
graph_converter.add_operator(tfl.ReshapeOperator([end_t, new_shape_tensor], [end_reshaped], new_shape_arr))
end_casted = self.create_transform_tensor(end_reshaped.tensor.astype('int32'))
graph_converter.add_operator(
tfl.CastOperator(
[end_reshaped],
[end_casted],
tfl.numpy_tflite_dtype_mappings[str(end_reshaped.dtype)],
tfl.numpy_tflite_dtype_mappings[str(end_casted.dtype)],
)
)
end_tensor = self.create_transform_tensor(ends)
ends_left = ends[:dim]
ends_right = ends[dim + 1 :]
ends_tensors = []
if len(ends_left) > 0:
ends_tensors.append(self.create_attr_tensor(ends_left))
ends_tensors.append(end_casted)
if len(ends_right) > 0:
ends_tensors.append(self.create_attr_tensor(ends_right))
if len(ends_tensors) > 1:
graph_converter.add_operator(tfl.ConcatenationOperator(ends_tensors, [end_tensor], 0))
else:
end_tensor = ends_tensors[0]
else:
end_tensor = self.create_attr_tensor(ends)
if step != 1 or start_tensor.buffer is None or end_tensor.buffer is None:
strides = np.ones(input_tensor.tensor.ndim, dtype='int32')
strides[dim] = step
stride_tensor = self.create_attr_tensor(strides)
inputs = [input_tensor, start_tensor, end_tensor, stride_tensor]
outputs = self.to_tfl_tensors(self.output_names, self.output_tensors)
graph_converter.add_operator(tfl.StridedSliceOperator(inputs, outputs))
else:
size_tensor = end_tensor
inputs = [input_tensor, start_tensor, size_tensor]
outputs = self.to_tfl_tensors(self.output_names, self.output_tensors)
graph_converter.add_operator(tfl.SliceOperator(inputs, outputs))
class ATenContiguousOperator(ATenContiguousSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
self.passthrough(graph_converter)
class ATenTOperator(ATenTSchema):
def parse(self, node, attrs, args, graph_converter: CommonGraph):
super().parse(node, attrs, args, graph_converter)
self.run(node)
input_tensor = self.input_tensors[0]
dims = len(input_tensor.shape)
if dims >= 2:
perm = torch.arange(dims).flip(dims=(0,))
inputs = [self.find_or_create_input(0, graph_converter), self.create_attr_tensor(perm)]
outputs = self.to_tfl_tensors(self.output_names, self.output_tensors)
graph_converter.add_operator(tfl.TransposeOperator(inputs, outputs))
else:
self.passthrough(graph_converter)
class ATenSqueezeOperator(ATenSqueezeSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
self.reshape(graph_converter)
class ATenReshapeOperator(ATenReshapeSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
self.reshape(graph_converter)
class ATenPermuteOperator(ATenPermuteSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
input_tensor = self.find_or_create_input(0, graph_converter)
attr_tensor = self.create_attr_tensor(np.array(self.input_tensors[1], dtype='int32'))
inputs = [input_tensor, attr_tensor]
outputs = self.to_tfl_tensors(self.output_names, self.output_tensors)
graph_converter.add_operator(tfl.TransposeOperator(inputs, outputs))
class ATenAddOperator(ATenAddSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
other = self.input_tensors[1]
alpha = self.input_tensors[-1]
assert alpha == 1, "Only alpha == 1 is supported"
if type(other) in (int, float, bool):
self.input_tensors[1] = torch.tensor([other], dtype=self.input_tensors[0].dtype)
elif not isinstance(other, torch.Tensor):
assert False, "other should have type int, float, tensor in aten::add(input, other)"
self.elementwise_binary(tfl.AddOperator, graph_converter, True)
class ATenReluOperator(ATenReluSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
self.elementwise_unary(tfl.ReluOperator, graph_converter)
class ATenRelu6Operator(ATenRelu6Schema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
self.elementwise_unary(tfl.Relu6Operator, graph_converter)
class ATenSigmoidOperator(ATenSigmoidSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
if self.q_type == np.int16:
self.output_tensors[0] = torch.quantize_per_tensor(
self.output_tensors[0].dequantize(),
self.output_tensors[0].q_scale() * 2,
0,
self.output_tensors[0].dtype,
)
self.elementwise_unary(tfl.LogisticOperator, graph_converter)
class ATenSelectOperator(ATenSelectSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
input_tensor = self.find_or_create_input(0, graph_converter)
dim, index = self.input_tensors[1:]
assert type(dim) is int
assert type(index) is int
if dim < 0:
dim += input_tensor.tensor.ndim
if index < 0:
index += input_tensor.tensor.shape[dim]
index_tensor = self.create_attr_tensor(np.array([index], dtype='int32'))
all_out = self.to_tfl_tensors(self.output_names, self.output_tensors)[0]
gather_out = self.create_transform_tensor(
np.expand_dims(all_out.tensor, dim), quantization=all_out.quantization
)
reshape_attr = self.create_attr_tensor(self.output_tensors[0].shape)
ops = []
gather_inputs = [input_tensor, index_tensor]
gather_outputs = [gather_out]
ops.append(tfl.GatherOperator(gather_inputs, gather_outputs, dim))
reshape_inputs = [gather_out, reshape_attr]
reshape_outputs = [all_out]
reshape_op = tfl.ReshapeOperator(reshape_inputs, reshape_outputs, reshape_attr.tensor)
reshape_op.extra_hints['direction'] = 'down'
ops.append(reshape_op)
for op in ops:
graph_converter.add_operator(op)
class ATenTanhOperator(ATenTanhSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
self.elementwise_unary(tfl.TanhOperator, graph_converter)
class ATenEmbeddingOperator(ATenEmbeddingSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
weight, indices = [self.find_or_create_input(i, graph_converter) for i in range(2)]
assert weight.tensor.ndim == 2, "Only 2D weight tensors are supported"
assert indices.dtype in (np.int32, np.int64), "Only integral indices are supported"
outputs = self.to_tfl_tensors(self.output_names, self.output_tensors)
graph_converter.add_operator(tfl.GatherOperator([weight, indices], outputs, 0))
class ATenLinearOperator(ATenLinearSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
ops = []
input_tensor, weight_tensor, bias_tensor = self.input_tensors
if input_tensor.dim() >= 2 and input_tensor.dim() <= 5:
assert len(weight_tensor.shape) == 2, "Weight of AddMM should be 2D"
if bias_tensor is not None:
input_tensor, weight_tensor, bias_tensor = [
self.find_or_create_input(i, graph_converter) for i in range(3)
]
else:
input_tensor, weight_tensor = [self.find_or_create_input(i, graph_converter) for i in range(2)]
bias_tensor = self.create_attr_tensor(np.zeros(weight_tensor.shape[0], dtype='float32'))
inputs = [input_tensor, weight_tensor, bias_tensor]
outputs = self.to_tfl_tensors(self.output_names, self.output_tensors)
keep_dims = len(outputs[0].shape) > 2
ops.append(tfl.FullyConnectedOperator(inputs, outputs, keepNumDims=keep_dims))
else:
log.error(
f'aten::linear is not supported for input shape {input_tensor.shape}, '
f'weight shape {weight_tensor.shape}, '
f'bias type {type(bias_tensor).__name__}'
)
self.unimplemented(node, attrs, args)
for op in ops:
graph_converter.add_operator(op)
class ATenClampOperator(ATenClampSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
self.parse_common(node, attrs, args, graph_converter)
def parse_common(self, node, attrs, args, graph_converter):
if type(self) is ATenClampOperator:
min_value, max_value = self.input_tensors[1:]
elif type(self) is ATenClampMinOperator:
min_value, max_value = self.input_tensors[1], None
elif type(self) is ATenClampMaxOperator:
min_value, max_value = None, self.input_tensors[1]
has_min = min_value is not None
has_max = max_value is not None
assert has_min or has_max
if min_value == 0 and max_value == 6:
self.elementwise_unary(tfl.Relu6Operator, graph_converter)
elif min_value == 0 and not has_max:
self.elementwise_unary(tfl.ReluOperator, graph_converter)
else:
ops = []
input_tensor = self.find_or_create_input(0, graph_converter)
if has_min:
if input_tensor.quantization is not None:
min_value_arr = np.array([min_value], dtype='float32')
min_value_tensor = self.create_attr_tensor(
self.quantize_numpy(
min_value_arr,
input_tensor.quantization.scale,
input_tensor.quantization.zero_point,
input_tensor.dtype,
),
quantization=input_tensor.quantization,
)
else:
min_value_arr = np.array([min_value], dtype=input_tensor.dtype)
min_value_tensor = self.create_attr_tensor(min_value_arr)
if has_max:
if input_tensor.quantization is not None:
max_value_arr = np.array([max_value], dtype='float32')
max_value_tensor = self.create_attr_tensor(
self.quantize_numpy(
max_value_arr,
input_tensor.quantization.scale,
input_tensor.quantization.zero_point,
input_tensor.dtype,
),
quantization=input_tensor.quantization,
)
else:
max_value_arr = np.array([max_value], dtype=input_tensor.dtype)
max_value_tensor = self.create_attr_tensor(max_value_arr)
if has_min and has_max:
inter_tensor = self.create_transform_tensor(
np.minimum(input_tensor.tensor, min_value_tensor.tensor), quantization=input_tensor.quantization
)
ops.append(tfl.MaximumOperator([input_tensor, min_value_tensor], [inter_tensor]))
outputs = self.to_tfl_tensors(self.output_names, self.output_tensors)
ops.append(tfl.MinimumOperator([inter_tensor, max_value_tensor], outputs))
elif has_min:
outputs = self.to_tfl_tensors(self.output_names, self.output_tensors)
ops.append(tfl.MaximumOperator([input_tensor, min_value_tensor], outputs))
else:
outputs = self.to_tfl_tensors(self.output_names, self.output_tensors)
ops.append(tfl.MinimumOperator([input_tensor, max_value_tensor], outputs))
for op in ops:
graph_converter.add_operator(op)
class ATenClampMinOperator(ATenClampMinSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
ATenClampOperator.parse_common(self, node, attrs, args, graph_converter)
class ATenClampMaxOperator(ATenClampMaxSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
ATenClampOperator.parse_common(self, node, attrs, args, graph_converter)
class ATenExpOperator(ATenExpSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
self.elementwise_unary(tfl.ExpOperator, graph_converter)
class ATenLogOperator(ATenLogSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
self.elementwise_unary(tfl.LogOperator, graph_converter)
class ATenNeOperator(ATenNeSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
if not isinstance(self.input_tensors[1], torch.Tensor):
self.input_tensors[1] = self.torch_tensor_from_scalar(self.input_tensors[0], self.input_tensors[1])
self.elementwise_binary(tfl.NotEqualOperator, graph_converter, True)
class ATenSoftplusOperator(ATenSoftplusSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
beta = self.input_tensors[1]
assert beta == 1.0, "Only beta=1.0 is supported for aten::softplus"
warnings.warn('threshold is ignored when transforming aten::softplus')
ops = []
input_tensor = self.find_or_create_input(0, graph_converter)
exp_out = self.create_transform_tensor(np.exp(input_tensor.tensor))
ops.append(tfl.ExpOperator([input_tensor], [exp_out]))
one_tensor = self.create_attr_tensor(np.ones((1,), dtype=exp_out.dtype))
add_out = self.create_transform_tensor(exp_out.tensor + one_tensor.tensor)
ops.append(tfl.AddOperator([exp_out, one_tensor], [add_out]))
outputs = self.to_tfl_tensors(self.output_names, self.output_tensors)
ops.append(tfl.LogOperator([add_out], outputs))
for op in ops:
graph_converter.add_operator(op)
class ATenLayerNormOperator(ATenLayerNormSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
input_tensor = self.find_or_create_input(0, graph_converter)
normalized_shape = self.input_tensors[1]
weight_tensor = self.find_or_create_input(2, graph_converter)
bias_tensor = self.find_or_create_input(3, graph_converter)
eps = self.input_tensors[4]
ops = []
axes = [input_tensor.tensor.ndim - i for i in range(len(normalized_shape), 0, -1)]
dims_tensor = self.create_attr_tensor(np.array(axes, dtype='int32'))
mean_tensor = self.create_transform_tensor(np.mean(input_tensor.tensor, axis=tuple(axes), keepdims=True))
ops.append(tfl.MeanOperator([input_tensor, dims_tensor], [mean_tensor], keepDims=True))
squared_diff = self.create_transform_tensor(np.power(input_tensor.tensor - mean_tensor.tensor, 2))
ops.append(tfl.SquaredDifferenceOperator([input_tensor, mean_tensor], [squared_diff]))
var_tensor = self.create_transform_tensor(np.mean(squared_diff.tensor, axis=tuple(axes), keepdims=True))
ops.append(tfl.MeanOperator([squared_diff, dims_tensor], [var_tensor], keepDims=True))
numerator = self.create_transform_tensor(input_tensor.tensor - mean_tensor.tensor)
ops.append(tfl.SubOperator([input_tensor, mean_tensor], [numerator]))
eps_tensor = self.create_attr_tensor(np.array([eps], dtype='float32'))
with_eps = self.create_transform_tensor(var_tensor.tensor + eps_tensor.tensor)
ops.append(tfl.AddOperator([var_tensor, eps_tensor], [with_eps]))
denominator = self.create_transform_tensor(np.sqrt(with_eps.tensor))
ops.append(tfl.SqrtOperator([with_eps], [denominator]))
norm = self.create_transform_tensor(numerator.tensor / denominator.tensor)
ops.append(tfl.DivOperator([numerator, denominator], [norm]))
mul_out = self.create_transform_tensor(norm.tensor * weight_tensor.tensor)
ops.append(tfl.MulOperator([norm, weight_tensor], [mul_out]))
outputs = self.to_tfl_tensors(self.output_names, self.output_tensors)
ops.append(tfl.AddOperator([mul_out, bias_tensor], outputs))
for op in ops:
graph_converter.add_operator(op)
class ATenInstanceNormOperator(ATenInstanceNormSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
ops = []
inp = self.find_or_create_input(0, graph_converter)
eps = self.input_tensors[args['eps']]
weight, bias = self.input_tensors[1:3]
affine = False
track_running_stats = False
if weight is not None and bias is not None:
affine = True
weight, bias = [self.find_or_create_input(i, graph_converter) for i in range(1, 3)]
running_mean, running_var = self.input_tensors[3:5]
if running_mean is not None and running_var is not None:
track_running_stats = True
running_mean, running_var = [self.find_or_create_input(i, graph_converter) for i in range(3, 5)]
if affine and track_running_stats:
inputs = [inp, weight, bias, running_mean, running_var]
outputs = self.to_tfl_tensors(self.output_names, self.output_tensors)
ops.append(tfl.BatchNormOperator(inputs, outputs, eps))
else:
assert (
track_running_stats is False
), 'Instance norm with track_running_stats=True and affine=False is not supported'
dims = len(inp.shape)
axis = tuple(range(2, dims))
axis_tensor = self.create_attr_tensor(np.array(axis, dtype='int32'))
dim_ones = (1,) * (dims - 2)
dims = self.create_attr_tensor(np.array(axis, dtype='int32'))
mean = self.create_transform_tensor(np.mean(inp.tensor, axis=axis, keepdims=True))
ops.append(tfl.MeanOperator([inp, axis_tensor], [mean], keepDims=True))
squared_diff = self.create_transform_tensor(np.power(inp.tensor - mean.tensor, 2))
ops.append(tfl.SquaredDifferenceOperator([inp, mean], [squared_diff]))
var = self.create_transform_tensor(np.mean(squared_diff.tensor, axis=axis, keepdims=True))
ops.append(tfl.MeanOperator([squared_diff, dims], [var], keepDims=True))
numerator = self.create_transform_tensor(inp.tensor - mean.tensor)
ops.append(tfl.SubOperator([inp, mean], [numerator]))
eps_tensor = self.create_attr_tensor(np.array([eps], dtype='float32'))
with_eps = self.create_transform_tensor(var.tensor + eps_tensor.tensor)
ops.append(tfl.AddOperator([var, eps_tensor], [with_eps]))
denominator = self.create_transform_tensor(np.sqrt(with_eps.tensor))
ops.append(tfl.SqrtOperator([with_eps], [denominator]))
outputs = self.to_tfl_tensors(self.output_names, self.output_tensors)
if affine is False:
ops.append(tfl.DivOperator([numerator, denominator], outputs))
else:
weight.tensor = weight.tensor.reshape(-1, *dim_ones)
bias.tensor = bias.tensor.reshape(-1, *dim_ones)
weight_tensor = self.create_attr_tensor(weight.tensor)
bias_tensor = self.create_attr_tensor(bias.tensor)
norm = self.create_transform_tensor(numerator.tensor / denominator.tensor)
ops.append(tfl.DivOperator([numerator, denominator], [norm]))
mul_out = self.create_transform_tensor(norm.tensor * weight_tensor.tensor)
ops.append(tfl.MulOperator([norm, weight_tensor], [mul_out]))
outputs = self.to_tfl_tensors(self.output_names, self.output_tensors)
ops.append(tfl.AddOperator([mul_out, bias_tensor], outputs))
for op in ops:
graph_converter.add_operator(op)
class ATenGroupNormOperator(ATenGroupNormSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
inp = self.find_or_create_input(0, graph_converter)
eps = self.input_tensors[args['eps']]
n_channels = inp.shape[1]
n_groups, weight, bias = self.input_tensors[1:4]
affine = False
if weight is not None and bias is not None:
affine = True
weight, bias = [self.find_or_create_input(i, graph_converter) for i in range(2, 4)]
ops = []
inputs = []
dims = len(inp.shape)
if n_channels == n_groups and n_groups > 1:
axis = tuple(range(2, dims))
axis_tensor = self.create_attr_tensor(np.array(axis, dtype='int32'))
inputs.append(inp)
elif n_groups == 1:
axis = tuple(range(1, dims))
axis_tensor = self.create_attr_tensor(np.array(axis, dtype='int32'))
inputs.append(inp)
else:
axis = tuple(range(1, dims))
axis_tensor = self.create_attr_tensor(np.array(axis, dtype='int32'))
split_dim_tensor = self.create_attr_tensor(np.array(1, dtype='int32'))
inputs = [self.create_transform_tensor(t) for t in np.split(inp.tensor, n_groups, axis=1)]
ops.append(tfl.SplitOperator([split_dim_tensor, inp], inputs, n_groups))
dim_ones = (1,) * (dims - 2)
norms = []
for input_t in inputs:
mean = self.create_transform_tensor(np.mean(input_t.tensor, axis=axis, keepdims=True))
ops.append(tfl.MeanOperator([input_t, axis_tensor], [mean], keepDims=True))
squared_diff = self.create_transform_tensor(np.power(input_t.tensor - mean.tensor, 2))
ops.append(tfl.SquaredDifferenceOperator([input_t, mean], [squared_diff]))
var = self.create_transform_tensor(np.mean(squared_diff.tensor, axis=axis, keepdims=True))
ops.append(tfl.MeanOperator([squared_diff, axis_tensor], [var], keepDims=True))
numerator = self.create_transform_tensor(input_t.tensor - mean.tensor)
ops.append(tfl.SubOperator([input_t, mean], [numerator]))
eps_tensor = self.create_attr_tensor(np.array([eps], dtype='float32'))
with_eps = self.create_transform_tensor(var.tensor + eps_tensor.tensor)
ops.append(tfl.AddOperator([var, eps_tensor], [with_eps]))
denominator = self.create_transform_tensor(np.sqrt(with_eps.tensor))
ops.append(tfl.SqrtOperator([with_eps], [denominator]))
norm = self.create_transform_tensor(numerator.tensor / denominator.tensor)
ops.append(tfl.DivOperator([numerator, denominator], [norm]))
norms.append(norm)
if len(norms) > 1:
cat_out = self.create_transform_tensor(np.concatenate([x.tensor for x in norms], 1))
ops.append(tfl.ConcatenationOperator(norms, [cat_out], 1))
else:
cat_out = norms[0]
if affine:
weight.tensor = weight.tensor.reshape(-1, *dim_ones)
bias.tensor = bias.tensor.reshape(-1, *dim_ones)
weight_tensor = self.create_attr_tensor(weight.tensor)
bias_tensor = self.create_attr_tensor(bias.tensor)
mul_out = self.create_transform_tensor(cat_out.tensor * weight_tensor.tensor)
ops.append(tfl.MulOperator([cat_out, weight_tensor], [mul_out]))
outputs = self.to_tfl_tensors(self.output_names, self.output_tensors)
ops.append(tfl.AddOperator([mul_out, bias_tensor], outputs))
else:
outputs = self.to_tfl_tensors(self.output_names, self.output_tensors)
ops[-1].outputs = outputs
for op in ops:
graph_converter.add_operator(op)
class ATenIndexOperator(ATenIndexSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
indices = self.input_tensors[1]
filtered_dims = [i for i, idx in enumerate(indices) if idx is not None]
assert all((indices[i].dtype in (torch.int64, torch.int32) for i in filtered_dims))
input_tensor = self.find_or_create_input(0, graph_converter)
outputs = self.to_tfl_tensors(self.output_names, self.output_tensors)
if len(filtered_dims) > 1:
if graph_converter.has_nested_names(self.input_names[1]):
input_names = graph_converter.get_list_expanded_names(self.input_names[1])
indices_tensors = self.to_tfl_tensors(
input_names, self.input_tensors[1], graph_converter=graph_converter, non_existent_as_buffer=True
)
else:
if type(self.input_tensors[1]) in (tuple, list):
indices_tensors = [self.create_attr_tensor(x) for x in self.input_tensors[1]]
else:
indices_tensors = [self.find_or_create_input(1, graph_converter)]
dim = input_tensor.tensor.ndim
indices_shape = [x.tensor.size for x in indices_tensors]
max_len = max(indices_shape)
indices_shape_tensor = torch.tensor(indices_shape)
left_indices = (
torch.arange(max_len).view(-1, 1).expand(-1, len(indices_shape)) % indices_shape_tensor
).int()
all_indices_shape = list(outputs[0].shape) + [dim]
if len(indices_tensors) < dim:
pad_shape = list(input_tensor.shape[len(indices_tensors) :])
pad_indices = torch.ones(pad_shape).nonzero().int()
left_len = len(indices_shape)
right_len = len(pad_shape)
left_size = left_indices.size(0)
right_size = pad_indices.size(0)
left_reshaped = (
left_indices.view(-1, 1, left_len).expand(-1, right_size, left_len).reshape(-1, left_len)
)
right_reshaped = (
pad_indices.view(1, -1, right_len).expand(left_size, -1, right_len).reshape(-1, right_len)
)
all_indices = torch.cat([left_reshaped, right_reshaped], 1).view(all_indices_shape).unbind(-1)
else:
all_indices = left_indices.view(all_indices_shape).unbind(-1)
new_indices = []
for i in range(dim):
if i < len(indices_tensors):
idx_tensor = indices_tensors[i]
actual_idx = np.take(idx_tensor.tensor, all_indices[i].numpy())
else:
actual_idx = all_indices[i].numpy()
if idx_tensor.buffer is None and i < len(indices_tensors):
actual_idx_t = self.create_transform_tensor(actual_idx)
fake_idx_t = self.create_attr_tensor(all_indices[i].numpy())
graph_converter.add_operator(tfl.GatherOperator([idx_tensor, fake_idx_t], [actual_idx_t], axis=0))
if str(actual_idx_t.dtype) != 'int32':
index_casted = self.create_transform_tensor(actual_idx_t.tensor.astype('int32'))
graph_converter.add_operator(
tfl.CastOperator(
[actual_idx_t],
[index_casted],
tfl.numpy_tflite_dtype_mappings[str(actual_idx_t.dtype)],
tfl.numpy_tflite_dtype_mappings[str(index_casted.dtype)],
)
)
actual_idx_t = index_casted
new_indices.append(actual_idx_t)
else:
new_indices.append(self.create_attr_tensor(actual_idx.astype(np.int32)))
index_arr = np.stack([x.tensor for x in new_indices], -1)
if all((x.buffer is not None for x in new_indices)):
index_tensor = self.create_attr_tensor(index_arr)
else:
index_tensor = self.create_transform_tensor(index_arr)
graph_converter.add_operator(
tfl.PackOperator(new_indices, [index_tensor], dim, axis=index_tensor.tensor.ndim - 1)
)
graph_converter.add_operator(tfl.GatherNdOperator([input_tensor, index_tensor], outputs))
else:
try:
names = graph_converter.get_list_expanded_names(self.input_names[1])
except KeyError:
names = [self.get_unique_attr_name() for _ in indices]
filtered_names = [names[i] for i in filtered_dims]
filtered_tensors = [indices[i].to(dtype=torch.int32) for i in filtered_dims]
filtered_tensors = [
t + (t < 0).int() * input_tensor.shape[i] if n not in graph_converter.tensor_map else t
for i, n, t in zip(filtered_dims, filtered_names, filtered_tensors)
]
indice_tensors = self.to_tfl_tensors(
filtered_names, filtered_tensors, graph_converter=graph_converter, non_existent_as_buffer=True
)
actual_input = input_tensor
actual_output = None
for i, (dim, idx) in enumerate(zip(filtered_dims, indice_tensors)):
if i == len(filtered_dims) - 1:
actual_output = outputs[0]
else:
actual_output = self.create_transform_tensor(np.take(actual_input.tensor, idx.tensor, axis=dim))
graph_converter.add_operator(tfl.GatherOperator([actual_input, idx], [actual_output], axis=dim))
actual_input = actual_output
class ATenIndexSelectOperator(ATenIndexSelectSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
dim = self.input_tensors[1]
indices = self.input_tensors[2]
assert indices.dtype in (torch.int64, torch.int32)
input_tensor = self.find_or_create_input(0, graph_converter)
if dim < 0:
dim += len(input_tensor.shape)
new_indices = indices.to(dtype=torch.int32)
new_indices = new_indices + (new_indices < 0).int() * input_tensor.shape[dim]
indices_tensor = self.to_tfl_tensors(
self.input_names[2:3], [new_indices], graph_converter=graph_converter, non_existent_as_buffer=True
)[0]
self.create_attr_tensor(new_indices)
outputs = self.to_tfl_tensors(self.output_names, self.output_tensors)
graph_converter.add_operator(tfl.GatherOperator([input_tensor, indices_tensor], outputs, axis=dim))
class ATenLogSoftmaxOperator(ATenLogSoftmaxSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
dim = self.input_tensors[1]
if dim < 0:
dim += len(self.input_tensors[0].shape)
ops = []
inputs = [self.find_or_create_input(0, graph_converter)]
outputs = self.to_tfl_tensors(self.output_names, self.output_tensors)
log_softmax_op = tfl.LogSoftmaxOperator(inputs, outputs)
ops.append(log_softmax_op)
ops = self.wrap_ops_with_last_dim_transposes(ops, dim)
for op in ops:
graph_converter.add_operator(op)
class ATenCloneOperator(ATenCloneSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
self.passthrough(graph_converter)
class ATenRepeatOperator(ATenRepeatSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
ops = []
input_tensor = self.find_or_create_input(0, graph_converter)
actual_input = input_tensor
if input_tensor.buffer is None:
outputs = self.to_tfl_tensors(self.output_names, self.output_tensors)
repeats = self.input_tensors[1]
input_shape = input_tensor.shape
if len(repeats) > len(input_shape):
new_shape = [1] * (len(repeats) - len(input_shape)) + list(input_shape)
new_shape_arr = np.array(new_shape, dtype='int32')
new_shape_tensor = self.create_attr_tensor(new_shape_arr)
reshaped = self.create_transform_tensor(np.reshape(input_tensor.tensor, new_shape_arr))
actual_input = reshaped
ops.append(tfl.ReshapeOperator([input_tensor, new_shape_tensor], [reshaped], new_shape_arr))
repeat_tensor = self.create_attr_tensor(np.array(repeats, dtype='int32'))
ops.append(tfl.TileOperator([actual_input, repeat_tensor], outputs))
for op in ops:
graph_converter.add_operator(op)
class ATenRepeatInterleaveOperator(ATenRepeatInterleaveSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
input_tensor = self.find_or_create_input(0, graph_converter)
outputs = self.to_tfl_tensors(self.output_names, self.output_tensors)
if 'dim' in args:
dim = self.input_tensors[args['dim']]
else:
dim = None
if 'repeats' in args:
repeats = self.input_tensors[args['repeats']]
else:
repeats = None
if repeats is None:
size_repeats = input_tensor.tensor.size
raw_indices = torch.arange(size_repeats, dtype=torch.int32)
repeats_tensor = input_tensor
elif type(repeats) is int:
if dim is None:
size_repeats = input_tensor.tensor.size
else:
size_repeats = input_tensor.shape[dim]
raw_indices = torch.arange(size_repeats, dtype=torch.int32)
repeats_arr = torch.tensor(repeats, dtype=torch.int32)
repeats_tensor = self.create_attr_tensor(repeats_arr)
else:
if dim is None:
size_repeats = input_tensor.tensor.size
else:
size_repeats = input_tensor.shape[dim]
raw_indices = torch.arange(size_repeats, dtype=torch.int32)
repeats_tensor = self.find_or_create_input(args['repeats'], graph_converter)
assert repeats_tensor.buffer is not None, "dynamic repeats_tensor is not supported"
actual_indices = self.create_attr_tensor(
torch.repeat_interleave(raw_indices, torch.from_numpy(repeats_tensor.tensor).long())
)
actual_input = input_tensor
if dim is None and len(input_tensor.shape) > 1:
new_shape = (input_tensor.tensor.size,)
shape_tensor = self.create_attr_tensor(np.array(new_shape, dtype='int32'))
actual_input = self.create_transform_tensor(np.reshape(input_tensor.tensor, new_shape))
graph_converter.add_operator(tfl.ReshapeOperator([input_tensor, shape_tensor], [actual_input], new_shape))
inputs = [actual_input, actual_indices]
gather_dim = dim
if gather_dim is None:
gather_dim = 0
if gather_dim < 0:
gather_dim += input_tensor.tensor.ndim
graph_converter.add_operator(tfl.GatherOperator(inputs, outputs, gather_dim))
class ATenMmOperator(ATenMmSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
ATenMatmulOperator.parse_common(self, node, attrs, args, graph_converter)
class ATenHardswishOperator(ATenHardswishSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
self.elementwise_unary(tfl.HardSwishOperator, graph_converter)
class ATenHardsigmoidOperator(ATenHardsigmoidSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
ops = []
input_tensor = self.find_or_create_input(0, graph_converter)
three_tensor = self.create_attr_tensor(np.array([3], dtype=input_tensor.dtype))
plus_three = self.create_transform_tensor(input_tensor.tensor + three_tensor.tensor)
ops.append(tfl.AddOperator([input_tensor, three_tensor], [plus_three]))
relu6_tensor = self.create_transform_tensor(np.clip(plus_three.tensor, 0, 6))
ops.append(tfl.Relu6Operator([plus_three], [relu6_tensor]))
six_tensor = self.create_attr_tensor(np.array([6], dtype=input_tensor.dtype))
output_tensor = self.to_tfl_tensors(self.output_names, self.output_tensors)[0]
ops.append(tfl.DivOperator([relu6_tensor, six_tensor], [output_tensor]))
for op in ops:
graph_converter.add_operator(op)
class ATenSiluOperator(ATenSiluSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
ops = []
input_tensor = self.find_or_create_input(0, graph_converter)
sigmoid_x = self.create_transform_tensor(torch.sigmoid(torch.from_numpy(input_tensor.tensor)).numpy())
ops.append(tfl.LogisticOperator([input_tensor], [sigmoid_x]))
outputs = self.to_tfl_tensors(self.output_names, self.output_tensors)
ops.append(tfl.MulOperator([input_tensor, sigmoid_x], outputs))
for op in ops:
graph_converter.add_operator(op)
class ATenVarOperator(ATenVarSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
input_dims = self.input_tensors[0].dim()
dims = self.input_tensors[args['dim']] if 'dim' in args else list(range(input_dims))
keep_dims = self.input_tensors[args['keepdim']] if 'keepdim' in args else False
unbiased = self.input_tensors[args['unbiased']] if 'unbiased' in args else True
correction = self.input_tensors[args['correction']] if 'correction' in args else 1
for i in range(len(dims)):
if dims[i] < 0:
dims[i] += input_dims
input_tensor = self.find_or_create_input(0, graph_converter)
output_tensor = self.to_tfl_tensors(self.output_names, self.output_tensors)[0]
ops = []
sample_dims = [input_tensor.shape[i] for i in range(input_dims) if i in dims]
samples = np.prod(sample_dims, dtype='float32')
if unbiased and correction != 0:
samples -= correction
samples = samples.astype('float32')
samples_tensor = self.create_attr_tensor(samples)
dims_tensor = self.create_attr_tensor(np.array(dims, dtype='int32'))
mean_tensor = self.create_transform_tensor(np.mean(input_tensor.tensor, axis=tuple(dims), keepdims=True))
ops.append(tfl.MeanOperator([input_tensor, dims_tensor], [mean_tensor], keepDims=True))
squared_diff = self.create_transform_tensor(np.power(input_tensor.tensor - mean_tensor.tensor, 2))
ops.append(tfl.SquaredDifferenceOperator([input_tensor, mean_tensor], [squared_diff]))
if unbiased and correction != 0:
squared_diff_sum = self.create_transform_tensor(
np.sum(squared_diff.tensor, axis=tuple(dims), keepdims=keep_dims)
)
ops.append(tfl.SumOperator([squared_diff, dims_tensor], [squared_diff_sum], keepDims=keep_dims))
ops.append(tfl.DivOperator([squared_diff_sum, samples_tensor], [output_tensor]))
else:
ops.append(tfl.MeanOperator([squared_diff, dims_tensor], [output_tensor], keepDims=keep_dims))
for op in ops:
graph_converter.add_operator(op)
class ATenStdOperator(ATenStdSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
input_dims = self.input_tensors[0].dim()
dims = self.input_tensors[args['dim']] if 'dim' in args else list(range(input_dims))
keep_dims = self.input_tensors[args['keepdim']] if 'keepdim' in args else False
unbiased = self.input_tensors[args['unbiased']] if 'unbiased' in args else True
correction = self.input_tensors[args['correction']] if 'correction' in args else 1
for i in range(len(dims)):
if dims[i] < 0:
dims[i] += input_dims
input_tensor = self.find_or_create_input(0, graph_converter)
output_tensor = self.to_tfl_tensors(self.output_names, self.output_tensors)[0]
ops = []
sample_dims = [input_tensor.shape[i] for i in range(input_dims) if i in dims]
samples = np.prod(sample_dims, dtype='float32')
if unbiased and correction != 0:
samples -= correction
samples = samples.astype('float32')
samples_tensor = self.create_attr_tensor(samples)
dims_tensor = self.create_attr_tensor(np.array(dims, dtype='int32'))
mean_tensor = self.create_transform_tensor(np.mean(input_tensor.tensor, axis=tuple(dims), keepdims=True))
ops.append(tfl.MeanOperator([input_tensor, dims_tensor], [mean_tensor], keepDims=True))
squared_diff = self.create_transform_tensor(np.power(input_tensor.tensor - mean_tensor.tensor, 2))
ops.append(tfl.SquaredDifferenceOperator([input_tensor, mean_tensor], [squared_diff]))
if unbiased and correction != 0:
squared_diff_sum = self.create_transform_tensor(
np.sum(squared_diff.tensor, axis=tuple(dims), keepdims=keep_dims)
)
ops.append(tfl.SumOperator([squared_diff, dims_tensor], [squared_diff_sum], keepDims=keep_dims))
var_tensor = self.create_transform_tensor(squared_diff_sum.tensor / samples_tensor.tensor)
ops.append(tfl.DivOperator([squared_diff_sum, samples_tensor], [var_tensor]))
else:
var_tensor = self.create_transform_tensor(
np.mean(squared_diff.tensor, axis=tuple(dims), keepdims=keep_dims)
)
ops.append(tfl.MeanOperator([squared_diff, dims_tensor], [var_tensor], keepDims=keep_dims))
ops.append(tfl.SqrtOperator([var_tensor], [output_tensor]))
for op in ops:
graph_converter.add_operator(op)
class ATenReflectionPad2dOperator(ATenReflectionPad2dSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
input_tensor = self.find_or_create_input(0, graph_converter)
pads = self.input_tensors[1]
tfl_pads = np.array([[0, 0], [0, 0], [pads[2], pads[3]], [pads[0], pads[1]]], dtype='int32')
pad_tensor = self.create_attr_tensor(tfl_pads)
outputs = self.to_tfl_tensors(self.output_names, self.output_tensors)
graph_converter.add_operator(
tfl.MirrorPadOperator([input_tensor, pad_tensor], outputs, tfl_schema.MirrorPadMode.REFLECT)
)
class ATenReflectionPad1dOperator(ATenReflectionPad1dSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
input_tensor = self.find_or_create_input(0, graph_converter)
pads = self.input_tensors[1]
tfl_pads = np.array([[0, 0], [0, 0], [pads[0], pads[1]]], dtype='int32')
pad_tensor = self.create_attr_tensor(tfl_pads)
outputs = self.to_tfl_tensors(self.output_names, self.output_tensors)
graph_converter.add_operator(
tfl.MirrorPadOperator([input_tensor, pad_tensor], outputs, tfl_schema.MirrorPadMode.REFLECT)
)
class ATenSplitOperator(ATenSplitSchema):
def parse_common(self, node, attrs, args, graph_converter):
input_tensor = self.find_or_create_input(0, graph_converter)
dim = self.input_tensors[2]
if dim < 0:
dim += len(self.input_tensors[0].shape)
dim_tensor = self.create_attr_tensor(np.array(dim, dtype='int32'))
size_splits = np.array([t.size(dim) for t in self.output_tensors[0]], dtype='int32')
chunks = len(size_splits)
split_tensor = self.create_attr_tensor(size_splits)
output_names = [f'{self.output_names[0]}:{i}' for i in range(chunks)]
graph_converter.add_iterable_pair(self.output_names, output_names, 'input')
outputs = self.to_tfl_tensors(output_names, self.output_tensors[0])
graph_converter.add_operator(tfl.SplitVOperator([input_tensor, split_tensor, dim_tensor], outputs, chunks))
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
self.parse_common(node, attrs, args, graph_converter)
class ATenSplitWithSizesOperator(ATenSplitWithSizesSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
ATenSplitOperator.parse_common(self, node, attrs, args, graph_converter)
class ATenChunkOperator(ATenChunkSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
input_tensor = self.find_or_create_input(0, graph_converter)
chunks, dim = self.input_tensors[1:]
if dim < 0:
dim += len(self.input_tensors[0].shape)
dim_size = self.input_tensors[0].size(dim)
if chunks > dim_size:
chunks = dim_size
input_tensor = self.find_or_create_input(0, graph_converter)
dim_tensor = self.create_attr_tensor(np.array(dim, dtype='int32'))
output_names = [f'{self.output_names[0]}:{i}' for i in range(len(self.output_tensors[0]))]
graph_converter.add_iterable_pair(self.output_names, output_names, 'input')
outputs = self.to_tfl_tensors(output_names, self.output_tensors[0])
if dim_size % chunks != 0:
size_splits = np.array([t.size(dim) for t in self.output_tensors[0]], dtype='int32')
chunks = len(size_splits)
split_tensor = self.create_attr_tensor(size_splits)
graph_converter.add_operator(tfl.SplitVOperator([input_tensor, split_tensor, dim_tensor], outputs, chunks))
else:
graph_converter.add_operator(tfl.SplitOperator([dim_tensor, input_tensor], outputs, chunks))
class ATenPixelShuffleOperator(ATenPixelShuffleSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
upscale_factor = self.input_tensors[1]
ops = []
input_tensor = self.find_or_create_input(0, graph_converter)
outputs = self.to_tfl_tensors(self.output_names, self.output_tensors)
# The implementation of tf.depth_to_space and torch.pixel_shuffle is not the same.
# The former one splits the output channel with (block_size, block_size, new_channel_size),
# while the latter one with (new_channel_size, block_size, block_size).
# Since TFLite has no support for transposes for >5D tensors, we need to use `tf.gather`
# to reorder the elements in the channel dimension.
ops.append(tfl.DepthToSpaceOperator([input_tensor], outputs, upscale_factor))
ops = self.wrap_ops_with_nhwc_nchw_transposes(ops)
c = input_tensor.shape[1]
bs = upscale_factor
perm = np.arange(c).reshape(c // (bs**2), bs, bs).transpose(1, 2, 0).flatten()
if not np.array_equal(np.sort(perm), perm):
reordered = self.create_transform_tensor(ops[0].outputs[0].tensor[:, :, :, perm])
indices = self.create_attr_tensor(perm.astype('int32'))
gather_op = tfl.GatherOperator([ops[0].outputs[0], indices], [reordered], axis=3)
ops[1].inputs[0] = reordered
ops.insert(1, gather_op)
for op in ops:
graph_converter.add_operator(op)
class ATenPixelUnshuffleOperator(ATenPixelUnshuffleSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
downscale_factor = self.input_tensors[1]
ops = []
input_tensor = self.find_or_create_input(0, graph_converter)
outputs = self.to_tfl_tensors(self.output_names, self.output_tensors)
# The implementation of tf.space_to_depth and torch.pixel_unshuffle is not the same.
# The former one splits the output channel with (block_size, block_size, new_channel_size),
# while the latter one with (new_channel_size, block_size, block_size).
# Since TFLite has no support for transposes for >5D tensors, we need to use `tf.gather`
# to reorder the elements in the channel dimension.
ops.append(tfl.SpaceToDepthOperator([input_tensor], outputs, downscale_factor))
ops = self.wrap_ops_with_nhwc_nchw_transposes(ops)
c = input_tensor.shape[1]
bs = downscale_factor
perm = np.arange(c * (bs**2)).reshape(bs, bs, c).transpose(2, 0, 1).flatten()
if not np.array_equal(np.sort(perm), perm):
reordered = self.create_transform_tensor(ops[1].outputs[0].tensor[:, :, :, perm])
indices = self.create_attr_tensor(perm.astype('int32'))
gather_op = tfl.GatherOperator([reordered, indices], [ops[1].outputs[0]], axis=3)
ops.insert(2, gather_op)
ops[1].outputs[0] = reordered
for op in ops:
graph_converter.add_operator(op)
class ATenArgmaxOperator(ATenArgmaxSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
assert 'dim' in args and 'keepdim' in args, "aten::argmax(tensor) is not supported"
# Downcast to int32
self.output_tensors[0] = self.output_tensors[0].to(dtype=torch.int32)
self.handle_reduce(tfl.ArgMaxOperator, args, graph_converter, False, tfl_schema.TensorType.INT32)
class ATenArgminOperator(ATenArgminSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
assert 'dim' in args and 'keepdim' in args, "aten::argmin(tensor) is not supported"
# Downcast to int32
self.output_tensors[0] = self.output_tensors[0].to(dtype=torch.int32)
self.handle_reduce(tfl.ArgMinOperator, args, graph_converter, False, tfl_schema.TensorType.INT32)
class ATenExpandOperator(ATenExpandSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
self.parse_common(node, attrs, args, graph_converter)
def parse_common(self, node, attrs, args, graph_converter):
input_tensor = self.find_or_create_input(0, graph_converter)
actual_input = input_tensor
if input_tensor.buffer is None:
outputs = self.to_tfl_tensors(self.output_names, self.output_tensors)
input_shape = input_tensor.shape
output_shape = outputs[0].shape
# No-OP if input tensor is already of desired sizes
if output_shape == input_shape:
self.passthrough(graph_converter)
return
ops = []
new_shape = input_shape
actual_input = input_tensor
if len(output_shape) > len(input_shape):
new_shape = [1] * (len(output_shape) - len(input_shape)) + list(input_shape)
new_shape_arr = np.array(new_shape, dtype='int32')
new_shape_tensor = self.create_attr_tensor(new_shape_arr)
reshaped = self.create_transform_tensor(np.reshape(input_tensor.tensor, new_shape_arr))
actual_input = reshaped
reshape_op = tfl.ReshapeOperator([input_tensor, new_shape_tensor], [reshaped], new_shape_arr)
reshape_op.extra_hints['direction'] = 'up'
ops.append(reshape_op)
repeats = []
for x, y in zip(new_shape, output_shape):
if x != y:
repeats.append(y)
else:
repeats.append(1)
repeat_tensor = self.create_attr_tensor(np.array(repeats, dtype='int32'))
ops.append(tfl.TileOperator([actual_input, repeat_tensor], outputs))
for op in ops:
graph_converter.add_operator(op)
class ATenExpandAsOperator(ATenExpandAsSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
ATenExpandOperator.parse_common(self, node, attrs, args, graph_converter)
class ATenGatherOperator(ATenGatherSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
# torch.gather requires index tensor of type `torch.int64`
orig_type = self.input_tensors[2].dtype
self.input_tensors[2] = self.input_tensors[2].to(dtype=torch.int64)
self.run(node)
input_tensor = self.find_or_create_input(0, graph_converter)
output_tensor = self.to_tfl_tensors(self.output_names, self.output_tensors)[0]
dim, index = self.input_tensors[1:3]
if dim < 0:
dim += input_tensor.tensor.ndim
fake_input = torch.arange(input_tensor.tensor.size).reshape(input_tensor.shape)
fake_output = torch.gather(fake_input, dim, index)
indices = torch.nonzero(fake_input >= 0)[fake_output].to(dtype=torch.int32)
self.input_tensors[2] = self.input_tensors[2].to(dtype=orig_type)
index_tensor = self.find_or_create_input(2, graph_converter)
if index_tensor.buffer is None:
indices_per_dim = torch.split(indices, 1, dim=-1)
indices_tensors = [self.create_attr_tensor(t) for t in indices_per_dim]
index_shape = list(index_tensor.shape) + [1]
axis = len(index_shape) - 1
shape_tensor = self.create_attr_tensor(np.array(index_shape, dtype='int32'))
index_reshaped = self.create_transform_tensor(np.reshape(index_tensor.tensor, index_shape))
reshape_op = tfl.ReshapeOperator([index_tensor, shape_tensor], [index_reshaped], index_shape)
reshape_op.extra_hints['direction'] = 'up'
graph_converter.add_operator(reshape_op)
if str(index_reshaped.dtype) != 'int32':
index_casted = self.create_transform_tensor(index_reshaped.tensor.astype('int32'))
graph_converter.add_operator(
tfl.CastOperator(
[index_reshaped],
[index_casted],
tfl.numpy_tflite_dtype_mappings[str(index_reshaped.dtype)],
tfl.numpy_tflite_dtype_mappings[str(index_casted.dtype)],
)
)
index_reshaped = index_casted
indices_tensors[dim] = index_reshaped
indices_tensor = self.create_transform_tensor(np.concatenate([x.tensor for x in indices_tensors], axis=-1))
graph_converter.add_operator(tfl.ConcatenationOperator(indices_tensors, [indices_tensor], axis=axis))
else:
indices_tensor = self.create_attr_tensor(indices)
graph_converter.add_operator(tfl.GatherNdOperator([input_tensor, indices_tensor], [output_tensor]))
class ATenScatterOperator(ATenScatterSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
assert not any(
(torch.is_nonzero(v) for v in self.input_tensors[0].flatten())
), "aten::scatter with non-zero input is not supported"
# torch.scatter requires index tensor of type `torch.int64`
orig_type = self.input_tensors[2].dtype
self.input_tensors[2] = self.input_tensors[2].to(dtype=torch.int64)
self.run(node)
assert 'reduce' not in args, "aten::scatter with reduction is not supported"
input_tensor = self.find_or_create_input(0, graph_converter)
assert input_tensor.buffer is not None, "aten::scatter with variable input is not supported"
output_tensor = self.to_tfl_tensors(self.output_names, self.output_tensors)[0]
dim, index = self.input_tensors[1:3]
if dim < 0:
dim += input_tensor.tensor.ndim
fake_input = torch.arange(input_tensor.tensor.size).reshape(input_tensor.shape)
fake_output = torch.gather(fake_input, dim, index)
indices = torch.nonzero(fake_input >= 0)[fake_output].to(dtype=torch.int32)
self.input_tensors[2] = self.input_tensors[2].to(dtype=orig_type)
index_tensor = self.find_or_create_input(2, graph_converter)
if index_tensor.buffer is None:
indices_per_dim = torch.split(indices, 1, dim=-1)
indices_tensors = [self.create_attr_tensor(t) for t in indices_per_dim]
index_shape = list(index_tensor.shape) + [1]
axis = len(index_shape) - 1
shape_tensor = self.create_attr_tensor(np.array(index_shape, dtype='int32'))
index_reshaped = self.create_transform_tensor(np.reshape(index_tensor.tensor, index_shape))
reshape_op = tfl.ReshapeOperator([index_tensor, shape_tensor], [index_reshaped], index_shape)
reshape_op.extra_hints['direction'] = 'up'
graph_converter.add_operator(reshape_op)
if str(index_reshaped.dtype) != 'int32':
index_casted = self.create_transform_tensor(index_reshaped.tensor.astype('int32'))
graph_converter.add_operator(
tfl.CastOperator(
[index_reshaped],
[index_casted],
tfl.numpy_tflite_dtype_mappings[str(index_reshaped.dtype)],
tfl.numpy_tflite_dtype_mappings[str(index_casted.dtype)],
)
)
index_reshaped = index_casted
indices_tensors[dim] = index_reshaped
indices_tensor = self.create_transform_tensor(np.concatenate([x.tensor for x in indices_tensors], axis=-1))
graph_converter.add_operator(tfl.ConcatenationOperator(indices_tensors, [indices_tensor], axis=axis))
else:
indices_tensor = self.create_attr_tensor(indices)
if isinstance(self.input_tensors[3], (int, float)):
fill_arr = np.zeros(indices_tensor.shape[:-1], dtype=input_tensor.dtype)
fill_arr.fill(self.input_tensors[3])
fill_tensor = self.create_attr_tensor(fill_arr)
else:
val_tensor = self.find_or_create_input(3, graph_converter)
val_slices = []
for i in indices_tensor.shape:
val_slices.append(slice(i))
val_slices = tuple(val_slices[: len(val_tensor.shape)])
val_sliced = val_tensor.tensor.__getitem__(val_slices)
if val_tensor.buffer is None:
if val_tensor.shape != indices_tensor.shape:
sizes = np.array(indices_tensor.tensor.shape[:-1], dtype='int32')
starts = np.zeros(indices_tensor.tensor.ndim - 1, dtype='int32')
size_tensor = self.create_attr_tensor(sizes)
start_tensor = self.create_attr_tensor(starts)
fill_tensor = self.create_transform_tensor(val_sliced)
graph_converter.add_operator(
tfl.SliceOperator([val_tensor, start_tensor, size_tensor], [fill_tensor])
)
else:
fill_tensor = self.create_attr_tensor(val_sliced)
shape_tensor = self.create_attr_tensor(np.array(input_tensor.shape, dtype='int32'))
graph_converter.add_operator(
tfl.ScatterNdOperator([indices_tensor, fill_tensor, shape_tensor], [output_tensor])
)
class ATenIndexPutOperator(ATenIndexPutSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
# torch.Tensor.index_put_ requires index tensor of type `torch.int64`
accumulate = self.input_tensors[3]
assert not accumulate, "aten::index_put_ with accumulate=True is not supported"
orig_type = self.input_tensors[1][0].dtype
self.input_tensors[1] = tuple([x.to(dtype=torch.int64) for x in self.input_tensors[1]])
self.run(node)
input_tensor = self.find_or_create_input(0, graph_converter)
output_tensor = self.to_tfl_tensors(self.output_names, self.output_tensors)[0]
self.input_tensors[1] = tuple([x.to(dtype=orig_type) for x in self.input_tensors[1]])
if graph_converter.has_nested_names(self.input_names[1]):
input_names = graph_converter.get_list_expanded_names(self.input_names[1])
indices_tensors = self.to_tfl_tensors(
input_names, self.input_tensors[1], graph_converter=graph_converter, non_existent_as_buffer=True
)
else:
if type(self.input_tensors[1]) in (tuple, list):
indices_tensors = [self.create_attr_tensor(x) for x in self.input_tensors[1]]
else:
indices_tensors = [self.find_or_create_input(1, graph_converter)]
dim = input_tensor.tensor.ndim
indices_shape = [x.tensor.size for x in indices_tensors]
max_len = max(indices_shape)
indices_shape_tensor = torch.tensor(indices_shape)
left_indices = (torch.arange(max_len).view(-1, 1).expand(-1, len(indices_shape)) % indices_shape_tensor).int()
if len(indices_tensors) < dim:
pad_shape = list(input_tensor.shape[len(indices_tensors) :])
pad_indices = torch.ones(pad_shape).nonzero().int()
left_len = len(indices_shape)
right_len = len(pad_shape)
left_size = left_indices.size(0)
right_size = pad_indices.size(0)
left_reshaped = left_indices.view(-1, 1, left_len).expand(-1, right_size, left_len).reshape(-1, left_len)
right_reshaped = pad_indices.view(1, -1, right_len).expand(left_size, -1, right_len).reshape(-1, right_len)
all_indices = torch.cat([left_reshaped, right_reshaped], 1).unbind(1)
else:
all_indices = left_indices.unbind(1)
new_indices = []
for i in range(dim):
if i < len(indices_tensors):
idx_tensor = indices_tensors[i]
actual_idx = np.take(idx_tensor.tensor, all_indices[i].numpy())
else:
actual_idx = all_indices[i].numpy()
if idx_tensor.buffer is None and i < len(indices_tensors):
actual_idx_t = self.create_transform_tensor(actual_idx)
fake_idx_t = self.create_attr_tensor(all_indices[i].numpy())
graph_converter.add_operator(tfl.GatherOperator([idx_tensor, fake_idx_t], [actual_idx_t], axis=0))
if str(actual_idx_t.dtype) != 'int32':
index_casted = self.create_transform_tensor(actual_idx_t.tensor.astype('int32'))
graph_converter.add_operator(
tfl.CastOperator(
[actual_idx_t],
[index_casted],
tfl.numpy_tflite_dtype_mappings[str(actual_idx_t.dtype)],
tfl.numpy_tflite_dtype_mappings[str(index_casted.dtype)],
)
)
actual_idx_t = index_casted
new_indices.append(actual_idx_t)
else:
new_indices.append(self.create_attr_tensor(actual_idx.astype(np.int32)))
index_arr = np.stack([x.tensor for x in new_indices], 1)
if all((x.buffer is not None for x in new_indices)):
index_tensor = self.create_attr_tensor(index_arr)
else:
index_tensor = self.create_transform_tensor(index_arr)
graph_converter.add_operator(tfl.PackOperator(new_indices, [index_tensor], dim, axis=1))
val_tensor = self.find_or_create_input(2, graph_converter)
actual_val = val_tensor
orig_val_shape = val_tensor.shape
target_val_shape = index_tensor.shape[:-1]
if orig_val_shape != target_val_shape:
if val_tensor.buffer is None:
new_shape = orig_val_shape
val_reshaped = val_tensor
if len(target_val_shape) > len(orig_val_shape):
new_shape = [1] * (len(target_val_shape) - len(orig_val_shape)) + list(orig_val_shape)
new_shape_arr = np.array(new_shape, dtype='int32')
new_shape_tensor = self.create_attr_tensor(new_shape_arr)
reshaped = self.create_transform_tensor(np.reshape(val_tensor.tensor, new_shape_arr))
val_reshaped = reshaped
reshape_op = tfl.ReshapeOperator([val_tensor, new_shape_tensor], [reshaped], new_shape_arr)
reshape_op.extra_hints['direction'] = 'up'
graph_converter.add_operator(reshape_op)
repeats = []
for x, y in zip(new_shape, target_val_shape):
if x != y:
repeats.append(y // x)
else:
repeats.append(1)
actual_val = self.create_transform_tensor(np.tile(val_reshaped.tensor, repeats))
repeat_tensor = self.create_attr_tensor(np.array(repeats, dtype='int32'))
graph_converter.add_operator(tfl.TileOperator([val_reshaped, repeat_tensor], [actual_val]))
else:
actual_val = self.create_attr_tensor(np.broadcast_to(val_tensor.tensor, target_val_shape))
shape_tensor = self.create_attr_tensor(np.array(input_tensor.shape, dtype='int32'))
if input_tensor.buffer is None or index_tensor.buffer is None:
old_val_tensor = self.create_transform_tensor(actual_val.tensor)
graph_converter.add_operator(tfl.GatherNdOperator([input_tensor, index_tensor], [old_val_tensor]))
else:
transformed_index = tuple(index_tensor.tensor[..., i] for i in range(index_tensor.shape[-1]))
old_val_tensor = self.create_attr_tensor(input_tensor.tensor[transformed_index])
if actual_val.buffer is None:
update_tensor = self.create_transform_tensor(actual_val.tensor - old_val_tensor.tensor)
graph_converter.add_operator(tfl.SubOperator([actual_val, old_val_tensor], [update_tensor]))
else:
update_tensor = self.create_attr_tensor(actual_val.tensor - old_val_tensor.tensor)
updated_tensor = self.create_transform_tensor(input_tensor.tensor)
graph_converter.add_operator(
tfl.ScatterNdOperator([index_tensor, update_tensor, shape_tensor], [updated_tensor])
)
graph_converter.add_operator(tfl.AddOperator([input_tensor, updated_tensor], [output_tensor]))
class ATenGeluOperator(ATenGeluSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
ops = []
input_tensor = self.find_or_create_input(0, graph_converter)
output_tensor = self.to_tfl_tensors(self.output_names, self.output_tensors)[0]
approximate = "none"
if 'approximate' in args:
approximate = self.input_tensors[args['approximate']] or "none"
if self.legacy_gelu:
if approximate == "none":
warnings.warn('aten::gelu[approximate="none"] is not supported with legacy_gelu=True')
constant_tensor = self.create_attr_tensor(np.array([1.702], dtype='float32'))
sigmoid_in = self.create_transform_tensor(input_tensor.tensor * constant_tensor.tensor)
actual_input = input_tensor
if input_tensor.quantization is not None:
actual_input = self.create_transform_tensor(actual_input.tensor.astype('float32'))
ops.append(tfl.DequantizeOperator([input_tensor], [actual_input]))
ops.append(tfl.MulOperator([actual_input, constant_tensor], [sigmoid_in]))
sigmoid_out = self.create_transform_tensor(torch.sigmoid(torch.from_numpy(input_tensor.tensor)).numpy())
ops.append(tfl.LogisticOperator([sigmoid_in], [sigmoid_out]))
if input_tensor.quantization is not None:
actual_output = self.create_transform_tensor(output_tensor.tensor.astype('float32'))
ops.append(tfl.MulOperator([sigmoid_out, actual_input], [actual_output]))
ops.append(tfl.QuantizeOperator([actual_output], [output_tensor]))
else:
ops.append(tfl.MulOperator([sigmoid_out, actual_input], [output_tensor]))
else:
op = tfl.GeluOperator([input_tensor], [output_tensor], approximate == "none")
ops.append(op)
for op in ops:
graph_converter.add_operator(op)
class ATenCopyOperator(ATenCopySchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
other = self.input_tensors[1]
output_tensor = self.to_tfl_tensors(self.output_names, self.output_tensors)[0]
ops = []
if isinstance(other, torch.Tensor):
other_tensor = self.find_or_create_input(1, graph_converter)
if other_tensor.buffer is None:
other_shape = other_tensor.shape
output_shape = output_tensor.shape
actual_input = other_tensor
if other_tensor.dtype != output_tensor.dtype:
casted = self.create_transform_tensor(other_tensor.tensor.astype(output_tensor.dtype))
actual_input = casted
ops.append(
tfl.CastOperator(
[other_tensor],
[casted],
inDataType=tfl.numpy_tflite_dtype_mappings[str(other_tensor.dtype)],
outDataType=tfl.numpy_tflite_dtype_mappings[str(output_tensor.dtype)],
)
)
if other_shape == output_shape:
shape_tensor = self.create_attr_tensor(np.array(other_shape, dtype='int32'))
ops.append(tfl.ReshapeOperator([actual_input, shape_tensor], [output_tensor], shape_tensor.tensor))
else:
new_shape = other_shape
if len(output_shape) > len(other_shape):
new_shape = [1] * (len(output_shape) - len(other_shape)) + list(other_shape)
new_shape_arr = np.array(new_shape, dtype='int32')
new_shape_tensor = self.create_attr_tensor(new_shape_arr)
reshaped = self.create_transform_tensor(np.reshape(actual_input.tensor, new_shape_arr))
reshape_op = tfl.ReshapeOperator([actual_input, new_shape_tensor], [reshaped], new_shape_arr)
reshape_op.extra_hints['direction'] = 'up'
ops.append(reshape_op)
actual_input = reshaped
repeats = []
for x, y in zip(new_shape, output_shape):
if x != y:
repeats.append(y)
else:
repeats.append(1)
repeat_tensor = self.create_attr_tensor(np.array(repeats, dtype='int32'))
ops.append(tfl.TileOperator([actual_input, repeat_tensor], [output_tensor]))
for op in ops:
graph_converter.add_operator(op)
class ATenQuantizedLstmOperator(ATenQuantizedLstmSchema, ATenLstmOperator):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
input_tensor, hidden_state_tensors, params_tensors = self.input_tensors[:3]
has_biases, num_layers, dropout, is_train, bidirectional, batch_first = self.input_tensors[3:9]
params_l = []
for t in params_tensors:
weight_l = []
bias_l = []
params = self.unpack_params(t)[1][0]
inner_params = params[-1]
for p in inner_params:
unpacked = self.unpack_params(p)[1]
w = unpacked[0]
weight_l.append(w[0])
if len(w) > 1:
bias_l.append(w[1])
params_l.extend(weight_l)
params_l.extend(bias_l)
self.parse_common(
input_tensor,
hidden_state_tensors,
params_l,
has_biases,
num_layers,
dropout,
is_train,
bidirectional,
batch_first,
graph_converter,
)
class ATenBmmOperator(ATenBmmSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
ATenMatmulOperator.parse_common(self, node, attrs, args, graph_converter)
class ATenEqOperator(ATenEqSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
if not isinstance(self.input_tensors[1], torch.Tensor):
self.input_tensors[1] = self.torch_tensor_from_scalar(self.input_tensors[0], self.input_tensors[1])
self.elementwise_binary(tfl.EqualOperator, graph_converter, True)
class ATenNegOperator(ATenNegSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
self.elementwise_unary(tfl.NegOperator, graph_converter)
class ATenBitwiseNotOperator(ATenBitwiseNotSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
assert self.input_tensors[0].dtype == torch.bool, "Only bools are supported in aten::bitwise_not"
self.elementwise_unary(tfl.LogicalNotOperator, graph_converter)
class ATenBitwiseAndOperator(ATenBitwiseAndSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
self.parse_common(graph_converter)
def parse_common(self, graph_converter):
other = self.input_tensors[1]
if not isinstance(other, torch.Tensor):
self.input_tensors[1] = torch.tensor([other]).repeat(self.input_tensors[0].shape)
assert all((t.dtype == torch.bool for t in self.input_tensors)), "Only bools are supported in aten::bitwise_not"
self.elementwise_binary(tfl.LogicalAndOperator, graph_converter, False)
class ATenBitwiseOrOperator(ATenBitwiseOrSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
self.parse_common(graph_converter)
def parse_common(self, graph_converter):
other = self.input_tensors[1]
if not isinstance(other, torch.Tensor):
self.input_tensors[1] = torch.tensor([other]).repeat(self.input_tensors[0].shape)
assert all((t.dtype == torch.bool for t in self.input_tensors)), "Only bools are supported in aten::bitwise_not"
self.elementwise_binary(tfl.LogicalOrOperator, graph_converter, False)
class ATenAndOperator(ATenAndSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
ATenBitwiseAndOperator.parse_common(self, graph_converter)
class ATenOrOperator(ATenOrSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
ATenBitwiseOrOperator.parse_common(self, graph_converter)
class ATenSumOperator(ATenSumSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
self.handle_reduce(tfl.SumOperator, args, graph_converter, False)
class ATenProdOperator(ATenProdSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
self.handle_reduce(tfl.ReduceProdOperator, args, graph_converter, False)
class ATenMinOperator(ATenMinSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
if 'other' in args:
self.elementwise_binary(tfl.MinimumOperator, graph_converter, True)
else:
self.handle_reduce(tfl.ReduceMinOperator, args, graph_converter, False)
class ATenMaxOperator(ATenMaxSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
if 'other' in args:
self.elementwise_binary(tfl.MaximumOperator, graph_converter, True)
else:
self.handle_reduce(tfl.ReduceMaxOperator, args, graph_converter, False)
class ATenAminOperator(ATenAminSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
self.handle_reduce(tfl.ReduceMinOperator, args, graph_converter, False)
class ATenAmaxOperator(ATenAmaxSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
self.handle_reduce(tfl.ReduceMaxOperator, args, graph_converter, False)
class ATenGluOperator(ATenGluSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
input_tensor = self.find_or_create_input(0, graph_converter)
dim = self.input_tensors[1]
if dim < 0:
dim += input_tensor.tensor.ndim
ops = []
mid_arrs = np.split(input_tensor.tensor, 2, axis=dim)
dim_tensor = self.create_attr_tensor(np.array(dim, dtype='int32'))
mid_tensors = [self.create_transform_tensor(t) for t in mid_arrs]
ops.append(tfl.SplitOperator([dim_tensor, input_tensor], mid_tensors, 2))
with_act = self.create_transform_tensor(torch.sigmoid(torch.from_numpy(mid_tensors[1].tensor)))
ops.append(tfl.LogisticOperator([mid_tensors[1]], [with_act]))
outputs = self.to_tfl_tensors(self.output_names, self.output_tensors)
ops.append(tfl.MulOperator([mid_tensors[0], with_act], outputs))
for op in ops:
graph_converter.add_operator(op)
class ATenMaskedFillOperator(ATenMaskedFillSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
self.parse_common(graph_converter)
def parse_common(self, graph_converter, input_idx=0, mask_idx=1, other_idx=2, out_idx=0):
for i in (input_idx, other_idx):
t = self.input_tensors[i]
if type(t) is torch.Tensor:
if t.dtype == torch.float64:
self.input_tensors[i] = t.to(dtype=torch.float32)
elif t.dtype == torch.int64:
self.input_tensors[i] = t.to(dtype=torch.int32)
if self.output_tensors[out_idx].dtype == torch.float64:
self.output_tensors[out_idx] = self.output_tensors[out_idx].to(dtype=torch.float32)
elif self.output_tensors[out_idx].dtype == torch.int64:
self.output_tensors[out_idx] = self.output_tensors[out_idx].to(dtype=torch.int32)
mask = self.input_tensors[mask_idx]
other = self.input_tensors[other_idx]
out = self.output_tensors[out_idx]
input_tensor, mask_tensor = [self.find_or_create_input(i, graph_converter) for i in (input_idx, mask_idx)]
ops = []
if type(other) is torch.Tensor:
other_t = self.find_or_create_input(other_idx, graph_converter)
if out.dtype != other.dtype:
casted = other.clone().to(dtype=out.dtype)
if other_t.buffer is None:
new_other = self.create_transform_tensor(casted)
ops.append(
tfl.CastOperator(
[other_t],
[new_other],
tfl.torch_tflite_dtype_mappings[other.dtype],
tfl.torch_tflite_dtype_mappings[out.dtype],
)
)
other_t = new_other
# TODO: +/- inf check for variable tensors
else:
if hasattr(torch.functional, 'atleast_1d'):
casted = torch.functional.atleast_1d(casted)
elif len(casted.shape) == 0:
casted = casted.reshape(1)
if torch.isinf(casted).any():
log.warning(
'aten::masked_fill(input, mask, value) where value=[+/-]inf is not supported, '
'trying to convert it to the nearest value'
)
type_info = torch.finfo(casted.dtype)
clamped = torch.clamp(casted, type_info.min, type_info.max)
other_t = self.create_attr_tensor(clamped, name=self.input_names[other_idx])
else:
other_t = self.create_attr_tensor(casted, name=self.input_names[other_idx])
elif type(other) in (int, float):
other_a = np.array([other], dtype=self.input_tensors[input_idx].detach().numpy().dtype)
if np.isinf(other_a).any():
log.warning(
'aten::masked_fill(input, mask, value) where value=[+/-]inf is not supported, '
'trying to convert it to the nearest value'
)
type_info = np.finfo(other_a.dtype)
other_a = np.clip(other_a, type_info.min, type_info.max)
other_t = self.create_attr_tensor(other_a)
else:
assert False, "value should have type float, tensor in aten::masked_fill(input, mask, value)"
if mask_tensor.buffer is None:
input_mask = self.create_transform_tensor(mask_tensor.tensor.astype(input_tensor.dtype))
ops.append(
tfl.CastOperator(
[mask_tensor],
[input_mask],
tfl.torch_tflite_dtype_mappings[mask.dtype],
tfl.torch_tflite_dtype_mappings[out.dtype],
)
)
else:
input_mask = self.create_attr_tensor(mask_tensor.tensor.astype(input_tensor.dtype))
if mask_tensor.buffer is None or other_t.buffer is None:
masked = self.create_transform_tensor(other_t.tensor * mask_tensor.tensor)
ops.append(tfl.MulOperator([other_t, input_mask], [masked]))
else:
masked = self.create_attr_tensor(other_t.tensor * mask_tensor.tensor)
one_tensor = self.create_attr_tensor(np.array([1], dtype=input_tensor.dtype))
if mask_tensor.buffer is None:
rev_mask = self.create_transform_tensor(one_tensor.tensor - mask_tensor.tensor)
ops.append(tfl.SubOperator([one_tensor, input_mask], [rev_mask]))
else:
rev_mask = self.create_attr_tensor(one_tensor.tensor - mask_tensor.tensor)
non_masked = self.create_transform_tensor(input_tensor.tensor * rev_mask.tensor)
ops.append(tfl.MulOperator([input_tensor, rev_mask], [non_masked]))
outputs = self.to_tfl_tensors(self.output_names, self.output_tensors)
ops.append(tfl.AddOperator([non_masked, masked], outputs))
for op in ops:
graph_converter.add_operator(op)
class ATenMaximumOperator(ATenMaximumSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
if not isinstance(self.input_tensors[1], torch.Tensor):
self.input_tensors[1] = self.torch_tensor_from_scalar(self.input_tensors[0], self.input_tensors[1])
self.elementwise_binary(tfl.MaximumOperator, graph_converter, True)
class ATenMinimumOperator(ATenMinimumSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
if not isinstance(self.input_tensors[1], torch.Tensor):
self.input_tensors[1] = self.torch_tensor_from_scalar(self.input_tensors[0], self.input_tensors[1])
self.elementwise_binary(tfl.MinimumOperator, graph_converter, True)
class ATenGtOperator(ATenGtSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
if not isinstance(self.input_tensors[1], torch.Tensor):
self.input_tensors[1] = self.torch_tensor_from_scalar(self.input_tensors[0], self.input_tensors[1])
self.elementwise_binary(tfl.GreaterOperator, graph_converter, True)
class ATenLtOperator(ATenLtSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
if not isinstance(self.input_tensors[1], torch.Tensor):
self.input_tensors[1] = self.torch_tensor_from_scalar(self.input_tensors[0], self.input_tensors[1])
self.elementwise_binary(tfl.LessOperator, graph_converter, True)
class ATenGeOperator(ATenGeSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
if not isinstance(self.input_tensors[1], torch.Tensor):
self.input_tensors[1] = self.torch_tensor_from_scalar(self.input_tensors[0], self.input_tensors[1])
self.elementwise_binary(tfl.GreaterEqualOperator, graph_converter, np.True_)
class ATenLeOperator(ATenLeSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
if not isinstance(self.input_tensors[1], torch.Tensor):
self.input_tensors[1] = self.torch_tensor_from_scalar(self.input_tensors[0], self.input_tensors[1])
self.elementwise_binary(tfl.LessEqualOperator, graph_converter, True)
class ATenRemainderOperator(ATenRemainderSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
if not isinstance(self.input_tensors[1], torch.Tensor):
self.input_tensors[1] = torch.tensor([self.input_tensors[1]], dtype=self.input_tensors[0].dtype)
self.elementwise_binary(tfl.FloorModOperator, graph_converter, True)
class ATenWhereOperator(ATenWhereSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
assert 'self' in args and 'other' in args, "aten::where(condition) is not supported"
if not isinstance(self.input_tensors[2], torch.Tensor):
self.input_tensors[2] = torch.tensor([self.input_tensors[2]])
ATenMaskedFillOperator.parse_common(self, graph_converter, input_idx=2, mask_idx=0, other_idx=1)
class ATenTypeAsOperator(ATenTypeAsSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
ATenToOperator.parse_common(self, node, attrs, args, graph_converter)
class ATenTopkOperator(ATenTopkSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
input_tensor, k, dim, largest, sorted = self.input_tensors[:5]
assert dim in (input_tensor.ndim - 1, -1), 'tflite topk only support last dim'
assert largest in (1, True) and sorted in (1, True), 'tflite topk only support largest=True and sorted=True'
input_tensor = self.find_or_create_input(0, graph_converter)
k = self.create_attr_tensor(np.array([k], dtype='int32'))
inputs = [input_tensor, k]
self.output_tensors[1] = self.output_tensors[1].to(dtype=torch.int32)
outputs = self.to_tfl_tensors(self.output_names, self.output_tensors)
op = tfl.TopkV2Operator(inputs, outputs)
graph_converter.add_operator(op)
class ATenCumsumOperator(ATenCumsumSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
input_tensor, dim = self.input_tensors[:2]
if dim < 0:
dim += input_tensor.ndim
input_tensor = self.find_or_create_input(0, graph_converter)
dim_tensor = self.create_attr_tensor(np.array([dim], dtype='int32'))
inputs = [input_tensor, dim_tensor]
outputs = self.to_tfl_tensors(self.output_names, self.output_tensors)
graph_converter.add_operator(tfl.CumsumOperator(inputs, outputs))
class ATenMeshgridOperator(ATenMeshgridSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
assert False, "aten::meshgrid for dynamic tensors is not supported"
class ATenFillOperator(ATenFillSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
assert False, "aten::fill_ for dynamic tensors is not supported"
class ATenUnbindOperator(ATenUnbindSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
input_tensor = self.find_or_create_input(0, graph_converter)
dim = self.input_tensors[1]
if dim < 0:
dim += len(self.input_tensors[0].shape)
chunks = self.input_tensors[0].shape[dim]
output_names = [f'{self.output_names[0]}:{i}' for i in range(chunks)]
graph_converter.add_iterable_pair(self.output_names, output_names, 'input')
outputs = self.to_tfl_tensors(output_names, self.output_tensors[0])
if str(input_tensor.dtype) == 'int64' and input_tensor.tensor.ndim == 1 and input_tensor.tensor.size == 1:
shape_tensor = self.create_attr_tensor(np.array((), dtype='int32'))
graph_converter.add_operator(tfl.ReshapeOperator([input_tensor, shape_tensor], outputs, []))
else:
graph_converter.add_operator(tfl.UnpackOperator([input_tensor], outputs, chunks, dim))
class ATenRollOperator(ATenRollSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
input_tensor = self.find_or_create_input(0, graph_converter)
shifts, dims = self.input_tensors[1:3]
ops = []
actual_input = input_tensor
if len(dims) == 0:
assert len(shifts) == 1
shift = shifts[0]
if len(input_tensor.shape) != 1:
flat_input = self.create_transform_tensor(
input_tensor.tensor.ravel(), quantization=input_tensor.quantization
)
flat_shape = self.create_attr_tensor(np.array(flat_input.shape, dtype='int32'))
prev_reshape_op = tfl.ReshapeOperator([input_tensor, flat_shape], [flat_input], flat_shape.tensor)
prev_reshape_op.extra_hints['direction'] = 'up'
ops.append(prev_reshape_op)
actual_input = flat_input
dims.append(0)
assert len(shifts) == len(dims)
for shift, dim in zip(shifts, dims):
if dim < 0:
dim += len(actual_input.shape)
dim_size = actual_input.shape[dim]
if shift < 0:
shift += dim_size
actual_shift = shift % dim_size
if actual_shift != 0:
split_sizes = self.create_attr_tensor(np.array([dim_size - actual_shift, actual_shift], dtype='int32'))
dim_tensor = self.create_attr_tensor(np.array(dim, dtype='int32'))
chunks = 2
splitted = [
self.create_transform_tensor(x, quantization=actual_input.quantization)
for x in np.split(actual_input.tensor, [actual_shift], dim)
]
ops.append(tfl.SplitVOperator([actual_input, split_sizes, dim_tensor], splitted, chunks))
reversed_s = splitted[::-1]
outputs = [
self.create_transform_tensor(
np.concatenate([s.tensor for s in reversed_s], dim), quantization=actual_input.quantization
)
]
ops.append(tfl.ConcatenationOperator(reversed_s, outputs, dim))
else:
inputs = [actual_input, self.create_attr_tensor(actual_input.shape)]
outputs = [
self.create_transform_tensor(actual_input.tensor.copy(), quantization=actual_input.quantization)
]
ops.append(tfl.ReshapeOperator(inputs, outputs, input_tensor.shape))
actual_input = outputs[0]
output_tensor = self.to_tfl_tensors(self.output_names, self.output_tensors)[0]
if len(actual_input.shape) != len(output_tensor.shape):
output_shape = self.create_attr_tensor(np.array(output_tensor.shape, dtype='int32'))
post_reshape_op = tfl.ReshapeOperator([actual_input, output_shape], [output_tensor], output_shape.tensor)
post_reshape_op.extra_hints['direction'] = 'down'
ops.append(post_reshape_op)
else:
ops[-1].outputs[0] = output_tensor
for op in ops:
graph_converter.add_operator(op)
class ATenPadOperator(ATenPadSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
input_tensor = self.find_or_create_input(0, graph_converter)
pads = self.input_tensors[1]
mode = self.input_tensors[2]
constant_value = self.input_tensors[3]
op_cls_dict = {'constant': (tfl.PadOperator, tfl.Padv2Operator), 'reflect': (tfl.MirrorPadOperator, None)}
assert mode in op_cls_dict, f"Unknown mode for aten::pad : {mode}"
orig_pad = np.array(pads, dtype='int32').reshape(-1, 2)
pad_fill = np.zeros((input_tensor.tensor.ndim - orig_pad.shape[0], 2), dtype='int32')
pad_arr = np.flip(np.concatenate((orig_pad, pad_fill)), 0)
pad_tensor = self.create_attr_tensor(pad_arr)
inputs = [input_tensor, pad_tensor]
outputs = self.to_tfl_tensors(self.output_names, self.output_tensors)
if constant_value not in (0, 0.0, None):
output = outputs[0]
if output.quantization is None:
constant_arr = np.array([constant_value], dtype='float32')
else:
float_arr = torch.tensor([constant_value], dtype=torch.float32)
constant_arr = torch.quantize_per_tensor(
float_arr, output.quantization.scale, output.quantization.zero_point, torch.quint8
)
inputs.append(self.create_attr_tensor(constant_arr))
graph_converter.add_operator(op_cls_dict[mode][1](inputs, outputs))
else:
graph_converter.add_operator(op_cls_dict[mode][0](inputs, outputs))
class ATenRoundOperator(ATenRoundSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
self.elementwise_unary(tfl.RoundOperator, graph_converter)
class ATenNormOperator(ATenNormSchema):
def parse_common(self, node, attrs, args, graph_converter):
p = self.input_tensors[1]
assert p in (1, 2), "only torch.norm with p=1,2 is supported"
input_t = self.find_or_create_input(0, graph_converter)
if 'dim' in args and 'keepdim' in args and self.input_tensors[args['dim']] is not None:
dims, keep_dim = self.input_tensors[2:4]
if type(dims) not in (list, tuple):
dims = [dims]
if len(dims) == 0:
dims = list(range(input_t.tensor.ndim))
self.output_tensors[0] = self.output_tensors[0].view(1)
elif len(dims) == input_t.tensor.ndim:
self.output_tensors[0] = self.output_tensors[0].view(1)
else:
dims = list(range(input_t.tensor.ndim))
keep_dim = False
self.output_tensors[0] = self.output_tensors[0].view(1)
for idx, dim in enumerate(dims):
if dim < 0:
dims[idx] += input_t.tensor.ndim
outputs = self.to_tfl_tensors(self.output_names, self.output_tensors)
dim_t = self.create_attr_tensor(np.array(dims, dtype='int32'))
ops = []
if p == 1:
tgt_t = self.create_transform_tensor(np.abs(input_t.tensor))
ops.append(tfl.AbsOperator([input_t], [tgt_t]))
actual_output = outputs[0]
else:
tgt_t = self.create_transform_tensor(np.power(input_t.tensor, 2))
two_t = self.create_attr_tensor(np.array([2.0], dtype='float32'))
ops.append(tfl.PowOperator([input_t, two_t], [tgt_t]))
actual_output = self.create_transform_tensor(outputs[0].tensor)
ops.append(tfl.SumOperator([tgt_t, dim_t], [actual_output], keepDims=keep_dim))
if actual_output != outputs[0]:
half_t = self.create_attr_tensor(np.array([0.5], dtype='float32'))
ops.append(tfl.PowOperator([actual_output, half_t], outputs))
for op in ops:
graph_converter.add_operator(op)
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
self.parse_common(node, attrs, args, graph_converter)
class ATenFrobeniusNormOperator(ATenFrobeniusNormSchema):
def parse_common(self, node, attrs, args, graph_converter):
assert 'p' not in args
self.input_tensors.insert(1, 2)
ATenNormOperator.parse_common(self, node, attrs, args, graph_converter)
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
self.parse_common(node, attrs, args, graph_converter)
class ATenLinalgVectorNormOperator(ATenLinalgVectorNormSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
ATenNormOperator.parse_common(self, node, attrs, args, graph_converter)
class ATenAbsOperator(ATenAbsSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
self.elementwise_unary(tfl.AbsOperator, graph_converter)
class ATenIm2colOperator(ATenIm2colSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
input_tensor = self.find_or_create_input(0, graph_converter)
assert input_tensor.tensor.ndim == 4, "only 4-D input tensors (batched image-like tensors) are supported"
output_tensors = self.to_tfl_tensors(self.output_names, self.output_tensors)
kernel_h, kernel_w = self.input_tensors[1]
dilation_h, dilation_w = self.input_tensors[2]
padding_h, padding_w = self.input_tensors[3]
stride_h, stride_w = self.input_tensors[4]
orig_pad = np.array([padding_h, padding_h, padding_w, padding_w], dtype='int32').reshape(-1, 2)
pad_fill = np.zeros((input_tensor.tensor.ndim - orig_pad.shape[0], 2), dtype='int32')
pad_arr = np.flip(np.concatenate((orig_pad, pad_fill)), 0)
pad_tensor = self.create_attr_tensor(pad_arr)
inter_tensor = self.create_transform_tensor(
np.pad(input_tensor.tensor, ((0, 0), (0, 0), (padding_h, padding_h), (padding_w, padding_w)))
)
graph_converter.add_operator(tfl.PadOperator([input_tensor, pad_tensor], [inter_tensor]))
fake_input = torch.arange(0.0, inter_tensor.tensor.size).reshape(inter_tensor.shape)
fake_output = torch.nn.functional.unfold(
fake_input, (kernel_h, kernel_w), (dilation_h, dilation_w), (0, 0), (stride_h, stride_w)
).to(dtype=torch.int64)
indices = torch.nonzero(fake_input >= 0)[fake_output].to(dtype=torch.int32)
indices_tensor = self.create_attr_tensor(indices)
graph_converter.add_operator(tfl.GatherNdOperator([inter_tensor, indices_tensor], output_tensors))
class ATenCol2imOperator(ATenCol2imSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
input_tensor = self.find_or_create_input(0, graph_converter)
assert input_tensor.tensor.ndim in (
2,
3,
), "Currently, only unbatched (3D) or batched (4D) image-like output tensors are supported"
output_tensors = self.to_tfl_tensors(self.output_names, self.output_tensors)
output_size_h, output_size_w = self.input_tensors[1]
kernel_h, kernel_w = self.input_tensors[2]
dilation_h, dilation_w = self.input_tensors[3]
padding_h, padding_w = self.input_tensors[4]
stride_h, stride_w = self.input_tensors[5]
fold_out = torch.nn.functional.fold(
torch.from_numpy(input_tensor.tensor),
(output_size_h, output_size_w),
(kernel_h, kernel_w),
(dilation_h, dilation_w),
(padding_h, padding_w),
(stride_h, stride_w),
)
padded_fold_out = torch.nn.functional.pad(fold_out, (padding_w, padding_w, padding_h, padding_h)).numpy()
fake_input = torch.arange(0.0, padded_fold_out.size).reshape(padded_fold_out.shape)
if input_tensor.tensor.ndim == 2:
fake_input = fake_input.unsqueeze(0)
fake_output = torch.nn.functional.unfold(
fake_input, (kernel_h, kernel_w), (dilation_h, dilation_w), (0, 0), (stride_h, stride_w)
).to(dtype=torch.int64)
if input_tensor.tensor.ndim == 2:
fake_input = fake_input.squeeze(0)
fake_output = fake_output.squeeze(0)
indices = torch.nonzero(fake_input >= 0)[fake_output].to(dtype=torch.int32)
indices_tensor = self.create_attr_tensor(indices)
shape_tensor = self.create_attr_tensor(np.array(padded_fold_out.shape, dtype='int32'))
padded_fold_out_tensor = self.create_transform_tensor(padded_fold_out)
graph_converter.add_operator(
tfl.ScatterNdOperator([indices_tensor, input_tensor, shape_tensor], [padded_fold_out_tensor])
)
fake_input = torch.arange(0.0, padded_fold_out.size).reshape(padded_fold_out.shape)
fake_output = fake_input[..., padding_h : output_size_h + padding_h, padding_w : output_size_w + padding_w].to(
dtype=torch.int64
)
indices = torch.nonzero(fake_input >= 0)[fake_output].to(dtype=torch.int32)
indices_tensor = self.create_attr_tensor(indices)
graph_converter.add_operator(tfl.GatherNdOperator([padded_fold_out_tensor, indices_tensor], output_tensors))
class ATenAddbmmOperator(ATenAddbmmSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
input_tensor, batch1_tensor, batch2_tensor = [self.find_or_create_input(i, graph_converter) for i in range(3)]
output_tensors = self.to_tfl_tensors(self.output_names, self.output_tensors)
assert (
batch1_tensor.tensor.ndim == batch2_tensor.tensor.ndim == 3
), "batch1 and batch2 must be 3-D tensors each containing the same number of matrices"
bmm_out = torch.bmm(torch.from_numpy(batch1_tensor.tensor), torch.from_numpy(batch2_tensor.tensor))
bmm_out_tensor = self.create_transform_tensor(bmm_out)
graph_converter.add_operator(tfl.BatchMatmulOperator([batch1_tensor, batch2_tensor], [bmm_out_tensor]))
sum_bmm_out = torch.sum(bmm_out, dim=0)
sum_bmm_out_tensor = self.create_transform_tensor(sum_bmm_out)
dim_t = self.create_attr_tensor(np.array([0], dtype='int32'))
graph_converter.add_operator(tfl.SumOperator([bmm_out_tensor, dim_t], [sum_bmm_out_tensor], keepDims=False))
graph_converter.add_operator(tfl.AddOperator([input_tensor, sum_bmm_out_tensor], output_tensors))
class ATenBaddbmmOperator(ATenBaddbmmSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
input_tensor, batch1_tensor, batch2_tensor = [self.find_or_create_input(i, graph_converter) for i in range(3)]
output_tensors = self.to_tfl_tensors(self.output_names, self.output_tensors)
assert (
batch1_tensor.tensor.ndim == batch2_tensor.tensor.ndim == 3
), "batch1 and batch2 must be 3-D tensors each containing the same number of matrices"
bmm_out = torch.bmm(torch.from_numpy(batch1_tensor.tensor), torch.from_numpy(batch2_tensor.tensor))
bmm_out_tensor = self.create_transform_tensor(bmm_out)
graph_converter.add_operator(tfl.BatchMatmulOperator([batch1_tensor, batch2_tensor], [bmm_out_tensor]))
graph_converter.add_operator(tfl.AddOperator([input_tensor, bmm_out_tensor], output_tensors))
class ATenMishOperator(ATenMishSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
ops = []
input_tensor = self.find_or_create_input(0, graph_converter)
exp_out = self.create_transform_tensor(np.exp(input_tensor.tensor))
ops.append(tfl.ExpOperator([input_tensor], [exp_out]))
one_tensor = self.create_attr_tensor(np.ones((1,), dtype=exp_out.dtype))
add_out = self.create_transform_tensor(exp_out.tensor + one_tensor.tensor)
ops.append(tfl.AddOperator([exp_out, one_tensor], [add_out]))
softplus_out = self.create_transform_tensor(np.log(add_out.tensor))
ops.append(tfl.LogOperator([add_out], [softplus_out]))
tanh_out = self.create_transform_tensor(np.tanh(softplus_out.tensor))
ops.append(tfl.TanhOperator([softplus_out], [tanh_out]))
outputs = self.to_tfl_tensors(self.output_names, self.output_tensors)
ops.append(tfl.MulOperator([input_tensor, tanh_out], outputs))
for op in ops:
graph_converter.add_operator(op)
class ATenBroadcastTensorsOperator(ATenBroadcastTensorsSchema):
def parse(self, node, attrs, args, graph_converter):
super().parse(node, attrs, args, graph_converter)
self.run(node)
input_names = graph_converter.get_list_expanded_names(self.input_names[0])
inputs = self.to_tfl_tensors(
input_names, self.input_tensors[0], graph_converter=graph_converter, non_existent_as_buffer=True
)
output_names = [f'{self.output_names[0]}:{i}' for i in range(len(input_names))]
outputs = self.to_tfl_tensors(output_names, self.output_tensors[0])
graph_converter.add_iterable_pair(self.output_names, output_names, 'input')
ops = []
for inp, outp in zip(inputs, outputs):
input_shape = inp.shape
output_shape = outp.shape
# No-OP if input tensor is already of desired sizes
if output_shape == input_shape:
inputs = [inp, self.create_attr_tensor(inp.shape)]
ops.append(tfl.ReshapeOperator(inputs, [outp], inp.shape))
continue
new_shape = input_shape
actual_input = inp
if len(output_shape) > len(input_shape):
new_shape = [1] * (len(output_shape) - len(input_shape)) + list(input_shape)
new_shape_arr = np.array(new_shape, dtype='int32')
new_shape_tensor = self.create_attr_tensor(new_shape_arr)
reshaped = self.create_transform_tensor(np.reshape(inp.tensor, new_shape_arr))
actual_input = reshaped
reshape_op = tfl.ReshapeOperator([inp, new_shape_tensor], [reshaped], new_shape_arr)
reshape_op.extra_hints['direction'] = 'up'
ops.append(reshape_op)
repeats = []
for x, y in zip(new_shape, output_shape):
if x != y:
repeats.append(y)
else:
repeats.append(1)
repeat_tensor = self.create_attr_tensor(np.array(repeats, dtype='int32'))
ops.append(tfl.TileOperator([actual_input, repeat_tensor], [outp]))
for op in ops:
graph_converter.add_operator(op)