in src/util/flip_loss.py [0:0]
def color_space_transform(input_color, fromSpace2toSpace):
dim = input_color.size()
if fromSpace2toSpace == "srgb2linrgb":
input_color = torch.clamp(input_color, 0.0, 1.0) # clamp added to stabilize training
limit = 0.04045
transformed_color = torch.where(input_color > limit, torch.pow((input_color + 0.055) / 1.055, 2.4), input_color / 12.92)
elif fromSpace2toSpace == "linrgb2srgb":
input_color = torch.clamp(input_color, 0.0, 1.0) # clamp added to stabilize training
limit = 0.0031308
transformed_color = torch.where(input_color > limit, 1.055 * (input_color ** (1.0 / 2.4)) - 0.055, 12.92 * input_color)
elif fromSpace2toSpace == "linrgb2xyz" or fromSpace2toSpace == "xyz2linrgb":
# Source: https://www.image-engineering.de/library/technotes/958-how-to-convert-between-srgb-and-ciexyz
# Assumes D65 standard illuminant
a11 = 10135552 / 24577794
a12 = 8788810 / 24577794
a13 = 4435075 / 24577794
a21 = 2613072 / 12288897
a22 = 8788810 / 12288897
a23 = 887015 / 12288897
a31 = 1425312 / 73733382
a32 = 8788810 / 73733382
a33 = 70074185 / 73733382
A = torch.Tensor([[a11, a12, a13],
[a21, a22, a23],
[a31, a32, a33]])
input_color = input_color.view(dim[0], dim[1], dim[2] * dim[3]).cuda() # NC(HW)
if fromSpace2toSpace == "xyz2linrgb":
A = torch.inverse(A)
transformed_color = torch.matmul(A.cuda(), input_color)
transformed_color = transformed_color.view(dim[0], dim[1], dim[2], dim[3])
elif fromSpace2toSpace == "xyz2ycxcz":
reference_illuminant = color_space_transform(torch.ones(dim), 'linrgb2xyz')
input_color = torch.div(input_color, reference_illuminant)
y = 116 * input_color[:, 1:2, :, :] - 16
cx = 500 * (input_color[:, 0:1, :, :] - input_color[:, 1:2, :, :])
cz = 200 * (input_color[:, 1:2, :, :] - input_color[:, 2:3, :, :])
transformed_color = torch.cat((y, cx, cz), 1)
elif fromSpace2toSpace == "ycxcz2xyz":
y = (input_color[:, 0:1, :, :] + 16) / 116
cx = input_color[:, 1:2, :, :] / 500
cz = input_color[:, 2:3, :, :] / 200
x = y + cx
z = y - cz
transformed_color = torch.cat((x, y, z), 1)
reference_illuminant = color_space_transform(torch.ones(dim), 'linrgb2xyz')
transformed_color = torch.mul(transformed_color, reference_illuminant)
elif fromSpace2toSpace == "xyz2lab":
reference_illuminant = color_space_transform(torch.ones(dim), 'linrgb2xyz')
input_color = torch.div(input_color, reference_illuminant)
delta = 6 / 29
limit = 0.00885
input_color = torch.where(input_color > limit, torch.pow(input_color, 1 / 3), (input_color / (3 * delta * delta)) + (4 / 29))
l = 116 * input_color[:, 1:2, :, :] - 16
a = 500 * (input_color[:, 0:1, :, :] - input_color[:, 1:2, :, :])
b = 200 * (input_color[:, 1:2, :, :] - input_color[:, 2:3, :, :])
transformed_color = torch.cat((l, a, b), 1)
elif fromSpace2toSpace == "lab2xyz":
y = (input_color[:, 0:1, :, :] + 16) / 116
a = input_color[:, 1:2, :, :] / 500
b = input_color[:, 2:3, :, :] / 200
x = y + a
z = y - b
xyz = torch.cat((x, y, z), 1)
delta = 6 / 29
xyz = torch.where(xyz > delta, xyz ** 3, 3 * delta ** 2 * (xyz - 4 / 29))
reference_illuminant = color_space_transform(torch.ones(dim), 'linrgb2xyz')
transformed_color = torch.mul(xyz, reference_illuminant)
elif fromSpace2toSpace == "srgb2xyz":
transformed_color = color_space_transform(input_color, 'srgb2linrgb')
transformed_color = color_space_transform(transformed_color, 'linrgb2xyz')
elif fromSpace2toSpace == "srgb2ycxcz":
transformed_color = color_space_transform(input_color, 'srgb2linrgb')
transformed_color = color_space_transform(transformed_color, 'linrgb2xyz')
transformed_color = color_space_transform(transformed_color, 'xyz2ycxcz')
elif fromSpace2toSpace == "linrgb2ycxcz":
transformed_color = color_space_transform(input_color, 'linrgb2xyz')
transformed_color = color_space_transform(transformed_color, 'xyz2ycxcz')
elif fromSpace2toSpace == "srgb2lab":
transformed_color = color_space_transform(input_color, 'srgb2linrgb')
transformed_color = color_space_transform(transformed_color, 'linrgb2xyz')
transformed_color = color_space_transform(transformed_color, 'xyz2lab')
elif fromSpace2toSpace == "linrgb2lab":
transformed_color = color_space_transform(input_color, 'linrgb2xyz')
transformed_color = color_space_transform(transformed_color, 'xyz2lab')
elif fromSpace2toSpace == "ycxcz2linrgb":
transformed_color = color_space_transform(input_color, 'ycxcz2xyz')
transformed_color = color_space_transform(transformed_color, 'xyz2linrgb')
elif fromSpace2toSpace == "lab2srgb":
transformed_color = color_space_transform(input_color, 'lab2xyz')
transformed_color = color_space_transform(transformed_color, 'xyz2linrgb')
transformed_color = color_space_transform(transformed_color, 'linrgb2srgb')
elif fromSpace2toSpace == "ycxcz2lab":
transformed_color = color_space_transform(input_color, 'ycxcz2xyz')
transformed_color = color_space_transform(transformed_color, 'xyz2lab')
else:
print('The color transform is not defined!')
transformed_color = input_color
return transformed_color