def convert_op_handle_to_op()

in flexflow/core/flexflow_cffi.py [0:0]


def convert_op_handle_to_op(op_type, handle, idx=None, name=None):
  if op_type == OpType.CONV2D:
    return Conv2D(handle, idx, name)
  elif op_type == OpType.POOL2D:
    return Pool2D(handle, idx, name)
  elif op_type == OpType.LINEAR:
    return Linear(handle, idx, name)
  elif op_type == OpType.EMBEDDING:
    return Embedding(handle, idx, name)
  elif op_type == OpType.FLAT:
    return Flat(handle, idx, name)
  elif op_type == OpType.CONCAT:
    return Concat(handle, idx, name)
  elif op_type == OpType.SOFTMAX:
    return Softmax(handle, idx, name)
  elif op_type == OpType.EXP:
    return Exp(handle, idx, name)
  elif op_type == OpType.ADD:
    return Add(handle, idx, name)
  elif op_type == OpType.SUBTRACT:
    return Subtract(handle, idx, name)
  elif op_type == OpType.MULTIPLY:
    return Multiply(handle, idx, name)
  elif op_type == OpType.DIVIDE:
    return Divide(handle, idx, name)
  elif op_type == OpType.MSELOSS:
    return MSELoss(handle, idx, name)
  elif op_type == OpType.SCALAR_MULTIPLY:
    return ScalarMultiply(handle, idx, name)
  elif op_type == OpType.SCALAR_ADD:
      return ScalarAdd(handle, idx, name)
  elif op_type == OpType.SCALAR_SUB:
      return ScalarSub(handle, idx, name)
  elif op_type == OpType.SCALAR_FLOORDIV:
      return ScalarFloorDiv(handle, idx, name)
  elif op_type == OpType.SCALAR_TRUEDIV:
      return ScalarTrueDiv(handle, idx, name)
  elif op_type == OpType.GELU:
    return Gelu(handle, idx, name)
  elif op_type == OpType.RELU:
    return Relu(handle, idx, name)
  elif op_type == OpType.SIGMOID:
    return Sigmoid(handle, idx, name)
  elif op_type == OpType.TANH:
    return Tanh(handle, idx, name)
  elif op_type == OpType.ELU:
    return Elu(handle, idx, name)
  elif op_type == OpType.DROPOUT:
    return Dropout(handle, idx, name)
  elif op_type == OpType.BATCH_NORM:
    return Batch_Norm(handle, idx, name)
  elif op_type == OpType.BATCH_MATMUL:
    return Batch_Matmul(handle, idx, name)
  elif op_type == OpType.SPLIT:
    return Split(handle, idx, name)
  elif op_type == OpType.RESHAPE:
    return Reshape(handle, idx, name)
  elif op_type == OpType.IDENTITY:
    return Identity(handle,idx,name)
  elif op_type == OpType.TRANSPOSE:
    return Transpose(handle, idx, name)
  elif op_type == OpType.REVERSE:
    return Reverse(handle, idx, name)
  elif op_type == OpType.MULTIHEAD_ATTENTION:
    return Reverse(handle, idx, name)
  else:
    assert 0, "unknow layer type {}".format(op_type)
    return None