in stylegan2_ada_pytorch/training/augment.py [0:0]
def forward(self, images, debug_percentile=None):
assert isinstance(images, torch.Tensor) and images.ndim == 4
batch_size, num_channels, height, width = images.shape
device = images.device
if debug_percentile is not None:
debug_percentile = torch.as_tensor(
debug_percentile, dtype=torch.float32, device=device
)
# -------------------------------------
# Select parameters for pixel blitting.
# -------------------------------------
# Initialize inverse homogeneous 2D transform: G_inv @ pixel_out ==> pixel_in
I_3 = torch.eye(3, device=device)
G_inv = I_3
# Apply x-flip with probability (xflip * strength).
if self.xflip > 0:
i = torch.floor(torch.rand([batch_size], device=device) * 2)
i = torch.where(
torch.rand([batch_size], device=device) < self.xflip * self.p,
i,
torch.zeros_like(i),
)
if debug_percentile is not None:
i = torch.full_like(i, torch.floor(debug_percentile * 2))
G_inv = G_inv @ scale2d_inv(1 - 2 * i, 1)
# Apply 90 degree rotations with probability (rotate90 * strength).
if self.rotate90 > 0:
i = torch.floor(torch.rand([batch_size], device=device) * 4)
i = torch.where(
torch.rand([batch_size], device=device) < self.rotate90 * self.p,
i,
torch.zeros_like(i),
)
if debug_percentile is not None:
i = torch.full_like(i, torch.floor(debug_percentile * 4))
G_inv = G_inv @ rotate2d_inv(-np.pi / 2 * i)
# Apply integer translation with probability (xint * strength).
if self.xint > 0:
t = (torch.rand([batch_size, 2], device=device) * 2 - 1) * self.xint_max
t = torch.where(
torch.rand([batch_size, 1], device=device) < self.xint * self.p,
t,
torch.zeros_like(t),
)
if debug_percentile is not None:
t = torch.full_like(t, (debug_percentile * 2 - 1) * self.xint_max)
G_inv = G_inv @ translate2d_inv(
torch.round(t[:, 0] * width), torch.round(t[:, 1] * height)
)
# --------------------------------------------------------
# Select parameters for general geometric transformations.
# --------------------------------------------------------
# Apply isotropic scaling with probability (scale * strength).
if self.scale > 0:
s = torch.exp2(torch.randn([batch_size], device=device) * self.scale_std)
s = torch.where(
torch.rand([batch_size], device=device) < self.scale * self.p,
s,
torch.ones_like(s),
)
if debug_percentile is not None:
s = torch.full_like(
s,
torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.scale_std),
)
G_inv = G_inv @ scale2d_inv(s, s)
# Apply pre-rotation with probability p_rot.
p_rot = 1 - torch.sqrt(
(1 - self.rotate * self.p).clamp(0, 1)
) # P(pre OR post) = p
if self.rotate > 0:
theta = (
(torch.rand([batch_size], device=device) * 2 - 1)
* np.pi
* self.rotate_max
)
theta = torch.where(
torch.rand([batch_size], device=device) < p_rot,
theta,
torch.zeros_like(theta),
)
if debug_percentile is not None:
theta = torch.full_like(
theta, (debug_percentile * 2 - 1) * np.pi * self.rotate_max
)
G_inv = G_inv @ rotate2d_inv(-theta) # Before anisotropic scaling.
# Apply anisotropic scaling with probability (aniso * strength).
if self.aniso > 0:
s = torch.exp2(torch.randn([batch_size], device=device) * self.aniso_std)
s = torch.where(
torch.rand([batch_size], device=device) < self.aniso * self.p,
s,
torch.ones_like(s),
)
if debug_percentile is not None:
s = torch.full_like(
s,
torch.exp2(torch.erfinv(debug_percentile * 2 - 1) * self.aniso_std),
)
G_inv = G_inv @ scale2d_inv(s, 1 / s)
# Apply post-rotation with probability p_rot.
if self.rotate > 0:
theta = (
(torch.rand([batch_size], device=device) * 2 - 1)
* np.pi
* self.rotate_max
)
theta = torch.where(
torch.rand([batch_size], device=device) < p_rot,
theta,
torch.zeros_like(theta),
)
if debug_percentile is not None:
theta = torch.zeros_like(theta)
G_inv = G_inv @ rotate2d_inv(-theta) # After anisotropic scaling.
# Apply fractional translation with probability (xfrac * strength).
if self.xfrac > 0:
t = torch.randn([batch_size, 2], device=device) * self.xfrac_std
t = torch.where(
torch.rand([batch_size, 1], device=device) < self.xfrac * self.p,
t,
torch.zeros_like(t),
)
if debug_percentile is not None:
t = torch.full_like(
t, torch.erfinv(debug_percentile * 2 - 1) * self.xfrac_std
)
G_inv = G_inv @ translate2d_inv(t[:, 0] * width, t[:, 1] * height)
# ----------------------------------
# Execute geometric transformations.
# ----------------------------------
# Execute if the transform is not identity.
if G_inv is not I_3:
# Calculate padding.
cx = (width - 1) / 2
cy = (height - 1) / 2
cp = matrix(
[-cx, -cy, 1], [cx, -cy, 1], [cx, cy, 1], [-cx, cy, 1], device=device
) # [idx, xyz]
cp = G_inv @ cp.t() # [batch, xyz, idx]
Hz_pad = self.Hz_geom.shape[0] // 4
margin = cp[:, :2, :].permute(1, 0, 2).flatten(1) # [xy, batch * idx]
margin = torch.cat([-margin, margin]).max(dim=1).values # [x0, y0, x1, y1]
margin = margin + misc.constant(
[Hz_pad * 2 - cx, Hz_pad * 2 - cy] * 2, device=device
)
margin = margin.max(misc.constant([0, 0] * 2, device=device))
margin = margin.min(
misc.constant([width - 1, height - 1] * 2, device=device)
)
mx0, my0, mx1, my1 = margin.ceil().to(torch.int32)
# Pad image and adjust origin.
images = torch.nn.functional.pad(
input=images, pad=[mx0, mx1, my0, my1], mode="reflect"
)
G_inv = translate2d((mx0 - mx1) / 2, (my0 - my1) / 2) @ G_inv
# Upsample.
images = upfirdn2d.upsample2d(x=images, f=self.Hz_geom, up=2)
G_inv = (
scale2d(2, 2, device=device) @ G_inv @ scale2d_inv(2, 2, device=device)
)
G_inv = (
translate2d(-0.5, -0.5, device=device)
@ G_inv
@ translate2d_inv(-0.5, -0.5, device=device)
)
# Execute transformation.
shape = [
batch_size,
num_channels,
(height + Hz_pad * 2) * 2,
(width + Hz_pad * 2) * 2,
]
G_inv = (
scale2d(2 / images.shape[3], 2 / images.shape[2], device=device)
@ G_inv
@ scale2d_inv(2 / shape[3], 2 / shape[2], device=device)
)
grid = torch.nn.functional.affine_grid(
theta=G_inv[:, :2, :], size=shape, align_corners=False
)
images = grid_sample_gradfix.grid_sample(images, grid)
# Downsample and crop.
images = upfirdn2d.downsample2d(
x=images, f=self.Hz_geom, down=2, padding=-Hz_pad * 2, flip_filter=True
)
# --------------------------------------------
# Select parameters for color transformations.
# --------------------------------------------
# Initialize homogeneous 3D transformation matrix: C @ color_in ==> color_out
I_4 = torch.eye(4, device=device)
C = I_4
# Apply brightness with probability (brightness * strength).
if self.brightness > 0:
b = torch.randn([batch_size], device=device) * self.brightness_std
b = torch.where(
torch.rand([batch_size], device=device) < self.brightness * self.p,
b,
torch.zeros_like(b),
)
if debug_percentile is not None:
b = torch.full_like(
b, torch.erfinv(debug_percentile * 2 - 1) * self.brightness_std
)
C = translate3d(b, b, b) @ C
# Apply contrast with probability (contrast * strength).
if self.contrast > 0:
c = torch.exp2(torch.randn([batch_size], device=device) * self.contrast_std)
c = torch.where(
torch.rand([batch_size], device=device) < self.contrast * self.p,
c,
torch.ones_like(c),
)
if debug_percentile is not None:
c = torch.full_like(
c,
torch.exp2(
torch.erfinv(debug_percentile * 2 - 1) * self.contrast_std
),
)
C = scale3d(c, c, c) @ C
# Apply luma flip with probability (lumaflip * strength).
v = misc.constant(
np.asarray([1, 1, 1, 0]) / np.sqrt(3), device=device
) # Luma axis.
if self.lumaflip > 0:
i = torch.floor(torch.rand([batch_size, 1, 1], device=device) * 2)
i = torch.where(
torch.rand([batch_size, 1, 1], device=device) < self.lumaflip * self.p,
i,
torch.zeros_like(i),
)
if debug_percentile is not None:
i = torch.full_like(i, torch.floor(debug_percentile * 2))
C = (I_4 - 2 * v.ger(v) * i) @ C # Householder reflection.
# Apply hue rotation with probability (hue * strength).
if self.hue > 0 and num_channels > 1:
theta = (
(torch.rand([batch_size], device=device) * 2 - 1) * np.pi * self.hue_max
)
theta = torch.where(
torch.rand([batch_size], device=device) < self.hue * self.p,
theta,
torch.zeros_like(theta),
)
if debug_percentile is not None:
theta = torch.full_like(
theta, (debug_percentile * 2 - 1) * np.pi * self.hue_max
)
C = rotate3d(v, theta) @ C # Rotate around v.
# Apply saturation with probability (saturation * strength).
if self.saturation > 0 and num_channels > 1:
s = torch.exp2(
torch.randn([batch_size, 1, 1], device=device) * self.saturation_std
)
s = torch.where(
torch.rand([batch_size, 1, 1], device=device)
< self.saturation * self.p,
s,
torch.ones_like(s),
)
if debug_percentile is not None:
s = torch.full_like(
s,
torch.exp2(
torch.erfinv(debug_percentile * 2 - 1) * self.saturation_std
),
)
C = (v.ger(v) + (I_4 - v.ger(v)) * s) @ C
# ------------------------------
# Execute color transformations.
# ------------------------------
# Execute if the transform is not identity.
if C is not I_4:
images = images.reshape([batch_size, num_channels, height * width])
if num_channels == 3:
images = C[:, :3, :3] @ images + C[:, :3, 3:]
elif num_channels == 1:
C = C[:, :3, :].mean(dim=1, keepdims=True)
images = images * C[:, :, :3].sum(dim=2, keepdims=True) + C[:, :, 3:]
else:
raise ValueError("Image must be RGB (3 channels) or L (1 channel)")
images = images.reshape([batch_size, num_channels, height, width])
# ----------------------
# Image-space filtering.
# ----------------------
if self.imgfilter > 0:
num_bands = self.Hz_fbank.shape[0]
assert len(self.imgfilter_bands) == num_bands
expected_power = misc.constant(
np.array([10, 1, 1, 1]) / 13, device=device
) # Expected power spectrum (1/f).
# Apply amplification for each band with probability (imgfilter * strength * band_strength).
g = torch.ones(
[batch_size, num_bands], device=device
) # Global gain vector (identity).
for i, band_strength in enumerate(self.imgfilter_bands):
t_i = torch.exp2(
torch.randn([batch_size], device=device) * self.imgfilter_std
)
t_i = torch.where(
torch.rand([batch_size], device=device)
< self.imgfilter * self.p * band_strength,
t_i,
torch.ones_like(t_i),
)
if debug_percentile is not None:
t_i = (
torch.full_like(
t_i,
torch.exp2(
torch.erfinv(debug_percentile * 2 - 1)
* self.imgfilter_std
),
)
if band_strength > 0
else torch.ones_like(t_i)
)
t = torch.ones(
[batch_size, num_bands], device=device
) # Temporary gain vector.
t[:, i] = t_i # Replace i'th element.
t = (
t / (expected_power * t.square()).sum(dim=-1, keepdims=True).sqrt()
) # Normalize power.
g = g * t # Accumulate into global gain.
# Construct combined amplification filter.
Hz_prime = g @ self.Hz_fbank # [batch, tap]
Hz_prime = Hz_prime.unsqueeze(1).repeat(
[1, num_channels, 1]
) # [batch, channels, tap]
Hz_prime = Hz_prime.reshape(
[batch_size * num_channels, 1, -1]
) # [batch * channels, 1, tap]
# Apply filter.
p = self.Hz_fbank.shape[1] // 2
images = images.reshape([1, batch_size * num_channels, height, width])
images = torch.nn.functional.pad(
input=images, pad=[p, p, p, p], mode="reflect"
)
images = conv2d_gradfix.conv2d(
input=images,
weight=Hz_prime.unsqueeze(2),
groups=batch_size * num_channels,
)
images = conv2d_gradfix.conv2d(
input=images,
weight=Hz_prime.unsqueeze(3),
groups=batch_size * num_channels,
)
images = images.reshape([batch_size, num_channels, height, width])
# ------------------------
# Image-space corruptions.
# ------------------------
# Apply additive RGB noise with probability (noise * strength).
if self.noise > 0:
sigma = (
torch.randn([batch_size, 1, 1, 1], device=device).abs() * self.noise_std
)
sigma = torch.where(
torch.rand([batch_size, 1, 1, 1], device=device) < self.noise * self.p,
sigma,
torch.zeros_like(sigma),
)
if debug_percentile is not None:
sigma = torch.full_like(
sigma, torch.erfinv(debug_percentile) * self.noise_std
)
images = (
images
+ torch.randn([batch_size, num_channels, height, width], device=device)
* sigma
)
# Apply cutout with probability (cutout * strength).
if self.cutout > 0:
size = torch.full([batch_size, 2, 1, 1, 1], self.cutout_size, device=device)
size = torch.where(
torch.rand([batch_size, 1, 1, 1, 1], device=device)
< self.cutout * self.p,
size,
torch.zeros_like(size),
)
center = torch.rand([batch_size, 2, 1, 1, 1], device=device)
if debug_percentile is not None:
size = torch.full_like(size, self.cutout_size)
center = torch.full_like(center, debug_percentile)
coord_x = torch.arange(width, device=device).reshape([1, 1, 1, -1])
coord_y = torch.arange(height, device=device).reshape([1, 1, -1, 1])
mask_x = ((coord_x + 0.5) / width - center[:, 0]).abs() >= size[:, 0] / 2
mask_y = ((coord_y + 0.5) / height - center[:, 1]).abs() >= size[:, 1] / 2
mask = torch.logical_or(mask_x, mask_y).to(torch.float32)
images = images * mask
return images