def train_batch_sl()

in src/train.py [0:0]


def train_batch_sl(model, optimizer, epoch, batch_id, 
                   step, batch, tb_logger, opts):
    # Optionally move Tensors to GPU
    x = move_to(batch['nodes'], opts.device)
    graph = move_to(batch['graph'], opts.device)
    
    if opts.model == 'nar':
        targets = move_to(batch['tour_edges'], opts.device)
        # Compute class weights for NAR decoder
        _targets = batch['tour_edges'].numpy().flatten()
        class_weights = compute_class_weight("balanced", classes=np.unique(_targets), y=_targets)
        class_weights = move_to(torch.FloatTensor(class_weights), opts.device)
    else:
        class_weights = None
        targets = move_to(batch['tour_nodes'], opts.device)
    
    # Evaluate model, get costs and loss
    cost, loss = model(x, graph, supervised=True, targets=targets, class_weights=class_weights)
    
    # Normalize loss for gradient accumulation
    loss = loss / opts.accumulation_steps
    
    # Perform backward pass
    loss.backward()
    
    # Clip gradient norms and get (clipped) gradient norms for logging
    grad_norms = clip_grad_norms(optimizer.param_groups, opts.max_grad_norm)
    
    # Perform optimization step after accumulating gradients
    if step % opts.accumulation_steps == 0:
        optimizer.step()
        optimizer.zero_grad()

    # Logging
    if step % int(opts.log_step) == 0:
        log_values_sl(cost, grad_norms, epoch, batch_id, 
                      step, loss, tb_logger, opts)