in src/train.py [0:0]
def train_batch(model, optimizer, baseline, epoch,
batch_id, step, batch, tb_logger, opts):
# Unwrap baseline
bat, bl_val = baseline.unwrap_batch(batch)
# Optionally move Tensors to GPU
x = move_to(bat['nodes'], opts.device)
graph = move_to(bat['graph'], opts.device)
bl_val = move_to(bl_val, opts.device) if bl_val is not None else None
# Evaluate model, get costs and log probabilities
cost, log_likelihood = model(x, graph)
# Evaluate baseline, get baseline loss if any (only for critic)
bl_val, bl_loss = baseline.eval(x, graph, cost) if bl_val is None else (bl_val, 0)
# Calculate loss
reinforce_loss = ((cost - bl_val) * log_likelihood).mean()
loss = reinforce_loss + bl_loss
# Normalize loss for gradient accumulation
loss = loss / opts.accumulation_steps
# Perform backward pass
loss.backward()
# Clip gradient norms and get (clipped) gradient norms for logging
grad_norms = clip_grad_norms(optimizer.param_groups, opts.max_grad_norm)
# Perform optimization step after accumulating gradients
if step % opts.accumulation_steps == 0:
optimizer.step()
optimizer.zero_grad()
# Logging
if step % int(opts.log_step) == 0:
log_values(cost, grad_norms, epoch, batch_id, step, log_likelihood,
reinforce_loss, bl_loss, tb_logger, opts)