in tinynn/graph/quantization/quantizer.py [0:0]
def rewrite_quantize_graph(self, graph: TraceGraph) -> None:
"""Rewrites the computation graph for quantization"""
if graph.quantized:
return
graph_quantized = False
for n in graph.forward_nodes:
if n.type() in (torch_q.QuantStub, torch_q.DeQuantStub):
graph_quantized = True
break
for n in graph.other_init_nodes:
if n.type() is nnq.FloatFunctional:
graph_quantized = True
break
if graph_quantized:
log.warning(
'Graph is quantized. No need to rewrite. Please pass in `config={"rewrite_graph": False}` to suppress'
' this warning'
)
return
if self.backend == 'tensorrt':
return self.rewrite_quantize_graph_for_tensorrt(graph)
if self.layerwise_default is False:
for n in graph.forward_nodes:
self.layerwise_config.setdefault(n.unique_name, False)
creation_func_names = load_creation_func_names()
def _is_extra_constant_nodes(node, custom_data):
return node.full_name() in creation_func_names
extra_constant_nodes = graph.filter_forward_nodes(_is_extra_constant_nodes)
def _is_int_to_float_nodes(node, custom_data):
if node.full_name() in creation_func_names:
return False
if len(node.prev_nodes) == 1 and len(node.next_nodes) == 1:
if len(node.prev_tensors) == 1 and len(node.next_tensors) == 1:
if node.prev_nodes[0].kind() == 'shape' and node.prev_nodes[0].module.is_property:
return False
if (
isinstance(node.prev_tensors[0], torch.Tensor)
and isinstance(node.next_tensors[0], torch.Tensor)
and node.prev_tensors[0].dtype in (torch.int32, torch.int64)
and torch.is_floating_point(node.next_tensors[0])
):
return True
else:
return False
int_to_float_nodes = graph.filter_forward_nodes(_is_int_to_float_nodes)
# When converting float-tensor to int16/int32/int64, we need to add 'fake_dequant' node before convert-node.
def _is_float_to_non_float_nodes(node, custom_data):
if isinstance(node.module, TraceFunction) and node.module.kind in ('shape', 'size'):
return False
return (
len(node.next_tensors) == 1
and len(node.prev_tensors) > 0
and isinstance(node.prev_tensors[0], torch.Tensor)
and isinstance(node.next_tensors[0], torch.Tensor)
and torch.is_floating_point(node.prev_tensors[0])
and not torch.is_floating_point(node.next_tensors[0])
)
float_to_non_float_nodes = graph.filter_forward_nodes(_is_float_to_non_float_nodes)
def _is_params_in_module(node, custom_data):
if len(node.prev_nodes) == 1 and len(node.next_nodes) == 1:
if len(node.prev_tensors) == 1 and len(node.next_tensors) == 1:
if isinstance(node.prev_tensors[0], nn.Module) and not isinstance(
node.prev_tensors[0], nnq.FloatFunctional
):
return True
return False
param_nodes = graph.filter_forward_nodes(_is_params_in_module)
for node in param_nodes:
prev_node = node.prev_nodes[0]
is_known_mod = prev_node.kind().__name__ in (
'Conv1d',
'Conv2d',
'Linear',
'ConvTranspose1d',
'ConvTranspose2d',
)
prop = node.module.full_name
if is_known_mod and prop in ('weight', 'bias') and self.layerwise_config.get(prev_node.unique_name, True):
node.module.is_class = False
node.module.is_property = False
node.module.full_name = 'tinynn.graph.quantization.utils.get_parameter'
node.module.args_template = f'{{}}, "{prop}"'
node.module.args_template_no_self = f'"{prop}"'
node.module.args_offset = 0
node.module.update_args_string()
# First, we insert the QuantStub nodes for every input/constant node
for idx, node in reversed(
list(
enumerate(
graph.input_nodes + graph.constant_nodes + extra_constant_nodes + int_to_float_nodes + param_nodes
)
)
):
if node.next_tensors[0].dtype in (torch.int32, torch.int64):
continue
fake_quant = torch_q.QuantStub()
graph.module_unique_name_dict[id(fake_quant)] = f'fake_quant_{idx}'
graph.module_original_name_dict[id(fake_quant)] = f'fake_quant_{idx}'
fake_quant_cls = type(fake_quant)
module_constructor_lines[id(fake_quant)] = f'{qualified_name(fake_quant_cls)}()'
graph.insert_after(node, fake_quant)
# Second, we insert the DeQuantStub nodes for every output node
for idx, node in enumerate(graph.output_nodes + float_to_non_float_nodes):
fake_dequant_cls = torch_q.DeQuantStub
if node.rev_index:
modules = []
for rev_idx in range(len(node.prev_nodes)):
if node.prev_tensors[rev_idx].dtype in (torch.int32, torch.int64):
fake_dequant = nn.Identity()
else:
fake_dequant = fake_dequant_cls()
graph.module_unique_name_dict[id(fake_dequant)] = f'fake_dequant_{idx}_{rev_idx}'
graph.module_original_name_dict[id(fake_dequant)] = f'fake_dequant_{idx}_{rev_idx}'
module_constructor_lines[id(fake_dequant)] = f'{qualified_name(fake_dequant_cls)}()'
modules.append(fake_dequant)
graph.insert_before(node, modules, move_idx=True)
else:
if node.prev_tensors[0].dtype in (torch.int32, torch.int64):
continue
fake_dequant = fake_dequant_cls()
graph.module_unique_name_dict[id(fake_dequant)] = f'fake_dequant_{idx}'
graph.module_original_name_dict[id(fake_dequant)] = f'fake_dequant_{idx}'
module_constructor_lines[id(fake_dequant)] = f'{qualified_name(fake_dequant_cls)}()'
if len(node.prev_nodes) > 1:
# Insert 'fake_dequant' node before type conversion operators
graph.insert_between(node.prev_nodes[0], node, fake_dequant, move_idx=True)
else:
graph.insert_before(node, fake_dequant, move_idx=True)
# Third, we rewrite neg/sub/div using supported functions(e.g add, mul)
def _is_neg_node(node: TraceNode, custom_data):
cur_module = node.module
cur_class = type(cur_module)
if cur_class == TraceFunction:
return (
cur_module.kind == 'neg'
and torch.is_floating_point(cur_module.prev_tensors[0])
and self.layerwise_config.get(node.unique_name, True)
)
neg_nodes = graph.filter_forward_nodes(_is_neg_node)
log.info(f'rewriting neg for {[node.unique_name for node in neg_nodes]}')
for idx, node in enumerate(neg_nodes):
node.module.func_type = '__mul__'
node.module.kind = 'mul'
full_name_parts = node.module.full_name.split('.')
full_name_parts[-1] = node.module.func_type
node.module.full_name = '.'.join(full_name_parts)
with override_current_trace_graph(graph):
node.module.parse_args(node.prev_tensors[0], -1.0)
def _is_div_node(node: TraceNode, custom_data):
cur_module = node.module
cur_class = type(cur_module)
# Current, only the following condition could be handled.
# a / constant => a * (1 / constant)
if cur_class == TraceFunction:
return (
(cur_module.kind == 'truediv' or cur_module.func_type in ('div', 'div_'))
and len(cur_module.prev_tensors) == 1
and torch.is_floating_point(cur_module.prev_tensors[0])
and cur_module.func_type != '__rtruediv__'
and torch.is_floating_point(node.next_tensors[0])
and node.prev_nodes[0].kind() not in ('size', 'shape')
and self.layerwise_config.get(node.unique_name, True)
)
div_nodes = graph.filter_forward_nodes(_is_div_node)
log.info(f'rewriting div for {[node.unique_name for node in div_nodes]}')
for idx, node in enumerate(div_nodes):
op_type = node.module.func_type
inplace = op_type in ('__itruediv__', 'itruediv', 'div_')
if inplace:
node.module.func_type = '__imul__'
else:
node.module.func_type = '__mul__'
node.module.kind = 'mul'
full_name_parts = node.module.full_name.split('.')
full_name_parts[-1] = node.module.func_type
node.module.full_name = '.'.join(full_name_parts)
# Here we make a simple guess, if dot is in the string, then it's a floating number.
# Otherwise, it is an integral number.
if '.' in node.module.args_string_no_self:
other_arg = float(node.module.args_string_no_self)
else:
other_arg = int(node.module.args_string_no_self)
with override_current_trace_graph(graph):
node.module.parse_args(node.prev_tensors[0], 1.0 / other_arg)
def _is_sub_node(node: TraceNode, custom_data):
cur_module = node.module
cur_class = type(cur_module)
if cur_class == TraceFunction:
return (
cur_module.kind == 'sub'
and torch.is_floating_point(cur_module.prev_tensors[0])
and torch.is_floating_point(node.next_tensors[0])
and self.layerwise_config.get(node.unique_name, True)
)
sub_nodes = graph.filter_forward_nodes(_is_sub_node)
log.info(f'rewriting sub for {[node.unique_name for node in sub_nodes]}')
for idx, node in enumerate(sub_nodes):
op_type = node.module.func_type
full_name_parts = node.module.full_name.split('.')
full_name_parts[-1] = node.module.func_type
if len(node.module.prev_tensors) == 1 and node.module.func_type != '__rsub__':
node.module.func_type = '__add__'
node.module.kind = 'add'
node.module.full_name = '.'.join(full_name_parts)
# Here we make a simple guess, if dot is in the string, then it's a floating number.
# Otherwise, it is an integral number.
if '.' in node.module.args_string_no_self or 'e' in node.module.args_string_no_self:
other_arg = float(node.module.args_string_no_self)
else:
other_arg = int(node.module.args_string_no_self)
with override_current_trace_graph(graph):
node.module.parse_args(node.prev_tensors[0], -other_arg)
elif len(node.module.prev_tensors) == 2 and len(node.prev_nodes) == 2:
new_fullname_parts = copy.deepcopy(full_name_parts)
new_fullname_parts[-1] = '__mul__'
new_fullname = '.'.join(new_fullname_parts)
current_tensor = node.prev_tensors[0]
input_node = node.prev_nodes[1]
input_tensor = node.prev_tensors[1]
output_tensor = input_tensor * -1
with override_current_trace_graph(graph):
trace_func = TraceFunction(new_fullname, True, prefix='fuse_').parse_args(input_tensor, -1)
graph.insert_between(input_node, node, trace_func, [output_tensor], True)
node.module.func_type = '__add__'
node.module.kind = 'add'
full_name_parts[-1] = node.module.func_type
node.module.full_name = '.'.join(full_name_parts)
node.module.parse_args(current_tensor, output_tensor)
elif node.module.func_type == '__rsub__' and len(node.module.prev_tensors) == 1:
new_fullname_parts = copy.deepcopy(full_name_parts)
new_fullname_parts[-1] = '__mul__'
new_fullname = '.'.join(new_fullname_parts)
input_node = node.prev_nodes[0]
input_tensor = node.prev_tensors[0]
output_tensor = input_tensor * -1
if '.' in node.module.args_string_no_self:
other_arg = float(node.module.args_string_no_self)
else:
other_arg = int(node.module.args_string_no_self)
with override_current_trace_graph(graph):
trace_func = TraceFunction(new_fullname, True, prefix='fuse_').parse_args(input_tensor, -1)
graph.insert_between(input_node, node, trace_func, [output_tensor], True)
node.module.func_type = '__radd__'
node.module.kind = 'add'
full_name_parts[-1] = node.module.func_type
node.module.full_name = '.'.join(full_name_parts)
# Then, we write torch.stack nodes to torch.unsqueeze + torch.cat
def _is_stack_node(node: TraceNode, custom_data):
cur_module = node.module
cur_class = type(cur_module)
return (
cur_class == TraceFunction
and cur_module.kind == 'stack'
and torch.is_floating_point(node.next_tensors[0])
and self.layerwise_config.get(node.unique_name, True)
)
stack_nodes = graph.filter_forward_nodes(_is_stack_node)
for idx, node in enumerate(stack_nodes):
args = getattr(node.module, 'args_string_no_self', '')
if ',' in args:
log.error('rewrite doesn\'t support multiple args for torch.stack')
assert False
if len(args) > 0:
if '=' in args:
k, v = args.split('=')
assert k in ('axis', 'dim')
dim = int(v)
if k == 'axis':
node.module.args_template_no_self = node.module.args_template_no_self.replace('axis', 'dim')
node.module.args_template = node.module.args_template.replace('axis', 'dim')
node.module.update_args_string()
else:
dim = int(args)
else:
dim = 0
unique_prev_nodes = {n.unique_name: n for n in node.prev_nodes}.values()
for n in unique_prev_nodes:
shared_tensors = list(set(node.prev_tensors).intersection(set(n.next_tensors)))
if len(shared_tensors) == 0:
log.debug('tensor rewrite already done, skipping')
continue
for t in shared_tensors:
with override_current_trace_graph(graph):
trace_func = TraceFunction('torch.unsqueeze', prefix='fuse_').parse_args(t, dim)
next_tensors = [torch.unsqueeze(t, dim)]
graph.insert_between(n, node, trace_func, next_tensors, move_idx=True, tensor_ptrs=set([id(t)]))
node.module.func_type = 'cat'
node.module.kind = 'cat'
full_name_parts = node.module.full_name.split('.')
full_name_parts[-1] = node.module.func_type
node.module.full_name = '.'.join(full_name_parts)
# Next, we rewrite add/mul/cat with one float32 output using torch.nn.quantized.FloatFunctional
def _is_convertible_node(node: TraceNode, custom_data):
cur_module = node.module
cur_class = type(cur_module)
if cur_class == TraceFunction:
return (
cur_module.kind in ('add', 'mul', 'cat')
and torch.is_floating_point(node.next_tensors[0])
and self.layerwise_config.get(node.unique_name, True)
and all((torch.is_floating_point(pt) for pt in node.prev_tensors))
)
convertible_nodes = graph.filter_forward_nodes(_is_convertible_node)
unfusable_add_nodes = [] # noqa: F841
log.info(f'rewriting add/mul/cat for {[node.unique_name for node in convertible_nodes]}')
for idx, node in enumerate(convertible_nodes):
old_full_name = node.module.full_name # noqa: F841
old_is_class = node.module.is_class # noqa: F841
op_kind = node.module.kind
float_functional = nnq.FloatFunctional()
module_name = f'float_functional_simple_{idx}'
graph.module_unique_name_dict[id(float_functional)] = module_name
graph.module_original_name_dict[id(float_functional)] = module_name
float_functional_cls = type(float_functional)
module_constructor_lines[id(float_functional)] = f'{qualified_name(float_functional_cls, short=True)}()'
new_node = TraceNode(float_functional, cur_graph=graph)
graph.nodes_map[new_node.unique_name] = new_node
graph.other_init_nodes.append(new_node)
node.module.is_class = False
prev_tensor_size = len(node.prev_tensors)
if op_kind in ('add', 'mul'):
if prev_tensor_size == 2:
if node.prev_nodes[1].prev_nodes[0].kind() not in ('shape', 'size'):
op_type = op_kind
else:
op_type = f'{op_kind}_scalar'
elif prev_tensor_size == 1:
op_type = f'{op_kind}_scalar'
else:
log.error(f'Unknown add/mul type for {node.unique_name}, prev tensor size: {prev_tensor_size}')
assert False
else:
# Don't check anything for other OPs.
# It is simply too complex for us.
op_type = op_kind
node.module.func_type = op_type
node.module.full_name = f'self.{module_name}.{op_type}'
# Inplace operations
if node.module.func_type in ['__iadd__', '__imul__', 'add_', 'mul_']:
node.module.add_alias(node.module.tensor_names[0])
q = queue.Queue()
q.put(node.prev_nodes[0])
while not q.empty():
n = q.get()
if type(n.module) is TraceFunction:
prev_aliases = n.module.get_aliases()
if prev_aliases is not None:
for pa in reversed(prev_aliases):
node.module.add_alias(pa, head=True)
else:
if getattr(n.module, 'inplace', False):
q.put(n.prev_nodes[0])
node.module.add_alias(n.prev_node_unique_name(0), head=True)
# We need to convert radd to normal add here
if node.module.func_type in ['__radd__', '__rmul__']:
if '=' in node.module.args_string_no_self or ', ' in node.module.args_string_no_self:
log.error(f'Don\'t know how to translate {node.module.args_string_no_self} for __radd__/__rmul__')
assert False
if prev_tensor_size == 1:
# Here we make a simple guess, if dot is in the string, then it's a floating number.
# Otherwise, it is an integral number.
if '.' in node.module.args_string_no_self or 'e' in node.module.args_string_no_self:
other_arg = float(node.module.args_string_no_self)
else:
other_arg = int(node.module.args_string_no_self)
with override_current_trace_graph(graph):
node.module.parse_args(node.prev_tensors[0], other_arg)
else:
with override_current_trace_graph(graph):
# It is even simple here. We only need to swap the order of the tensors.
node.module.parse_args(node.prev_tensors[1], node.prev_tensors[0])
# Rewrite torch.clamp_{min,max} to torch.clamp
def _is_clamp_min_max_node(node: TraceNode, custom_data):
cur_module = node.module
cur_class = type(cur_module)
if cur_class == TraceFunction:
return (
cur_module.kind in ('clamp_min', 'clamp_max')
and torch.is_floating_point(node.next_tensors[0])
and self.layerwise_config.get(node.unique_name, True)
)
clamp_min_max_nodes = graph.filter_forward_nodes(_is_clamp_min_max_node)
log.info(f'rewriting clamp_{{min,max}} for {[node.unique_name for node in convertible_nodes]}')
for node in clamp_min_max_nodes:
kw = node.module.kind[-3:]
_parse_args = eval(f'lambda {kw}, *args, **kwargs: {kw}')
val = eval(f'_parse_args({node.module.args_string_no_self})')
if kw != 'min' or val != 0.0:
continue
if kw == 'min':
arg_str = f'{val}'
else:
arg_str = f'None, {val}'
sub_str = f'_{kw}'
node.module.kind = node.module.kind.replace(sub_str, '')
node.module.func_type = node.module.func_type.replace(sub_str, '')
node.module.full_name = '.'.join(node.module.full_name.split('.')[:-1] + [node.module.func_type])
node.module.args_template_no_self = arg_str
node.module.args_template = f'{{}}, {arg_str}'
node.module.update_args_string()
# Rewrite torch.clamp to relu, relu6 and clamp_{min, max}
def _is_clamp_node(node: TraceNode, custom_data):
cur_module = node.module
cur_class = type(cur_module)
if cur_class == TraceFunction:
return (
cur_module.kind == 'clamp'
and torch.is_floating_point(node.next_tensors[0])
and self.layerwise_config.get(node.unique_name, True)
)
clamp_nodes = graph.filter_forward_nodes(_is_clamp_node)
log.info(f'rewriting clamp for {[node.unique_name for node in convertible_nodes]}')
for node in clamp_nodes:
def _parse_args(min=None, max=None, *args, **kwargs):
return min, max
min, max = eval(f'_parse_args({node.module.args_string_no_self})')
kind = None
if min == 0.0:
if max is None:
kind = 'relu'
elif max == 6.0:
kind = 'relu6'
if kind is None:
if max is None:
kind = 'clamp_min'
elif min is None:
kind = 'clamp_max'
else:
kind = 'clamp_with_fusion'
if kind in ('clamp_min', 'clamp_max'):
node.module.kind = kind
node.module.func_type = node.module.func_type.replace('clamp', kind)
node.module.full_name = '.'.join(node.module.full_name.split('.')[:-1] + [node.module.func_type])
if max is None:
arg_str = f'{min}'
else:
arg_str = f'{max}'
elif kind == 'clamp_with_fusion':
node.module.is_class = False
node.module.kind = kind
node.module.func_type = node.module.func_type.replace('clamp', kind)
node.module.full_name = f'tinynn.graph.quantization.utils.{node.module.func_type}'
arg_str = f'{min}, {max}'
else:
inplace = node.module.func_type == f'{node.module.kind}_'
node.module.kind = kind
node.module.func_type = kind
node.module.full_name = f'torch.nn.functional.{kind}'
if inplace:
arg_str = 'inplace=True'
else:
arg_str = ''
node.module.args_template_no_self = arg_str
node.module.args_template = f'{{}}, {arg_str}'
node.module.update_args_string()
# Rewrite other fusable functions
# e.g. add_relu(x, y) =>
# r = torch.add(x, y)
# r = torch.nn.functional.relu(r)
def _is_add_relu_fusable_node(node: TraceNode, custom_data) -> bool:
cur_module = node.module
cur_class = type(cur_module)
visited_nodes = [node]
if self.layerwise_config.get(node.unique_name, True) is False:
return False
if cur_class == TraceFunction:
# The intermediate result cannot be used if fused.
# So we don't fuse the nodes under such circumstances.
if cur_module.kind != 'add' or len(node.next_nodes) != 1:
return False
# The input for the add operations should be two tensors.
if len(node.prev_tensors) != 2:
return False
# We accept inplace operations for both add and relu.
# The inplace property could be elinimated because we track the tensors
# instead of their names.
next_node = node.next_nodes[0]
next_module = next_node.module
next_class = type(next_module)
if next_class == TraceFunction:
fuse = next_module.kind == 'relu'
else:
while next_class == nn.Identity:
cur_node = next_node
visited_nodes.append(cur_node)
if len(cur_node.next_nodes) != 1:
return False
next_node = cur_node.next_nodes[0]
next_module = next_node.module
next_class = type(next_module)
fuse = next_class.__name__ == 'ReLU'
if not fuse:
return False
if type(next_node.module) is TraceFunction:
inplace = next_node.module.func_type == 'relu_' or 'True' in next_node.module.args_string
else:
inplace = getattr(next_node.module, 'inplace', False)
# Inplace check
# If add is inplace and relu is not inplace, we need to ensure that all the aliases of
# the first operand of add are not used when relu is called.
if not inplace:
aliases = cur_module.get_aliases()
if aliases:
q = queue.Queue()
q.put(node.prev_nodes[0])
while not q.empty():
n = q.get()
last_order = max((x.forward_order for x in n.next_nodes))
if last_order > node.forward_order:
fuse = False
break
if type(n.module) is TraceFunction and n.module.get_aliases():
q.put(n.prev_nodes[0])
elif getattr(n.module, 'inplace', False):
q.put(n.prev_nodes[0])
return fuse
add_relu_fusable_nodes = graph.filter_forward_nodes(_is_add_relu_fusable_node)
for node in add_relu_fusable_nodes:
full_name = node.module.full_name.replace('add', 'add_relu')
next_node = node.next_nodes[0]
kind = 'add_relu'
func_type = kind
is_class = False
nodes_to_fuse = [node, next_node]
while next_node.type() is nn.Identity:
next_node = next_node.next_nodes[0]
nodes_to_fuse.append(next_node)
if type(next_node.module) is TraceFunction:
inplace = next_node.module.func_type == 'relu_' or 'True' in next_node.module.args_string
else:
inplace = next_node.module.inplace
graph.fuse_nodes_to_func(nodes_to_fuse, full_name, kind, func_type, is_class)
# Propagate aliases for inplace nodes
if inplace:
aliases = node.module.get_aliases()
if aliases:
node.module.add_alias(node.module.tensor_names[0])
else:
node.module.aliases = None
# Rewrite relu, relu6 as nn.ReLU() and nn.ReLU6() for Module fusable rules
def _is_functional_rewrite_node(node: TraceNode, custom_data):
cur_module = node.module
cur_class = type(cur_module)
if cur_class == TraceFunction:
return cur_module.kind in FUNCTIONAL_MODULE_MAPPING and self.layerwise_config.get(
node.unique_name, True
)
func_nodes_to_rewrite = graph.filter_forward_nodes(_is_functional_rewrite_node)
log.info(f'rewriting functional to module for {[node.unique_name for node in func_nodes_to_rewrite]}')
for idx, node in enumerate(func_nodes_to_rewrite):
kind = node.module.kind
inplace = node.module.func_type == f'{kind}_' or 'True' in node.module.args_string
klass = FUNCTIONAL_MODULE_MAPPING[kind]
arguments = getattr(klass, '__constants__', None)
if arguments is None:
new_func = klass()
elif node.module.kind in ('relu', 'relu6', 'silu', 'hardswish'):
new_func = klass(inplace=inplace)
elif node.module.kind in ('elu', 'leaky_relu'):
if hasattr(node.module, 'args_string_no_self'):
def _parse_args(alpha=1.0, *args, **kwargs): # noqa: F811
return alpha
alpha = eval(f'_parse_args({node.module.args_string_no_self})')
if node.module.kind == 'leaky_relu':
new_func = klass(alpha, inplace=inplace)
else:
new_func = klass(alpha, inplace=inplace)
else:
alpha = None
if node.module.kind == 'leaky_relu':
new_func = klass(inplace=inplace)
else:
new_func = klass()
elif node.module.kind == 'prelu':
weight_t = node.prev_tensors[1]
weight_node = node.prev_nodes[1]
if weight_node.type() is torch_q.QuantStub:
weight_node = weight_node.prev_nodes[0]
if weight_node.type() is not ConstantNode or not weight_node.module.is_parameter:
log.warning('Rewrite for F.prelu(x, buffer) to nn.PReLU is skipped as it changes the semantics')
continue
num_parameters = weight_t.nelement()
new_func = klass(num_parameters)
new_func.weight.data.copy_(node.prev_tensors[1])
# Drop last input
node.prev_tensors.pop()
last_node = node.prev_nodes.pop()
last_node.next_nodes.remove(node)
# Remove unused forward functions iteratively
new_last_node = last_node
while new_last_node is not None:
last_node = new_last_node
if (
len(last_node.next_nodes) == 0
and len(last_node.prev_nodes) < 2
and last_node in graph.forward_nodes
):
if len(last_node.prev_nodes) == 1:
new_last_node = last_node.prev_nodes[0]
else:
new_last_node = None
graph.remove_node(last_node)
else:
new_last_node = None
elif node.module.kind == 'glu':
if hasattr(node.module, 'args_string_no_self'):
def _parse_args_dim(dim=-1, *args, **kwargs): # noqa: F811
return dim
dim = eval(f'_parse_args_dim({node.module.args_string_no_self})')
new_func = klass(dim)
else:
dim = None
new_func = klass()
else:
raise NotImplementedError(f"Don't know how to parse {klass.__name__} with argument {arguments}")
graph.module_unique_name_dict[id(new_func)] = f'rewritten_{kind}_{idx}'
graph.module_original_name_dict[id(new_func)] = f'rewritten_{kind}_{idx}'
relu_cls = type(new_func)
if inplace:
arg_str = 'inplace=True'
else:
arg_str = ''
if node.module.kind in ('elu', 'leaky_relu') and alpha is not None:
if arg_str:
arg_str = f'{alpha}, {arg_str}'
else:
arg_str = f'{alpha}'
if node.module.kind == 'prelu':
if num_parameters != 1:
arg_str = f'{num_parameters}'
else:
arg_str = ''
if node.module.kind == 'glu':
if dim is not None and dim != -1:
arg_str = f'{dim}'
else:
arg_str = ''
module_constructor_lines[id(new_func)] = f'{qualified_name(relu_cls)}({arg_str})'
graph.replace_node_module(node, new_func)
# Rewrite dropout as nn.Dropout() for models in training mode
def _is_dropout_functional_node(node: TraceNode, custom_data):
cur_module = node.module
cur_class = type(cur_module)
if cur_class == TraceFunction:
return cur_module.kind == 'dropout' and self.layerwise_config.get(node.unique_name, True)
def _dropout_args(p=0.5, training=True, inplace=False):
return p, training, inplace
dropout_nodes_to_rewrite = graph.filter_forward_nodes(_is_dropout_functional_node)
log.info(f'rewriting dropout for {[node.unique_name for node in dropout_nodes_to_rewrite]}')
for idx, node in enumerate(dropout_nodes_to_rewrite):
args = getattr(node.module, 'args_string_no_self', '')
p, _, inplace = eval(f'_dropout_args({args})')
kind = node.module.kind
inplace = node.module.func_type == f'{kind}_' or inplace
dropout_cls = nn.Dropout
new_dropout = dropout_cls(p, inplace=inplace)
graph.module_unique_name_dict[id(new_dropout)] = f'rewritten_{kind}_{idx}'
graph.module_original_name_dict[id(new_dropout)] = f'rewritten_{kind}_{idx}'
if inplace:
arg_str = f'{p}, inplace={inplace}'
else:
arg_str = f'{p}'
module_constructor_lines[id(new_dropout)] = f'{qualified_name(dropout_cls)}({arg_str})'
graph.replace_node_module(node, new_dropout)
# Add contiguous nodes for partially-supported OPs
# Some of the operations support quantization, but they only accept contiguous input tensors.
def _is_partially_quantizable(node, custom_data):
if self.layerwise_config.get(node.unique_name, True) is False:
return False
cur_module = node.module
cur_class = type(cur_module)
if cur_class == ConstantNode:
return False
elif cur_class == TraceFunction:
return cur_module.kind in ('pad',)
else:
return cur_class in (nn.ConstantPad1d, nn.ConstantPad2d, nn.ConstantPad3d, nn.ZeroPad2d)
if LooseVersion(torch.__version__) >= LooseVersion('1.7.0'):
partially_supported_nodes = graph.filter_forward_nodes(_is_partially_quantizable)
for idx, node in enumerate(partially_supported_nodes):
for n in node.prev_nodes[:1]:
shared_tensors = list(set(node.prev_tensors).intersection(set(n.next_tensors)))
if len(shared_tensors) > 1:
log.error('rewrite for partially-supported ops supports with nodes with exact one input')
assert False
with override_current_trace_graph(graph):
trace_func = TraceFunction('torch.Tensor.contiguous', True, prefix='fuse_').parse_args(
shared_tensors[0]
)
next_tensors = [x.contiguous() for x in shared_tensors]
graph.insert_between(n, node, trace_func, next_tensors, True)
# Remove non-leaf `.data` nodes
def _is_non_leaf_data_nodes(node, custom_data):
if self.layerwise_config.get(node.unique_name, True) is False:
return False
cur_module = node.module
cur_class = type(cur_module)
if cur_class == TraceFunction:
return cur_module.kind == 'data' and cur_module.is_property and len(node.next_nodes) > 0
return False
non_leaf_data_nodes = graph.filter_forward_nodes(_is_non_leaf_data_nodes)
for idx, node in enumerate(non_leaf_data_nodes):
graph.remove_node(node)
# Handle PoolNd with kernel_size=1
def _is_pool_nd_with_one_kernel_size(node, custom_data):
if self.layerwise_config.get(node.unique_name, True) is False:
return False
cur_module = node.module
cur_class = type(cur_module)
kernel_size, stride = None, None
if cur_class == TraceFunction:
if cur_module.kind in ('avg_pool1d', 'avg_pool2d', 'max_pool1d', 'max_pool2d'):
def _avgpool_kernel_size_and_stride(kernel_size, stride=None, *args, **kwargs):
return kernel_size, stride
kernel_size, stride = eval(f'_avgpool_kernel_size_and_stride({cur_module.args_string_no_self})')
else:
if cur_class in (nn.AvgPool1d, nn.AvgPool2d, nn.MaxPool1d, nn.MaxPool2d):
kernel_size = cur_module.kernel_size
stride = cur_module.stride
if kernel_size is not None:
if isinstance(kernel_size, (tuple, list)):
is_match = all((ks == 1 for ks in kernel_size))
else:
is_match = kernel_size == 1
else:
is_match = False
if is_match:
custom_data.append((node, kernel_size, stride))
return True
else:
return False
pool_one_kernel_size_nodes = []
graph.filter_forward_nodes(_is_pool_nd_with_one_kernel_size, pool_one_kernel_size_nodes)
for idx, (node, kernel_size, stride) in enumerate(pool_one_kernel_size_nodes):
slices = [slice(None)] * 2
t = node.prev_tensors[0]
dim = len(t.shape)
if not isinstance(stride, (list, tuple)):
stride = [stride] * (dim - 2)
for s in stride:
if s == 1 or s is None:
slices.append(slice(None))
else:
slices.append(slice(None, None, s))
with override_current_trace_graph(graph):
new_func = TraceFunction('torch.Tensor.__getitem__', True).parse_args(t, slices)
graph.module_unique_name_dict[id(new_func)] = f'rewritten_pool_{idx}'
graph.module_original_name_dict[id(new_func)] = f'rewritten_pool_{idx}'
graph.replace_node_module(node, new_func)
# Rewrite Linear-BatchNorm1d structure to Conv2d-BatchNorm2d
is_rewrite_to_fuse = functools.partial(
self.is_fusable,
current_rules=load_processed_rewrite_to_fuse_rules(),
check_node_quantized=False,
graph=graph,
layerwise_config_default=True,
use_original_name=False,
)
custom_data = ([], set())
graph.filter_forward_nodes(is_rewrite_to_fuse, custom_data, reverse=True)
rewrite_fuse_names_list = custom_data[0]
log.debug(f'found names_list that need to rewrite for fusing: {rewrite_fuse_names_list}')
for idx, names in enumerate(reversed(rewrite_fuse_names_list)):
# case fc-bn1d
assert len(names) == 2, 'the rewrite nodes list length != 2'
node_fc = graph.nodes_map[names[0]]
node_bn1d = graph.nodes_map[names[1]]
mod_fc = node_fc.module
mod_bn = node_bn1d.module
assert type(mod_fc) is nn.Linear and type(mod_bn) is nn.BatchNorm1d, "the rewrite struct is\'t [fc-bn1d]"
if len(node_fc.prev_tensors[0].shape) != 2:
log.debug('the [fc-bn]\'s input dimension != 2')
continue
# for fc-bn1d, rewrite [fc-bn1d] to [conv2d-bn2d]
new_conv2d = torch.nn.Conv2d(
in_channels=mod_fc.in_features,
out_channels=mod_fc.out_features,
kernel_size=[1, 1],
bias=mod_fc.bias is not None,
)
fc_weight = mod_fc.weight
new_conv2d.weight = nn.Parameter(torch.reshape(fc_weight, [fc_weight.shape[0], fc_weight.shape[1], 1, 1]))
if mod_fc.bias is not None:
new_conv2d.bias = mod_fc.bias
graph.module_unique_name_dict[id(new_conv2d)] = f'rewritten_conv2d_bn2d_conv2d_{idx}'
graph.module_original_name_dict[id(new_conv2d)] = f'rewritten_conv2d_bn2d_conv2d_{idx}'
new_bn2d = torch.nn.BatchNorm2d(
mod_bn.num_features,
mod_bn.eps,
mod_bn.momentum,
affine=mod_bn.affine,
track_running_stats=mod_bn.track_running_stats,
)
new_bn2d.load_state_dict(mod_bn.state_dict())
graph.module_unique_name_dict[id(new_bn2d)] = f'rewritten_conv2d_bn2d_bn2d_{idx}'
graph.module_original_name_dict[id(new_bn2d)] = f'rewritten_conv2d_bn2d_bn2d_{idx}'
# replace new node, then insert reshape before new_conv2d and after new_bn2d
with override_current_trace_graph(graph):
graph.replace_node_module(node_fc, new_conv2d)
graph.replace_node_module(node_bn1d, new_bn2d)
prev_func = TraceFunction('torch.Tensor.__getitem__', prefix='rewritten_conv2d_bn2d_').parse_args(
node_fc.prev_tensors[0], (Ellipsis, None, None)
)
next_func = TraceFunction('torch.flatten', prefix='rewritten_conv2d_bn2d_').parse_args(
node_bn1d.next_tensors[0], 1
)
# expand the tensor shape between fc new_conv2d and new_bn2d
node_fc.next_tensors[0].unsqueeze_(2).unsqueeze_(2)
node_bn1d.prev_tensors[0].unsqueeze_(2).unsqueeze_(2)
node_bn1d.next_tensors[0].unsqueeze_(2).unsqueeze_(2)
prev_out = node_fc.prev_tensors[0][..., None, None]
graph.insert_between(node_fc.prev_nodes[0], node_fc, prev_func, [prev_out], True)
next_out = torch.flatten(node_bn1d.next_tensors[0], 1)
graph.insert_after(node_bn1d, next_func, [next_out])
# Rewrite BatchNorm1d to BatchNorm2d
def _is_batch_norm_1d(node, custom_data):
if self.layerwise_config.get(node.unique_name, True) is False:
return False
cur_module = node.module
cur_class = type(cur_module)
if len(node.prev_nodes) != 1:
return False
if node.prev_nodes[0].kind() in ('conv1d', nn.Conv1d):
return False
if cur_class == TraceFunction:
return cur_module.kind == 'batch_norm' and node.prev_tensors[0].ndim == 3
else:
return cur_class == nn.BatchNorm1d
batch_norm_1d_nodes = graph.filter_forward_nodes(_is_batch_norm_1d)
for idx, node in enumerate(batch_norm_1d_nodes):
mod = node.module
if type(mod) is nn.BatchNorm1d:
new_bn = torch.nn.BatchNorm2d(
mod.num_features,
mod.eps,
mod.momentum,
affine=mod.affine,
track_running_stats=mod.track_running_stats,
)
new_bn.load_state_dict(mod.state_dict())
graph.module_unique_name_dict[id(new_bn)] = f'rewritten_bn2d_{idx}'
graph.module_original_name_dict[id(new_bn)] = f'rewritten_bn2d_{idx}'
with override_current_trace_graph(graph):
graph.replace_node_module(node, new_bn)
prev_func = TraceFunction('torch.unsqueeze', prefix='rewritten_bn2d_').parse_args(
node.prev_tensors[0], 2
)
next_func = TraceFunction('torch.squeeze', prefix='rewritten_bn2d_').parse_args(
node.next_tensors[0], 2
)
node.next_tensors[0].unsqueeze_(2)
prev_out = torch.unsqueeze(node.prev_tensors[0], 2)
graph.insert_between(node.prev_nodes[0], node, prev_func, [prev_out], True)
next_out = torch.squeeze(node.next_tensors[0], 2)
graph.insert_after(node, next_func, [next_out])
type_dict = {}
for n in graph.forward_nodes:
if n.type() not in (torch_q.QuantStub, torch_q.DeQuantStub):
self.layerwise_config.setdefault(n.unique_name, True)
kind = n.kind()
type_str = kind if isinstance(kind, str) else kind.__name__
type_dict[n.unique_name] = type_str
fuse_mapping = {}
if self.fused_layerwise_config:
# Find all fusable nodes
if type(self).__name__ == 'QATQuantizer':
processed_all_rules = load_processed_all_qat_rules()
else:
processed_all_rules = load_processed_all_ptq_rules()
is_fusable = functools.partial(
self.is_fusable,
current_rules=processed_all_rules,
check_node_quantized=False,
use_original_name=False,
)
custom_data = ([], set())
graph.filter_forward_nodes(is_fusable, custom_data, reverse=True)
activ_names = custom_data[0]
for idx, names in enumerate(reversed(activ_names)):
types = [graph.nodes_map[n].kind() for n in names]
types = [x.__name__ if not isinstance(x, str) else x for x in types]
types_str = ', '.join(types)
types_str = f'({types_str})'
type_dict[names[0]] = types_str
for name in names[1:]:
fuse_mapping[name] = names[0]
self.layerwise_config[name] = self.layerwise_config.get(names[0], True)
for n, t in type_dict.items():
self.layerwise_config.yaml_add_eol_comment(f'type: {t}', n)
skip_types = set(k[0] for k in REWRITE_QUANTIZABLE_RULE_LIST if len(k) == 1)
for module_cls, action in self.quantize_op_action.items():
if action == 'rewrite':
skip_types.add(module_cls)
if self.set_quantizable_op_stats:
skip_types |= set(KNOWN_QSTATS.keys())
skip_types_prev = skip_types | set(k[-1] for k in REWRITE_QUANTIZABLE_RULE_LIST if len(k) > 1)
skip_types_next = skip_types | set(k[0] for k in REWRITE_QUANTIZABLE_RULE_LIST if len(k) > 1)
# Add quant/dequant nodes for non-quantizable OPs
disable_quantize_op_list = UNSUPPORTED_PYTORCH_QUANTIZATION_OP_LIST.copy()
for module_cls, action in self.quantize_op_action.items():
if action in ('disable', 'rewrite'):
disable_quantize_op_list[module_cls] = None
def _is_rewritable_lstm_node(node, custom_data):
cur_module = node.module
cur_class = type(cur_module)
return cur_class == nn.LSTM
if self.quantize_op_action.get(nn.LSTM, 'enable') == 'rewrite':
rewritable_lstm_nodes = graph.filter_forward_nodes(_is_rewritable_lstm_node)
fake_dequant_cls = torch_q.DeQuantStub
for idx, node in enumerate(rewritable_lstm_nodes):
assert node.module.num_layers == 1, "Quantization rewrite for multi-layer LSTM is not yet supported"
assert not node.module.bidirectional, "Quantization rewrite for bidirectional LSTM is not yet supported"
cell_state = node.next_tensors[1][1]
fake_dequant = fake_dequant_cls()
fake_dequant_name = f'fake_dequant_rewrite_{idx}'
graph.module_unique_name_dict[id(fake_dequant)] = fake_dequant_name
graph.module_original_name_dict[id(fake_dequant)] = fake_dequant_name
module_constructor_lines[id(fake_dequant)] = f'{qualified_name(fake_dequant_cls)}()'
new_node = graph.insert_new_after(
node, fake_dequant, [cell_state], [[1, 1]], before_node=node.next_nodes[0]
)
with override_current_trace_graph(graph):
size_func = TraceFunction(
'torch.Tensor.size', is_class=True, prefix='fake_dequant_rewrite_'
).parse_args(new_node.next_tensors[0], -1)
size_node = graph.insert_new_after(
new_node,
size_func,
[new_node.next_tensors[0]],
[None],
next_tensors=[torch.tensor(new_node.next_tensors[0].size(-1))],
before_node=node.next_nodes[0],
)
size_len = len(node.next_tensors[0].shape)
if node.module.bidirectional:
with override_current_trace_graph(graph):
size_func = TraceFunction(
'torch.Tensor.__mul__', is_class=True, prefix='fake_dequant_rewrite_'
).parse_args(size_node.next_tensors[0], 2)
size_node = graph.insert_new_after(
size_node,
size_func,
[size_node.next_tensors[0]],
[None],
next_tensors=[size_node.next_tensors[0] * 2],
before_node=node.next_nodes[0],
)
with override_current_trace_graph(graph):
expand_func = TraceFunction(
'torch.Tensor.expand', is_class=True, prefix='fake_dequant_rewrite_'
).parse_args(node.next_tensors[0], *((-1,) * (size_len - 1)), size_node.next_tensors[0])
graph.insert_between(
node, node.next_nodes[0], expand_func, tensor_ptrs=[id(node.next_tensors[0])], move_idx=True
)
expand_node = graph.nodes_map[expand_func.unique_name]
size_node.next_nodes.append(expand_node)
expand_node.prev_nodes.append(node)
expand_node.prev_tensors.append(size_node.next_tensors[0])
expand_node.prev_indices.append(None)
def _is_not_quantizable(node, custom_data):
cur_module = node.module
cur_class = type(cur_module)
if cur_class == ConstantNode:
return False
elif cur_class == TraceFunction:
if node.type() in ('__truediv__', '__itruediv__', 'div', 'div_'):
if node.prev_nodes[0].kind() in ('shape', 'size'):
return False
if node.type() in ('shape', 'device', 'size', 'dtype'):
if node.unique_name in self.layerwise_config:
del self.layerwise_config[node.unique_name]
return False
if self.layerwise_config.get(node.unique_name, True) is False:
return True
supported_version = disable_quantize_op_list.get(cur_module.kind, torch.__version__)
return supported_version is None or LooseVersion(torch.__version__) < supported_version
else:
if isinstance(cur_module, (torch_q.QuantStub, torch_q.DeQuantStub)):
return False
if self.layerwise_config.get(node.unique_name, True) is False:
return True
unsupported_types = tuple(
k
for k, v in disable_quantize_op_list.items()
if type(k) is not str
and k not in Q_MODULES_MAPPING
and (v is None or LooseVersion(torch.__version__) < v)
)
return isinstance(cur_module, unsupported_types)
unsupported_nodes = graph.filter_forward_nodes(_is_not_quantizable)
for idx, node in enumerate(reversed(unsupported_nodes)):
if node.unique_name in self.layerwise_config:
if node.kind() not in skip_types and self.layerwise_config[node.unique_name]:
del self.layerwise_config[node.unique_name]
node_map = dict()
next_nodes = {n.unique_name: n for n in node.next_nodes}.values()
for inner_idx, next_node in enumerate(next_nodes):
prev_tensor_ptrs = []
if type(next_node.module) is TraceFunction and next_node.module.is_property:
continue
for pt in next_node.prev_tensors:
for nt in node.next_tensors:
if isinstance(nt, (list, tuple)):
for k, ntt in enumerate(nt):
if id(pt) == id(ntt):
if id(pt) not in prev_tensor_ptrs:
prev_tensor_ptrs.append(id(pt))
break
elif id(pt) == id(nt):
if id(pt) not in prev_tensor_ptrs:
prev_tensor_ptrs.append(id(pt))
break
for ptr_idx, ptr in enumerate(prev_tensor_ptrs):
if ptr in node_map:
fake_quant = node_map[ptr]
graph.insert_between(node, next_node, fake_quant, move_idx=True, tensor_ptrs=[ptr])
else:
fake_quant = torch_q.QuantStub()
fake_quant_name = f'fake_quant_inner_{idx}_{inner_idx}_{ptr_idx}'
graph.module_unique_name_dict[id(fake_quant)] = fake_quant_name
graph.module_original_name_dict[id(fake_quant)] = fake_quant_name
fake_quant_cls = type(fake_quant)
module_constructor_lines[id(fake_quant)] = f'{qualified_name(fake_quant_cls)}()'
graph.insert_between(node, next_node, fake_quant, move_idx=True, tensor_ptrs=[ptr])
node_map[ptr] = graph.nodes_map[fake_quant_name]
# Insert the DeQuantStub nodes before every input node of the unsupported ops
for idx, node in enumerate(unsupported_nodes):
fake_dequant_cls = torch_q.DeQuantStub
assert node.rev_index is False
node_map = dict()
prev_nodes = {n.unique_name: n for n in node.prev_nodes}.values()
for inner_idx, prev_node in enumerate(prev_nodes):
if prev_node.kind() in ('shape', 'device', 'size', 'dtype'):
continue
prev_tensor_ptrs = []
for pt in node.prev_tensors:
for nt in prev_node.next_tensors:
if isinstance(nt, (list, tuple)):
for k, ntt in enumerate(nt):
if id(pt) == id(ntt):
if id(pt) not in prev_tensor_ptrs:
prev_tensor_ptrs.append(id(pt))
break
elif id(pt) == id(nt):
if id(pt) not in prev_tensor_ptrs:
prev_tensor_ptrs.append(id(pt))
break
for ptr_idx, ptr in enumerate(prev_tensor_ptrs):
if ptr in node_map:
fake_dequant = node_map[ptr]
graph.insert_between(prev_node, node, fake_dequant, move_idx=True, tensor_ptrs=set([ptr]))
else:
fake_dequant = fake_dequant_cls()
fake_dequant_name = f'fake_dequant_inner_{idx}_{inner_idx}_{ptr_idx}'
graph.module_unique_name_dict[id(fake_dequant)] = fake_dequant_name
graph.module_original_name_dict[id(fake_dequant)] = fake_dequant_name
module_constructor_lines[id(fake_dequant)] = f'{qualified_name(fake_dequant_cls)}()'
graph.insert_between(prev_node, node, fake_dequant, move_idx=True, tensor_ptrs=set([ptr]))
node_map[ptr] = graph.nodes_map[fake_dequant_name]
# Remove consecutive dequant quant nodes
def _is_consecutive_dequant_quant_nodes(node, custom_data):
cur_type = node.type()
if cur_type in (torch_q.QuantStub, torch_q.DeQuantStub):
for next_node in node.next_nodes:
next_type = next_node.type()
if next_type in (torch_q.QuantStub, torch_q.DeQuantStub):
if cur_type != next_type:
if cur_type == torch_q.QuantStub:
target_node = None
if len(node.prev_nodes) == 1 and node.prev_nodes[0].kind() in skip_types_prev:
target_node = node.prev_nodes[0]
elif (
len(next_node.next_nodes) == 1 and next_node.next_nodes[0].kind() in skip_types_next
):
target_node = next_node.next_nodes[0]
if target_node is not None:
self.layerwise_config.setdefault(target_node.unique_name, True)
if self.layerwise_config[target_node.unique_name]:
return False
custom_data.append((node, next_node))
return True
return False
consecutive_dequant_quant_nodes = []
consecutive_branch_dequant_quant_nodes = []
graph.filter_forward_nodes(_is_consecutive_dequant_quant_nodes, consecutive_dequant_quant_nodes)
for node, next_node in consecutive_dequant_quant_nodes:
if len(node.next_nodes) == 1 and len(next_node.prev_nodes) == 1:
graph.remove_node(next_node)
graph.remove_node(node)
else:
consecutive_branch_dequant_quant_nodes.append((node, next_node))
# TODO: Support complex cases
for node, next_node in consecutive_branch_dequant_quant_nodes:
is_removable = all(
(
n.type() in (torch_q.QuantStub, torch_q.DeQuantStub, 'shape', 'device', 'size', 'dtype')
and n.type() != node.type()
for n in node.next_nodes
)
)
if is_removable:
graph.remove_node(node)
for n in node.next_nodes:
if n.type() in (torch_q.QuantStub, torch_q.DeQuantStub):
graph.remove_node(n)
# Process additional fusable nodes
processed_extra_q_rules = load_processed_extra_q_rules()
is_extra_fusable = functools.partial(
self.is_fusable,
current_rules=processed_extra_q_rules,
check_node_quantized=False,
use_original_name=False,
)
custom_data = ([], set())
graph.filter_forward_nodes(is_extra_fusable, custom_data, reverse=True)
activ_names = custom_data[0]
log.debug(f'found nodes that cannot fuse: {activ_names}')
for idx, names in enumerate(reversed(activ_names)):
name = names[-1]
node = graph.nodes_map[name]
fake_quant = torch_q.QuantStub()
graph.module_unique_name_dict[id(fake_quant)] = f'fake_activ_quant_{idx}'
graph.module_original_name_dict[id(fake_quant)] = f'fake_activ_quant_{idx}'
fake_quant_cls = type(fake_quant)
module_constructor_lines[id(fake_quant)] = f'{qualified_name(fake_quant_cls)}()'
graph.insert_after(node, fake_quant)
fake_dequant = torch_q.DeQuantStub()
graph.module_unique_name_dict[id(fake_dequant)] = f'fake_activ_dequant_{idx}'
graph.module_original_name_dict[id(fake_dequant)] = f'fake_activ_dequant_{idx}'
fake_dequant_cls = type(fake_dequant)
module_constructor_lines[id(fake_dequant)] = f'{qualified_name(fake_dequant_cls)}()'
graph.insert_after(node, fake_dequant)
fused_clamps = []
for names in activ_names:
for name in names:
node = graph.nodes_map[name]
if node.kind() == 'clamp_with_fusion':
fused_clamps.append(name)
# Rewrite unfused clamp_with_fusion to torch.clamp
def _is_unfused_clamp_node(node: TraceNode, custom_data):
cur_module = node.module
cur_class = type(cur_module)
if cur_class == TraceFunction:
return (
cur_module.kind == 'clamp_with_fusion'
and torch.is_floating_point(node.next_tensors[0])
and node.unique_name not in fused_clamps
)
return False
unfused_clamp_nodes = graph.filter_forward_nodes(_is_unfused_clamp_node)
log.info(f'rewriting unfused_clamp for {[node.unique_name for node in unfused_clamp_nodes]}')
for node in unfused_clamp_nodes:
node.module.kind = kind
node.module.func_type = node.module.func_type.replace('clamp_with_fusion', 'clamp')
node.module.full_name = f'torch.{node.module.func_type}'
# Optional tensor shape broadcasting for quantized binary ops
def _is_broadcastable_binary_quantized_op_node(node: TraceNode, custom_data) -> bool:
cur_module = node.module
cur_class = type(cur_module)
if cur_class != TraceFunction:
return False
return (
cur_module.full_name.startswith('self.float_functional_simple_')
and cur_module.func_type in ('add', 'mul', 'add_relu')
and node.prev_tensors[0].shape != node.prev_tensors[1].shape
)
broadcastable_binary_quantized_op_nodes = graph.filter_forward_nodes(_is_broadcastable_binary_quantized_op_node)
for node in broadcastable_binary_quantized_op_nodes:
assert len(node.prev_nodes) == 2
assert len(node.prev_tensors) == 2
l_shape = list(node.prev_tensors[0].shape)
r_shape = list(node.prev_tensors[1].shape)
ref_index = None
if len(l_shape) > len(r_shape):
ref_index = 0
r_shape = [1] * (len(l_shape) - len(r_shape)) + r_shape
elif len(l_shape) < len(r_shape):
ref_index = 1
l_shape = [1] * (len(r_shape) - len(l_shape)) + l_shape
for l_dim, r_dim in zip(l_shape, r_shape):
if l_dim > r_dim:
if ref_index in (None, 0) and r_dim == 1:
ref_index = 0
else:
ref_index = -1
break
elif l_dim < r_dim:
if ref_index in (None, 1) and l_dim == 1:
ref_index = 1
else:
ref_index = -1
break
if ref_index >= 0:
src_index = 1 - ref_index
with override_current_trace_graph(graph):
trace_func = TraceFunction('torch.Tensor.expand_as', True, prefix='fuse_').parse_args(
node.prev_tensors[src_index], node.prev_tensors[ref_index]
)
next_tensors = [node.prev_tensors[src_index].expand_as(node.prev_tensors[ref_index])]
graph.insert_between(node.prev_nodes[src_index], node, trace_func, next_tensors, True)
new_node = graph.nodes_map[trace_func.unique_name]
new_node.prev_nodes.append(node.prev_nodes[ref_index])
new_node.prev_tensors.append(node.prev_tensors[ref_index])
new_node.prev_indices.append(node.prev_indices[ref_index])
node.prev_nodes[ref_index].next_nodes.append(new_node)
else:
new_indices = []
for i in range(2):
if node.prev_indices[i] is None:
new_indices.append(i)
elif isinstance(node.prev_indices[i], (tuple, list)):
new_indices.append(node.prev_indices[i] + [i])
else:
new_indices.append([node.prev_indices[i], i])
with override_current_trace_graph(graph):
trace_func = TraceFunction('torch.broadcast_tensors', False, prefix='fuse_').parse_args(
node.prev_tensors[0], node.prev_tensors[1]
)
next_tensors = torch.broadcast_tensors(node.prev_tensors[0], node.prev_tensors[1])
graph.insert_before(node, trace_func, next_tensors, False, new_indices)
for name in fuse_mapping:
if name in self.layerwise_config:
del self.layerwise_config[name]
for name in self.layerwise_config:
node = graph.nodes_map[name]
if isinstance(node.module, nn.Module) and not isinstance(
node.module,
(nn.Dropout, nn.AdaptiveAvgPool2d, nn.AvgPool2d, nn.MaxPool2d, nn.ReLU, nn.Upsample, nn.ConstantPad2d),
):
self.effective_layers.append(name)
elif node.kind() in ('add', 'mul', 'cat') or node.kind() in FUNCTIONAL_MODULE_MAPPING:
self.effective_layers.append(name)
graph.quantized = True
graph.recompute_forward_order()