def torch_to_flexflow_str()

in flexflow/torch/fx.py [0:0]


def torch_to_flexflow_str(model):
  graph = __symbolic_trace(model)
  lines = []

  for name,parameter in model.named_parameters():
      splitted_name = name.split(".")
      if not (splitted_name[-1] in ["weight","bias"]):
          fx_name = "_"+"_".join(splitted_name)
          print(fx_name)
          op_str = fx_name+", "
          op_str = parse_inoutedge(op_str,(),())
          op_str = parse_parameter(op_str,parameter)
          lines.append(op_str)
  
  for node in graph:
    # op name
    op_str = node.name + ", "
    print(node.name, type(node))
    
    #op type
    if type(node) == InputNode:
      op_str = parse_inoutedge(op_str, node.inedges, node.outedges)
      op_str = parse_input(op_str, node)
      
    if type(node) == OutputNode:
      if type(node.inedges[0]) == tuple:
        op_str = parse_inoutedge(op_str, node.inedges[0], node.outedges)
      else:
        op_str = parse_inoutedge(op_str, node.inedges, node.outedges)
      op_str = parse_output(op_str, node)
    
    if type(node) == FunctionNode:
      function_name = str(node.function)
      if function_name.find('add') >= 0:
        if type(node.inedges[1]) is float:
            op_str = parse_inoutedge(op_str, (node.inedges[0],), node.outedges)
            op_str = parse_scalaradd(op_str,node)
        else:
            op_str = parse_inoutedge(op_str, node.inedges, node.outedges)
            op_str = parse_add(op_str, node)
        
      elif function_name.find('sub') >= 0:
        if type(node.inedges[1]) is float:
            op_str = parse_inoutedge(op_str, (node.inedges[0],), node.outedges)
            op_str = parse_scalarsub(op_str,node)
        else:
            assert 0, "Unknown binary subtraction operator"
            op_str = parse_inoutedge(op_str, node.inedges, node.outedges)
            op_str = parse_add(op_str, node)
      
      elif function_name.find('truediv') >= 0:
        if type(node.inedges[1]) is float:
            op_str = parse_inoutedge(op_str, (node.inedges[0],), node.outedges)
            op_str = parse_scalartruediv(op_str,node)
        else:
            assert 0, "Unknown binary true division operator"
            op_str = parse_inoutedge(op_str, node.inedges, node.outedges)
            op_str = parse_add(op_str, node)
      
      elif function_name.find('cat') >= 0:
        op_str = parse_inoutedge(op_str, node.inedges[0], node.outedges)
        op_str = parse_concat(op_str, node)
        
      elif function_name.find('split') >= 0:
        op_str = parse_inoutedge(op_str, (node.inedges[0],), node.outedges)
        op_str = parse_split(op_str, node)
      
      elif function_name.find('flatten') >= 0:
        op_str = parse_inoutedge(op_str, (node.inedges[0],), node.outedges)
        op_str = parse_flat(op_str, node)
        
      elif function_name.find('relu') >= 0:
        op_str = parse_inoutedge(op_str, node.inedges, node.outedges)
        op_str = parse_relu(op_str, node)
        
      elif function_name.find('getitem') >= 0:
        op_str = parse_inoutedge(op_str, (node.inedges[0],), node.outedges)
        op_str = parse_getitem(op_str, node)

      elif function_name.find('matmul') >= 0:
        op_str = parse_inoutedge(op_str, node.inedges, node.outedges)
        op_str = parse_batchmatmul(op_str,node)

      elif function_name.find('mul') >= 0:
          if type(node.inedges[1]) is float:
            op_str = parse_inoutedge(op_str, (node.inedges[0],), node.outedges)
            op_str = parse_scalarmul(op_str,node)
          else:
            op_str = parse_inoutedge(op_str, node.inedges[0], node.outedges)
            op_str = parse_mul(op_str,node)
      
      elif function_name.find('getattr') >= 0:
        op_str = parse_inoutedge(op_str, (node.inedges[0],), node.outedges)
        op_str = parse_getattr(op_str, node)
     
      elif function_name.find('transpose') >= 0:
        op_str = parse_inoutedge(op_str,(node.inedges[0],), node.outedges)
        op_str = parse_transpose(op_str, node) 

      elif function_name.find('expand') >= 0:
        op_str = parse_inoutedge(op_str, (node.inedges[0],), node.outedges)
        op_str = parse_expand(op_str, node)
        
      elif function_name.find('floordiv') >= 0 or function_name.find('floor_divide') >= 0:
        if type(node.inedges[1]) is float or type(node.inedges[1]) is int:
            op_str = parse_inoutedge(op_str, (node.inedges[0],), node.outedges)
            op_str = parse_scalarfloordiv(op_str,node)
        else:
            assert 0, "Tensor floor division is not supported."

      elif function_name.find('reshape') >= 0:
        op_str = parse_inoutedge(op_str, (node.inedges[0],), node.outedges)
        op_str = parse_reshape(op_str,node)

      elif function_name.find('permute') >= 0:
        op_str = parse_inoutedge(op_str, (node.inedges[0],), node.outedges)
        op_str = parse_permute(op_str,node)
     
      elif function_name.find('softmax') >= 0:
        op_str = parse_inoutedge(op_str, node.inedges, node.outedges)
        op_str = parse_softmax(op_str, node)

      else:
        # Unrecogonized type 
        assert False, "Unrecogonized built-in function: {}".format(function_name)
    
    if type(node) == ModuleNode:
      assert len(node.inedges) == 1, "wrong format"
      
      if type(node.module) == torch.nn.modules.linear.Linear:
        op_str = parse_inoutedge(op_str, node.inedges, node.outedges)
        op_str = parse_linear(op_str, node)
        #parse_linear_onnx(node)
      
      elif type(node.module) == torch.nn.modules.conv.Conv2d:
        op_str = parse_inoutedge(op_str, node.inedges, node.outedges)
        op_str = parse_conv2d(op_str, node)
          
      elif type(node.module) == torch.nn.modules.pooling.MaxPool2d:
        op_str = parse_inoutedge(op_str, node.inedges, node.outedges)
        op_str = parse_pool2d(op_str, node, PoolType.POOL_MAX)
        
      elif type(node.module) == torch.nn.modules.pooling.AvgPool2d:
        op_str = parse_inoutedge(op_str, node.inedges, node.outedges)
        op_str = parse_pool2d(op_str, node, PoolType.POOL_AVG)
        
      elif type(node.module) == torch.nn.modules.pooling.AdaptiveAvgPool2d:
        op_str = parse_inoutedge(op_str, node.inedges, node.outedges)
        op_str = parse_adaptivepool2d(op_str, node, PoolType.POOL_AVG)
        
      elif type(node.module) == torch.nn.modules.batchnorm.BatchNorm2d:
        op_str = parse_inoutedge(op_str, node.inedges, node.outedges)
        op_str = parse_batchnorm2d(op_str, node)

      elif type(node.module) == torch.nn.modules.dropout.Dropout:
        op_str = parse_inoutedge(op_str, node.inedges, node.outedges)
        op_str = parse_dropout(op_str, node)
        
      elif type(node.module) == torch.nn.modules.flatten.Flatten:
        op_str = parse_inoutedge(op_str, node.inedges, node.outedges)
        op_str = parse_flat(op_str, node)
          
      elif type(node.module) == torch.nn.modules.activation.ReLU:
        op_str = parse_inoutedge(op_str, node.inedges, node.outedges)
        op_str = parse_relu(op_str, node)
        
      elif type(node.module) == torch.nn.modules.activation.Sigmoid:
        op_str = parse_inoutedge(op_str, node.inedges, node.outedges)
        op_str = parse_sigmoid(op_str, node)
        
      elif type(node.module) == torch.nn.modules.activation.Tanh:
        op_str = parse_inoutedge(op_str, node.inedges, node.outedges)
        op_str = parse_tanh(op_str, node)
        
      elif type(node.module) == torch.nn.modules.activation.ELU:
        op_str = parse_inoutedge(op_str, node.inedges, node.outedges)
        op_str = parse_elu(op_str, node)
        
      elif type(node.module) == torch.nn.modules.activation.Softmax:
        op_str = parse_inoutedge(op_str, node.inedges, node.outedges)
        op_str = parse_softmax(op_str, node)
      
      elif type(node.module) == torch.nn.modules.normalization.LayerNorm:
        op_str = parse_inoutedge(op_str, node.inedges, node.outedges)
        op_str = parse_layernorm(op_str, node)

      elif type(node.module) == torch.nn.Identity:
        op_str = parse_inoutedge(op_str, node.inedges, node.outedges)
        op_str = parse_identity(op_str, node)

      elif type(node.module) == torch.nn.GELU:
        op_str = parse_inoutedge(op_str, node.inedges, node.outedges)
        op_str = parse_gelu(op_str, node)

      else:
        print(node.module)
        assert 0, "unknown op"
      
    print(op_str)
    lines.append(op_str)
    
  return lines