tinynn/graph/rewriter.py (28 lines of code) (raw):
import torch
import torch.nn as nn
def gen_layernorm(mod):
class CustomLayerNorm(torch.autograd.Function):
@staticmethod
def symbolic(g, input):
return g.op(
"trt::LayerNorm",
input,
g.op("Constant", value_t=mod.weight.data),
g.op("Constant", value_t=mod.bias.data),
epsilon_f=mod.eps,
axis_i=-len(mod.normalized_shape),
)
@staticmethod
def forward(ctx, x):
return torch.nn.functional.layer_norm(x, mod.normalized_shape, mod.weight.data, mod.bias.data, mod.eps)
return CustomLayerNorm
MOD_DICT = {nn.LayerNorm: gen_layernorm}
def gen_rewrite_hook(gen):
def rewrite_hook(mod, inp, outp):
return gen(mod).apply(inp[0])
return rewrite_hook
def rewrite_for_tensorrt_export(model):
for m in model.modules():
gen = MOD_DICT.get(type(m), None)
if gen is not None:
m.register_forward_hook(gen_rewrite_hook(gen))