in flexflow/core/flexflow_cffi.py [0:0]
def convert_op_handle_to_op(op_type, handle, idx=None, name=None):
if op_type == OpType.CONV2D:
return Conv2D(handle, idx, name)
elif op_type == OpType.POOL2D:
return Pool2D(handle, idx, name)
elif op_type == OpType.LINEAR:
return Linear(handle, idx, name)
elif op_type == OpType.EMBEDDING:
return Embedding(handle, idx, name)
elif op_type == OpType.FLAT:
return Flat(handle, idx, name)
elif op_type == OpType.CONCAT:
return Concat(handle, idx, name)
elif op_type == OpType.SOFTMAX:
return Softmax(handle, idx, name)
elif op_type == OpType.EXP:
return Exp(handle, idx, name)
elif op_type == OpType.ADD:
return Add(handle, idx, name)
elif op_type == OpType.SUBTRACT:
return Subtract(handle, idx, name)
elif op_type == OpType.MULTIPLY:
return Multiply(handle, idx, name)
elif op_type == OpType.DIVIDE:
return Divide(handle, idx, name)
elif op_type == OpType.MSELOSS:
return MSELoss(handle, idx, name)
elif op_type == OpType.SCALAR_MULTIPLY:
return ScalarMultiply(handle, idx, name)
elif op_type == OpType.SCALAR_ADD:
return ScalarAdd(handle, idx, name)
elif op_type == OpType.SCALAR_SUB:
return ScalarSub(handle, idx, name)
elif op_type == OpType.SCALAR_FLOORDIV:
return ScalarFloorDiv(handle, idx, name)
elif op_type == OpType.SCALAR_TRUEDIV:
return ScalarTrueDiv(handle, idx, name)
elif op_type == OpType.GELU:
return Gelu(handle, idx, name)
elif op_type == OpType.RELU:
return Relu(handle, idx, name)
elif op_type == OpType.SIGMOID:
return Sigmoid(handle, idx, name)
elif op_type == OpType.TANH:
return Tanh(handle, idx, name)
elif op_type == OpType.ELU:
return Elu(handle, idx, name)
elif op_type == OpType.DROPOUT:
return Dropout(handle, idx, name)
elif op_type == OpType.BATCH_NORM:
return Batch_Norm(handle, idx, name)
elif op_type == OpType.BATCH_MATMUL:
return Batch_Matmul(handle, idx, name)
elif op_type == OpType.SPLIT:
return Split(handle, idx, name)
elif op_type == OpType.RESHAPE:
return Reshape(handle, idx, name)
elif op_type == OpType.IDENTITY:
return Identity(handle,idx,name)
elif op_type == OpType.TRANSPOSE:
return Transpose(handle, idx, name)
elif op_type == OpType.REVERSE:
return Reverse(handle, idx, name)
elif op_type == OpType.MULTIHEAD_ATTENTION:
return Reverse(handle, idx, name)
else:
assert 0, "unknow layer type {}".format(op_type)
return None