in model/transformer_gan.py [0:0]
def forward(self, data, target, reset_mems, train_loss, mems=None, status_vec=None,
update_D0=False):
# loss type can be "mle","gen_loss", "dis_loss" or "mle_and_gen_loss"
return_dict = {"mle": None, "gen_loss": None, "dis_loss": None, "mems": None}
if "mle" in train_loss:
ret = self.generator(data, target, reset_mems, mems, status_vec=status_vec)
return_dict["mle"], return_dict["mems"] = ret
# Sample a sequence
if "gen" in train_loss or "dis" in train_loss or "classifier" in train_loss:
# TODO: low priority could potentially make forward_generate a static func?
# Cache params
cache_tgt_len, cache_mem_len = (
self.generator.tgt_len, self.generator.mem_len
)
# Reset params for sampling
self.generator.reset_length(
1, self.cfg.DISCRIMINATOR.mem_len
) # Use mem_len=bert_len
# Generate samples
# First token has to be expanded into one-hot if data is in index form
# sample_mems can be greater than dis mem_len but only for one forward generate pass
def process_for_sequence(inp):
if len(inp.shape) == 1:
# Use F.one_hot(,self.n_token)
return (
inp.new_zeros(
(*inp.shape, self.ntokens), dtype=torch.float32
).scatter_(-1, inp[..., None], 1)
)
elif len(inp.shape) == 2:
return (inp)
else:
raise NotImplementedError
seq = []
sample_mems = None
# TODO: When training gen do not pass only context into dis (since no grads anyway)
# TODO: do not loop over context
status_vec = None
with torch.no_grad():
# Last token in context is used to feed forward for sequential generation
if self.cfg.DISCRIMINATOR.context_len > 1:
context = data[:self.cfg.DISCRIMINATOR.context_len - 1]
if self.cfg.TRAIN.append_note_status:
bptt, batch_size = context.shape
status_vec = context.new_zeros((bptt, batch_size, self.vec_len), dtype=torch.bool)
self.vocab.update_status_vec(context, status_vec)
ret = self.generator.forward_generate(context, sample_mems,
status_vec=status_vec)
_, sample_mems = ret
sample_len = self.cfg.DISCRIMINATOR.tgt_len // self.cfg.DISCRIMINATOR.sample_chunks_mem
# Do not detach the gradient w.r.t memory. Choose to manually detach
self.generator.detach_mems_grad = False
gen_loss, dis_loss, gp_loss = 0, 0, 0
# Split real and fake samples according to sample_chunks_mem
for chunk_start in range(0, self.cfg.DISCRIMINATOR.tgt_len, sample_len):
chunk_end = min(chunk_start + sample_len, self.cfg.DISCRIMINATOR.tgt_len)
# TODO: Can we retain sub graph after calling backward?
for ind in range(chunk_start, chunk_end):
if ind < self.cfg.DISCRIMINATOR.context_len:
seq.append(process_for_sequence(data[ind]))
continue
# Since start token is chosen and bert tgt len is fixed
elif self.cfg.DISCRIMINATOR.truncate_backprop or (ind == chunk_start) or "classifier" in train_loss:
# Noticed I do not gain much memory if dis is frozen since requires_grad=False for dis params
# Also do not gain memory is dis is freed (Assuming pytorch already optimizes and shares comp
# graphs)
# Saved some time in backward()
# Stop gradient propagation
inp = torch.argmax(seq[-1], dim=-1)[None,
:].detach() # Gumbel max so gradients do not propagate
cont = inp
else:
inp = seq[-1][None, :]
cont = torch.argmax(seq[-1], dim=-1)[None, :].detach()
if self.cfg.TRAIN.append_note_status:
bptt, batch_size = cont.shape
if status_vec is None:
status_vec = inp.new_zeros((bptt, batch_size, self.vec_len), dtype=torch.bool)
else:
status_vec = status_vec[-1:, :, :]
self.vocab.update_status_vec(cont, status_vec)
ret = self.generator.forward_generate_gumbel(inp, self.temperature, sample_mems, status_vec=status_vec)
logits, sample_mems = ret
seq.append(logits[0])
# Prepare fake data
# Ignore first token
if len(seq) == sample_len + 1:
seq = seq[1:]
fake_chunk = torch.cat(
[i[None, :, :] for i in seq], 0
) # seq_len, bsz, vocab
if 'dis' in train_loss:
fake_chunk = fake_chunk.detach()
data_chunk = data[chunk_start:chunk_end]
if "classifier" in train_loss:
if self.P0 is None:
with torch.no_grad():
D0 = torch.sigmoid(self.dis_D_forward(fake_chunk))
self.P0 = (1. - D0) / torch.clamp(D0, min=1e-7)
real_label = self.P0.new_full((self.P0.shape[0],), 1.)
fake_label = self.P0.new_full((self.P0.shape[0],), 0.)
criterion = nn.BCELoss()
errDD_real = criterion(torch.sigmoid(self.dis_D_forward(data_chunk)), real_label)
errDD_fake = criterion(torch.sigmoid(self.dis_D_forward(fake_chunk.detach())), fake_label)
errDD_loss = errDD_real + errDD_fake
((errDD_loss.float().mean()) / (
self.cfg.DISCRIMINATOR.batch_chunk * self.cfg.DISCRIMINATOR.sample_chunks_mem)).backward()
# Reset params for next chunk
sample_mems = sample_mems.detach()
seq = [seq[-1]]
continue
# Train classifier
if 'gen' in train_loss and (
'ppo' in self.cfg.DISCRIMINATOR.BERT.loss_type or 'ppo' in self.cfg.DISCRIMINATOR.CNN.loss_type):
if self.P0 is None or update_D0:
with torch.no_grad():
D0 = torch.sigmoid(self.dis_D_forward(fake_chunk))
self.P0 = (1. - D0) / torch.clamp(D0, min=1e-7)
D1 = torch.sigmoid(self.dis_D_forward(fake_chunk))
P1 = (1. - D1)
ratio = (P1 / torch.clamp(D1 * self.P0, min=1e-7))
ratio_clipped = torch.clamp(ratio, 1.0 - self.cfg.PPO.clip_param, 1.0 +
self.cfg.PPO.clip_param)
if self.cfg.DISCRIMINATOR.type == "bert": \
# bert_vocab_size = 311
data_chunk = torch.transpose(data_chunk, 0, 1)
fake_chunk = torch.transpose(fake_chunk, 0, 1)
# Pad zeros corresponding to MASK token
fake_chunk = torch.cat(
[fake_chunk, fake_chunk.new_zeros((*fake_chunk.shape[:-1], 1))], -1
)
embedding_matrix = self.discriminator.bert.embeddings.word_embeddings.weight
emb_real = embedding_matrix[data_chunk]
emb_fake = torch.einsum(
"ve,bcv -> bce",
embedding_matrix,
fake_chunk,
)
bert_emb_real = self.discriminator(inputs_embeds=emb_real)
bert_emb_fake = self.discriminator(inputs_embeds=emb_fake)
# 1 is real and 0 is fake in bce_loss, so take we extract the real index from output vector
d_out_real, d_out_fake = bert_emb_real[0][:, 0], bert_emb_fake[0][:, 0]
if 'ppo' in self.cfg.DISCRIMINATOR.BERT.loss_type and 'gen' in train_loss:
surr1 = ratio * d_out_fake
surr2 = ratio_clipped * d_out_fake
target = torch.where(d_out_fake > 0, torch.min(surr1, surr2), torch.max(surr1, surr2))
temp_gen_loss, temp_dis_loss = get_losses(d_out_real, target,
self.cfg.DISCRIMINATOR.BERT.loss_type)
else:
temp_gen_loss, temp_dis_loss = get_losses(d_out_real, d_out_fake,
self.cfg.DISCRIMINATOR.BERT.loss_type)
# Regularize discriminator with gradient penalty
if "dis" in train_loss and 'gp' in self.cfg.DISCRIMINATOR.BERT.loss_type:
data_chunk = (
data_chunk.new_zeros((*data_chunk.shape, self.ntokens + 1), dtype=torch.float32)
.scatter_(-1, data_chunk[..., None], 1)
)
temp_gp_loss = self.calc_gradient_penalty(data_chunk, fake_chunk)
if self.cfg.DISCRIMINATOR.backprop_outside:
gen_loss += temp_gen_loss.detach()
dis_loss += temp_dis_loss.detach()
if "dis" in train_loss and 'gp' in self.cfg.DISCRIMINATOR.BERT.loss_type:
gp_loss += temp_gp_loss.detach()
else:
gen_loss += temp_gen_loss
dis_loss += temp_dis_loss
if "dis" in train_loss and 'gp' in self.cfg.DISCRIMINATOR.BERT.loss_type:
gp_loss += temp_gp_loss
elif self.cfg.DISCRIMINATOR.type == "cnn":
real_samples = (
data_chunk.new_zeros((*data_chunk.shape, self.ntokens), dtype=torch.float32)
.scatter_(-1, data_chunk[..., None], 1)
.transpose(0, 1)
)
gen_samples = torch.transpose(fake_chunk, 0, 1)
d_out_real = self.discriminator(real_samples)
d_out_fake = self.discriminator(gen_samples)
if 'ppo' in self.cfg.DISCRIMINATOR.CNN.loss_type and 'gen' in train_loss:
surr1 = ratio * d_out_fake
surr2 = ratio_clipped * d_out_fake
target = torch.where(d_out_fake > 0, torch.min(surr1, surr2), torch.max(surr1, surr2))
temp_gen_loss, temp_dis_loss = get_losses(d_out_real, target,
self.cfg.DISCRIMINATOR.CNN.loss_type)
else:
temp_gen_loss, temp_dis_loss = get_losses(
d_out_real, d_out_fake, self.cfg.DISCRIMINATOR.CNN.loss_type
)
# Regularize discriminator with gradient penalty
if "dis" in train_loss and 'gp' in self.cfg.DISCRIMINATOR.CNN.loss_type:
temp_gp_loss = self.calc_gradient_penalty(real_samples, gen_samples)
if self.cfg.DISCRIMINATOR.backprop_outside:
gen_loss += temp_gen_loss.detach()
dis_loss += temp_dis_loss.detach()
if "dis" in train_loss and 'gp' in self.cfg.DISCRIMINATOR.CNN.loss_type:
gp_loss += temp_gp_loss.detach()
else:
gen_loss += temp_gen_loss
dis_loss += temp_dis_loss
if "dis" in train_loss and 'gp' in self.cfg.DISCRIMINATOR.CNN.loss_type:
gp_loss += temp_gp_loss
else:
raise NotImplementedError
if self.cfg.DISCRIMINATOR.backprop_outside:
if "gen" in train_loss:
((temp_gen_loss.float().mean()) * self.cfg.DISCRIMINATOR.gen_loss_factor / (
self.cfg.DISCRIMINATOR.batch_chunk * self.cfg.DISCRIMINATOR.sample_chunks_mem)).backward()
if "dis" in train_loss:
((temp_dis_loss.float().mean()) * self.cfg.DISCRIMINATOR.dis_loss_factor / (
self.cfg.DISCRIMINATOR.batch_chunk * self.cfg.DISCRIMINATOR.sample_chunks_mem)).backward()
if self.cfg.DISCRIMINATOR.type == "bert":
if 'gp' in self.cfg.DISCRIMINATOR.BERT.loss_type:
((temp_gp_loss.float().mean()) * self.cfg.DISCRIMINATOR.dis_loss_factor / (
self.cfg.DISCRIMINATOR.batch_chunk * self.cfg.DISCRIMINATOR.sample_chunks_mem)).backward()
elif self.cfg.DISCRIMINATOR.type == "cnn":
if 'gp' in self.cfg.DISCRIMINATOR.CNN.loss_type:
((temp_gp_loss.float().mean()) * self.cfg.DISCRIMINATOR.dis_loss_factor / (
self.cfg.DISCRIMINATOR.batch_chunk * self.cfg.DISCRIMINATOR.sample_chunks_mem)).backward() # TODO CNN WGAN-GP
else:
raise NotImplementedError
# Reset params for next chunk
sample_mems = sample_mems.detach()
seq = [seq[-1]]
# Reset model parameters
self.generator.detach_mems_grad = True
self.generator.reset_length(cache_tgt_len, cache_mem_len)
#Setup values to return
if "dis" in train_loss:
dis_loss = self.cfg.DISCRIMINATOR.dis_loss_factor * dis_loss / self.cfg.DISCRIMINATOR.sample_chunks_mem
return_dict["dis_loss"] = dis_loss
if self.cfg.DISCRIMINATOR.type == "bert":
if 'gp' in self.cfg.DISCRIMINATOR.BERT.loss_type:
return_dict[
"gp_loss"] = self.cfg.DISCRIMINATOR.dis_loss_factor * gp_loss / self.cfg.DISCRIMINATOR.sample_chunks_mem
elif self.cfg.DISCRIMINATOR.type == "cnn":
if 'gp' in self.cfg.DISCRIMINATOR.CNN.loss_type:
return_dict[
"gp_loss"] = self.cfg.DISCRIMINATOR.dis_loss_factor * gp_loss / self.cfg.DISCRIMINATOR.sample_chunks_mem
else:
raise NotImplementedError
elif "gen" in train_loss:
gen_loss = self.cfg.DISCRIMINATOR.gen_loss_factor * gen_loss / self.cfg.DISCRIMINATOR.sample_chunks_mem
return_dict["gen_loss"] = gen_loss
return return_dict