in flexflow/torch/model.py [0:0]
def apply(self, ffmodel, input_tensors):
output_tensors = []
input_idx = 0
for line in self.lines:
items = line.strip().split(",")
assert len(items) >= 3, "wrong format"
items = [i.strip() for i in items]
print(items)
#get op name
op_name = items[0]
#get input ops' name
self.input_ops_list = items[1].split(":")
self.input_ops_list = [i.strip() for i in self.input_ops_list]
for i in self.input_ops_list:
if i == "":
self.input_ops_list.remove(i)
#get output ops' name
self.output_ops_list = items[2].split(":")
self.output_ops_list = [i.strip() for i in self.output_ops_list]
for i in self.output_ops_list:
if i == "":
self.output_ops_list.remove(i)
#get op type
op_type = str_to_enum(OpType, items[3])
if op_type == OpType.INPUT:
assert len(self.input_ops_list) == 0, "wrong format"
output = input_tensors[input_idx]
output = FXTensor(output)
input_idx += 1
elif op_type == OpType.LINEAR:
assert len(items) == 7, "wrong format"
assert len(self.input_ops_list) == 1, "wrong format"
input_tensor = self.tensor_dict[self._get_input_key(op_name, 0)].fftensor
od = int(items[4])
activ = int_to_enum(ActiMode, int(items[5]))
bias = bool(int(items[6]))
output = ffmodel.dense(input=input_tensor, out_dim=od, activation=activ, use_bias=bias, name=op_name)
output = FXTensor(output)
elif op_type == OpType.CONV2D:
assert len(items) == 14, "wrong format"
assert len(self.input_ops_list) == 1, "wrong format"
input_tensor = self.tensor_dict[self._get_input_key(op_name, 0)].fftensor
oc = int(items[4])
kh = int(items[5])
kw = int(items[6])
sh = int(items[7])
sw = int(items[8])
ph = int(items[9])
pw = int(items[10])
activ = int_to_enum(ActiMode, int(items[11]))
group = int(items[12])
bias = bool(int(items[13]))
output = ffmodel.conv2d(input=input_tensor, out_channels=oc, kernel_h=kh, kernel_w=kw, stride_h=sh, stride_w=sw, padding_h=ph, padding_w=pw, activation=activ, groups=group, use_bias=bias, name=op_name)
output = FXTensor(output)
elif op_type == OpType.POOL2D:
assert len(items) == 9, "wrong format"
assert len(self.input_ops_list) == 1, "wrong format"
input_tensor = self.tensor_dict[self._get_input_key(op_name, 0)].fftensor
kh = int(items[4])
sh = int(items[5])
ph = int(items[6])
pt = int_to_enum(PoolType, int(items[7]))
activ = int_to_enum(ActiMode, int(items[8]))
output = ffmodel.pool2d(input=input_tensor, kernel_h=kh, kernel_w=kh, stride_h=sh, stride_w=sh, padding_h=ph, padding_w=ph, pool_type=pt, activation=activ, name=op_name)
output = FXTensor(output)
elif op_type == OpType.DROPOUT:
assert len(items) == 5, "wrong format"
assert len(self.input_ops_list) == 1, "wrong format"
input_tensor = self.tensor_dict[self._get_input_key(op_name, 0)].fftensor
r = float(items[4])
output = ffmodel.dropout(input=input_tensor, rate=r, seed=0, name=op_name)
output = FXTensor(output)
elif op_type == OpType.FLAT:
assert len(items) == 4, "wrong format"
assert len(self.input_ops_list) == 1, "wrong format"
input_tensor = self.tensor_dict[self._get_input_key(op_name, 0)].fftensor
output = ffmodel.flat(input=input_tensor, name=op_name)
output = FXTensor(output)
elif op_type == OpType.SCALAR_MULTIPLY:
assert len(items) == 5, "wrong format"
assert len(self.input_ops_list) == 1, "wrong format"
input_tensor = self.tensor_dict[self._get_input_key(op_name, 0)].fftensor
output = ffmodel.scalar_multiply(input=input_tensor, scalar=float(items[4]), name=op_name)
output = FXTensor(output)
elif op_type == OpType.SCALAR_FLOORDIV:
assert len(items) == 5, "wrong format"
assert len(self.input_ops_list) == 1, "wrong format"
input_tensor = self.tensor_dict[self._get_input_key(op_name, 0)].fftensor
if type(input_tensor) is float or type(input_tensor) is int:
output = input_tensor // float(items[4])
else:
assert 0, "Tensor floor division is not supported."
output = FXTensor(output)
elif op_type == OpType.SCALAR_ADD:
assert len(items) == 5, "wrong format"
assert len(self.input_ops_list) == 1, "wrong format"
input_tensor = self.tensor_dict[self._get_input_key(op_name, 0)].fftensor
output = ffmodel.scalar_add(input=input_tensor, scalar=float(items[4]), name=op_name)
output = FXTensor(output)
elif op_type == OpType.SCALAR_SUB:
assert len(items) == 5, "wrong format"
assert len(self.input_ops_list) == 1, "wrong format"
input_tensor = self.tensor_dict[self._get_input_key(op_name, 0)].fftensor
output = ffmodel.scalar_sub(input=input_tensor, scalar=float(items[4]), name=op_name)
output = FXTensor(output)
elif op_type == OpType.SCALAR_TRUEDIV:
assert len(items) == 5, "wrong format"
assert len(self.input_ops_list) == 1, "wrong format"
input_tensor = self.tensor_dict[self._get_input_key(op_name, 0)].fftensor
output = ffmodel.scalar_true_divide(input=input_tensor, scalar=float(items[4]), name=op_name)
output = FXTensor(output)
elif op_type == OpType.RELU:
assert len(items) == 4, "wrong format"
assert len(self.input_ops_list) == 1, "wrong format"
input_tensor = self.tensor_dict[self._get_input_key(op_name, 0)].fftensor
output = ffmodel.relu(input=input_tensor, name=op_name)
output = FXTensor(output)
elif op_type == OpType.GELU:
assert len(items) == 4, "wrong format"
assert len(self.input_ops_list) == 1, "wrong format"
input_tensor = self.tensor_dict[self._get_input_key(op_name, 0)].fftensor
output = ffmodel.gelu(input=input_tensor, name=op_name)
output = FXTensor(output)
elif op_type == OpType.IDENTITY:
assert len(items) == 4, "wrong format"
assert len(self.input_ops_list) == 1, "wrong format"
input_tensor = self.tensor_dict[self._get_input_key(op_name, 0)].fftensor
output = ffmodel.identity(input=input_tensor, name=op_name)
output = FXTensor(output)
elif op_type == OpType.LAYER_NORM:
assert len(items) == 4, "wrong format"
assert len(self.input_ops_list) == 1, "wrong format"
input_tensor = self.tensor_dict[self._get_input_key(op_name, 0)].fftensor
output = ffmodel.identity(input=input_tensor, name=op_name)
output = FXTensor(output)
elif op_type == OpType.EXPAND:
assert len(items) >= 4, "wrong format"
assert len(self.input_ops_list) == 1, "wrong format"
input_tensor = self.tensor_dict[self._get_input_key(op_name, 0)].fftensor
output = ffmodel.identity(input=input_tensor, name=op_name)
output = FXTensor(output)
elif op_type == OpType.TRANSPOSE:
assert len(items) >= 6
assert len(self.input_ops_list) == 1, "wrong format"
input_tensor = self.tensor_dict[self._get_input_key(op_name, 0)].fftensor
perm = list(range(1,len(input_tensor.dims)+1))
a,b = int(items[4]),int(items[5])
perm[a-1],perm[b-1] = perm[b-1],perm[a-1]
output = ffmodel.transpose(input=input_tensor,perm=perm,name=op_name)
output = FXTensor(output)
elif op_type == OpType.PERMUTE:
assert len(items) > 4
assert len(self.input_ops_list) == 1, "wrong format"
input_tensor = self.tensor_dict[self._get_input_key(op_name, 0)].fftensor
perm = [int(dim) for dim in items[4:]]
output = ffmodel.transpose(input=input_tensor,perm=perm,name=op_name)
output = FXTensor(output)
elif op_type == OpType.RESHAPE:
assert len(items) >= 5
assert len(self.input_ops_list) == 1, "wrong format"
input_tensor = self.tensor_dict[self._get_input_key(op_name, 0)].fftensor
shape = items[4:]
for idx,dim in enumerate(shape):
try:
shape[idx] = int(dim)
except:
shape[idx] = self.tensor_dict[dim+op_name].fftensor
output = ffmodel.reshape(input=input_tensor,shape=shape,name=op_name)
output = FXTensor(output)
elif op_type == OpType.BATCH_MATMUL:
assert len(items) == 4, "wrong format"
assert len(self.input_ops_list) == 2, "wrong format"
input_tensor1 = self.tensor_dict[self._get_input_key(op_name, 0)].fftensor
input_tensor2 = self.tensor_dict[self._get_input_key(op_name, 1)].fftensor
output = ffmodel.batch_matmul(A=input_tensor1, B=input_tensor2, name=op_name)
output = FXTensor(output)
elif op_type == OpType.SIGMOID:
assert len(items) == 4, "wrong format"
assert len(self.input_ops_list) == 1, "wrong format"
input_tensor = self.tensor_dict[self._get_input_key(op_name, 0)].fftensor
output = ffmodel.sigmoid(input=input_tensor, name=op_name)
output = FXTensor(output)
elif op_type == OpType.TANH:
assert len(items) == 4, "wrong format"
assert len(self.input_ops_list) == 1, "wrong format"
input_tensor = self.tensor_dict[self._get_input_key(op_name, 0)].fftensor
output = ffmodel.tanh(input=input_tensor, name=op_name)
output = FXTensor(output)
elif op_type == OpType.ELU:
assert len(items) == 4, "wrong format"
assert len(self.input_ops_list) == 1, "wrong format"
input_tensor = self.tensor_dict[self._get_input_key(op_name, 0)].fftensor
output = ffmodel.elu(input=input_tensor, name=op_name)
output = FXTensor(output)
elif op_type == OpType.SOFTMAX:
assert len(items) == 4, "wrong format"
assert len(self.input_ops_list) == 1, "wrong format"
input_tensor = self.tensor_dict[self._get_input_key(op_name, 0)].fftensor
output = ffmodel.softmax(input=input_tensor, name=op_name)
output = FXTensor(output)
elif op_type == OpType.CONCAT:
assert len(items) == 5, "wrong format"
assert len(self.input_ops_list) >= 2, "wrong format"
input_tensors = []
for i in range(0, len(self.input_ops_list)):
input_tensors.append(self.tensor_dict[self._get_input_key(op_name, i)].fftensor)
ax = int(items[4])
output = ffmodel.concat(tensors=input_tensors, axis=ax, name=op_name)
output = FXTensor(output)
elif op_type == OpType.SPLIT:
assert len(items) == 5, "wrong format"
assert len(self.input_ops_list) == 1, "wrong format"
size = len(self.output_ops_list)
assert size >= 2, "wrong format"
input_tensor = self.tensor_dict[self._get_input_key(op_name, 0)].fftensor
ax = int(items[4])
output = ffmodel.split(input=input_tensor, sizes=size, axis=ax, name=op_name)
assert type(output) == list
output = FXTensor(output)
elif op_type == OpType.GETITEM:
assert len(items) == 5, "wrong format"
assert len(self.input_ops_list) == 1, "wrong format"
input_tensor = self.tensor_dict[self._get_input_key(op_name, 0)].fftensor
assert type(input_tensor) == list or type(input_tensor) == tuple
idx = int(items[4])
output = input_tensor[idx]
output = FXTensor(output)
elif op_type == OpType.GETATTR:
assert len(items) == 5, "wrong format"
assert len(self.input_ops_list) == 1, "wrong format"
input_tensor = self.tensor_dict[self._get_input_key(op_name, 0)].fftensor
if(items[4] == "shape"):
output = input_tensor.dims
else:
output = getattr(input_tensor, items[4])
output = FXTensor(output)
elif op_type == OpType.BATCH_NORM:
assert len(items) == 4, "wrong format"
assert len(self.input_ops_list) == 1, "wrong format"
input_tensor = self.tensor_dict[self._get_input_key(op_name, 0)].fftensor
output = ffmodel.batch_norm(input=input_tensor, name=op_name)
output = FXTensor(output)
elif op_type == OpType.ADD:
assert len(items) == 4, "wrong format"
assert len(self.input_ops_list) == 2, "wrong format"
input_tensor1 = self.tensor_dict[self._get_input_key(op_name, 0)].fftensor
input_tensor2 = self.tensor_dict[self._get_input_key(op_name, 1)].fftensor
output = ffmodel.add(x=input_tensor1, y=input_tensor2, name=op_name)
output = FXTensor(output)
elif op_type == OpType.MULTIPLY:
assert len(items) == 4, "wrong format"
assert len(self.input_ops_list) == 2, "wrong format"
input_tensor1 = self.tensor_dict[self._get_input_key(op_name, 0)].fftensor
input_tensor2 = self.tensor_dict[self._get_input_key(op_name, 1)].fftensor
output = ffmodel.multiply(x=input_tensor1, y=input_tensor2, name=op_name)
output = FXTensor(output)
elif op_type == OpType.OUTPUT:
assert len(self.input_ops_list) >= 1, "wrong format"
for i in range(0, len(self.input_ops_list)):
output_tensors.append(self.tensor_dict[self._get_input_key(op_name, i)].fftensor)
output = None
#print(output_tensors[1].handle.impl)
else:
print(op_type)
assert 0, "unknown op"
if type(output) == FXTensor:
for i in range(0, len(self.output_ops_list)):
self.tensor_dict[self._get_output_key(op_name, i)] = output
elif output == None:
pass
else:
assert 0
#self.tensor_dict[self._get_output_key(op_name, 0)] = output
return output_tensors