in python/flexflow/torch/fx.py [0:0]
def torch_to_flexflow_str(model):
graph = __symbolic_trace(model)
lines = []
for name,parameter in model.named_parameters():
splitted_name = name.split(".")
if not (splitted_name[-1] in ["weight","bias"]):
fx_name = "_"+"_".join(splitted_name)
print(fx_name)
op_str = fx_name+", "
op_str = parse_inoutedge(op_str,(),())
op_str = parse_parameter(op_str,parameter)
lines.append(op_str)
for node in graph:
# op name
op_str = node.name + ", "
print(node.name, type(node))
#op type
if type(node) == InputNode:
op_str = parse_inoutedge(op_str, node.inedges, node.outedges)
op_str = parse_input(op_str, node)
if type(node) == OutputNode:
if type(node.inedges[0]) == tuple:
op_str = parse_inoutedge(op_str, node.inedges[0], node.outedges)
else:
op_str = parse_inoutedge(op_str, node.inedges, node.outedges)
op_str = parse_output(op_str, node)
if type(node) == FunctionNode:
function_name = str(node.function)
if function_name.find('add') >= 0:
if type(node.inedges[1]) is float:
op_str = parse_inoutedge(op_str, (node.inedges[0],), node.outedges)
op_str = parse_scalaradd(op_str,node)
else:
op_str = parse_inoutedge(op_str, node.inedges, node.outedges)
op_str = parse_add(op_str, node)
elif function_name.find('sub') >= 0:
if type(node.inedges[1]) is float:
op_str = parse_inoutedge(op_str, (node.inedges[0],), node.outedges)
op_str = parse_scalarsub(op_str,node)
else:
assert 0, "Unknown binary subtraction operator"
op_str = parse_inoutedge(op_str, node.inedges, node.outedges)
op_str = parse_add(op_str, node)
elif function_name.find('truediv') >= 0:
if type(node.inedges[1]) is float:
op_str = parse_inoutedge(op_str, (node.inedges[0],), node.outedges)
op_str = parse_scalartruediv(op_str,node)
else:
assert 0, "Unknown binary true division operator"
op_str = parse_inoutedge(op_str, node.inedges, node.outedges)
op_str = parse_add(op_str, node)
elif function_name.find('cat') >= 0:
op_str = parse_inoutedge(op_str, node.inedges[0], node.outedges)
op_str = parse_concat(op_str, node)
elif function_name.find('split') >= 0:
op_str = parse_inoutedge(op_str, (node.inedges[0],), node.outedges)
op_str = parse_split(op_str, node)
elif function_name.find('flatten') >= 0:
op_str = parse_inoutedge(op_str, (node.inedges[0],), node.outedges)
op_str = parse_flat(op_str, node)
elif function_name.find('relu') >= 0:
op_str = parse_inoutedge(op_str, node.inedges, node.outedges)
op_str = parse_relu(op_str, node)
elif function_name.find('getitem') >= 0:
op_str = parse_inoutedge(op_str, (node.inedges[0],), node.outedges)
op_str = parse_getitem(op_str, node)
elif function_name.find('matmul') >= 0:
op_str = parse_inoutedge(op_str, node.inedges, node.outedges)
op_str = parse_batchmatmul(op_str,node)
elif function_name.find('mul') >= 0:
if type(node.inedges[1]) is float:
op_str = parse_inoutedge(op_str, (node.inedges[0],), node.outedges)
op_str = parse_scalarmul(op_str,node)
else:
op_str = parse_inoutedge(op_str, node.inedges[0], node.outedges)
op_str = parse_mul(op_str,node)
elif function_name.find('getattr') >= 0:
op_str = parse_inoutedge(op_str, (node.inedges[0],), node.outedges)
op_str = parse_getattr(op_str, node)
elif function_name.find('transpose') >= 0:
op_str = parse_inoutedge(op_str,(node.inedges[0],), node.outedges)
op_str = parse_transpose(op_str, node)
elif function_name.find('expand') >= 0:
op_str = parse_inoutedge(op_str, (node.inedges[0],), node.outedges)
op_str = parse_expand(op_str, node)
elif function_name.find('floordiv') >= 0 or function_name.find('floor_divide') >= 0:
if type(node.inedges[1]) is float or type(node.inedges[1]) is int:
op_str = parse_inoutedge(op_str, (node.inedges[0],), node.outedges)
op_str = parse_scalarfloordiv(op_str,node)
else:
assert 0, "Tensor floor division is not supported."
elif function_name.find('reshape') >= 0:
op_str = parse_inoutedge(op_str, (node.inedges[0],), node.outedges)
op_str = parse_reshape(op_str,node)
elif function_name.find('permute') >= 0:
op_str = parse_inoutedge(op_str, (node.inedges[0],), node.outedges)
op_str = parse_permute(op_str,node)
elif function_name.find('softmax') >= 0:
op_str = parse_inoutedge(op_str, node.inedges, node.outedges)
op_str = parse_softmax(op_str, node)
else:
# Unrecogonized type
assert False, "Unrecogonized built-in function: {}".format(function_name)
if type(node) == ModuleNode:
assert len(node.inedges) == 1, "wrong format"
if type(node.module) == torch.nn.modules.linear.Linear:
op_str = parse_inoutedge(op_str, node.inedges, node.outedges)
op_str = parse_linear(op_str, node)
#parse_linear_onnx(node)
elif type(node.module) == torch.nn.modules.conv.Conv2d:
op_str = parse_inoutedge(op_str, node.inedges, node.outedges)
op_str = parse_conv2d(op_str, node)
elif type(node.module) == torch.nn.modules.pooling.MaxPool2d:
op_str = parse_inoutedge(op_str, node.inedges, node.outedges)
op_str = parse_pool2d(op_str, node, PoolType.POOL_MAX)
elif type(node.module) == torch.nn.modules.pooling.AvgPool2d:
op_str = parse_inoutedge(op_str, node.inedges, node.outedges)
op_str = parse_pool2d(op_str, node, PoolType.POOL_AVG)
elif type(node.module) == torch.nn.modules.pooling.AdaptiveAvgPool2d:
op_str = parse_inoutedge(op_str, node.inedges, node.outedges)
op_str = parse_adaptivepool2d(op_str, node, PoolType.POOL_AVG)
elif type(node.module) == torch.nn.modules.batchnorm.BatchNorm2d:
op_str = parse_inoutedge(op_str, node.inedges, node.outedges)
op_str = parse_batchnorm2d(op_str, node)
elif type(node.module) == torch.nn.modules.dropout.Dropout:
op_str = parse_inoutedge(op_str, node.inedges, node.outedges)
op_str = parse_dropout(op_str, node)
elif type(node.module) == torch.nn.modules.flatten.Flatten:
op_str = parse_inoutedge(op_str, node.inedges, node.outedges)
op_str = parse_flat(op_str, node)
elif type(node.module) == torch.nn.modules.activation.ReLU:
op_str = parse_inoutedge(op_str, node.inedges, node.outedges)
op_str = parse_relu(op_str, node)
elif type(node.module) == torch.nn.modules.activation.Sigmoid:
op_str = parse_inoutedge(op_str, node.inedges, node.outedges)
op_str = parse_sigmoid(op_str, node)
elif type(node.module) == torch.nn.modules.activation.Tanh:
op_str = parse_inoutedge(op_str, node.inedges, node.outedges)
op_str = parse_tanh(op_str, node)
elif type(node.module) == torch.nn.modules.activation.ELU:
op_str = parse_inoutedge(op_str, node.inedges, node.outedges)
op_str = parse_elu(op_str, node)
elif type(node.module) == torch.nn.modules.activation.Softmax:
op_str = parse_inoutedge(op_str, node.inedges, node.outedges)
op_str = parse_softmax(op_str, node)
elif type(node.module) == torch.nn.modules.normalization.LayerNorm:
op_str = parse_inoutedge(op_str, node.inedges, node.outedges)
op_str = parse_layernorm(op_str, node)
elif type(node.module) == torch.nn.Identity:
op_str = parse_inoutedge(op_str, node.inedges, node.outedges)
op_str = parse_identity(op_str, node)
elif type(node.module) == torch.nn.GELU:
op_str = parse_inoutedge(op_str, node.inedges, node.outedges)
op_str = parse_gelu(op_str, node)
else:
print(node.module)
assert 0, "unknown op"
print(op_str)
lines.append(op_str)
return lines