agents/neural_agent.py [228:312]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
    last_checkpoint = get_latest_checkpoint(agent_folder)
    assert last_checkpoint is not None, agent_folder
    logging.info('Loading a model from: %s', last_checkpoint)
    last_checkpoint = torch.load(last_checkpoint)
    model = build_model(**last_checkpoint['model_kwargs'])
    try:
        model.load_state_dict(last_checkpoint['model'])
    except RuntimeError:
        model = nn.DataParallel(model)
        model.load_state_dict(last_checkpoint['model'])
        model = model.module
    model.to(nets.DEVICE)
    return model


def finetune(model: NeuralModel, data: Sequence[Tuple[int,
                                                      phyre.SimulationStatus,
                                                      Sequence[float]]],
             simulator: phyre.ActionSimulator, learning_rate: float,
             num_updates: int) -> None:
    """Finetunes a model on a small new batch of data.

    Args:
        model: DQN network, e.g., built with build_model().
        data: a list of tuples (task_index, status, action).
        learning_rate: learning rate for Adam.
        num_updates: number updates to do. All data is used for every update.
    """

    data = [x for x in data if not x[1].is_invalid()]
    if not data:
        return
    task_indices, statuses, actions = zip(*data)
    if len(set(task_indices)) == 1:
        observations = np.expand_dims(simulator.initial_scenes[task_indices[0]],
                                      0)
    else:
        observations = simulator.initial_scenes[task_indices]

    device = model.module.device if isinstance(
        model, nn.DataParallel) else model.device
    is_solved = torch.tensor(statuses, device=device) > 0
    observations = torch.tensor(observations, device=device)
    actions = torch.tensor(actions, device=device)

    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    model.train()
    for _ in range(num_updates):
        optimizer.zero_grad()
        model.ce_loss(model(observations, actions), is_solved).backward()
        optimizer.step()


def refine_actions(model, actions, single_observarion, learning_rate,
                   num_updates, batch_size, refine_loss):
    device = model.module.device if isinstance(
        model, nn.DataParallel) else model.device
    observations = torch.tensor(single_observarion, device=device).unsqueeze(0)
    actions = torch.tensor(actions)

    refined_actions = []
    model.eval()
    preprocessed = model.preprocess(observations)
    preprocessed = {k: v.detach() for k, v in preprocessed.items()}
    for start in range(0, len(actions), batch_size):
        action_batch = actions[start:][:batch_size].to(device)
        action_batch = torch.nn.Parameter(action_batch)
        optimizer = torch.optim.Adam([action_batch], lr=learning_rate)
        losses = []
        for _ in range(num_updates):
            optimizer.zero_grad()
            logits = model(None, action_batch, preprocessed=preprocessed)
            if refine_loss == 'ce':
                loss = model.ce_loss(logits, actions.new_ones(len(logits)))
            elif refine_loss == 'linear':
                loss = -logits.sum()
            else:
                raise ValueError(f'Unknown loss: {refine_loss}')
            loss.backward()
            losses.append(loss.item())
            optimizer.step()
        action_batch = torch.clamp_(action_batch.data, 0, 1)
        refined_actions.append(action_batch.cpu().numpy())
    refined_actions = np.concatenate(refined_actions, 0).tolist()
    return refined_actions
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



agents/neural_agent_contrastive.py [249:334]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
        last_checkpoint = get_latest_checkpoint(agent_folder)

    assert last_checkpoint is not None, agent_folder
    logging.info('Loading a model from: %s', last_checkpoint)
    last_checkpoint = torch.load(last_checkpoint)
    model = build_model(**last_checkpoint['model_kwargs'])
    try:
        model.load_state_dict(last_checkpoint['model'])
    except RuntimeError:
        model = nn.DataParallel(model)
        model.load_state_dict(last_checkpoint['model'])
        model = model.module
    model.to(nets.DEVICE)
    return model


def finetune(model: NeuralModel, data: Sequence[Tuple[int,
                                                      phyre.SimulationStatus,
                                                      Sequence[float]]],
             simulator: phyre.ActionSimulator, learning_rate: float,
             num_updates: int) -> None:
    """Finetunes a model on a small new batch of data.

    Args:
        model: DQN network, e.g., built with build_model().
        data: a list of tuples (task_index, status, action).
        learning_rate: learning rate for Adam.
        num_updates: number updates to do. All data is used for every update.
    """

    data = [x for x in data if not x[1].is_invalid()]
    if not data:
        return
    task_indices, statuses, actions = zip(*data)
    if len(set(task_indices)) == 1:
        observations = np.expand_dims(simulator.initial_scenes[task_indices[0]],
                                      0)
    else:
        observations = simulator.initial_scenes[task_indices]

    device = model.module.device if isinstance(
        model, nn.DataParallel) else model.device
    is_solved = torch.tensor(statuses, device=device) > 0
    observations = torch.tensor(observations, device=device)
    actions = torch.tensor(actions, device=device)

    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    model.train()
    for _ in range(num_updates):
        optimizer.zero_grad()
        model.ce_loss(model(observations, actions), is_solved).backward()
        optimizer.step()


def refine_actions(model, actions, single_observarion, learning_rate,
                   num_updates, batch_size, refine_loss):
    device = model.module.device if isinstance(
        model, nn.DataParallel) else model.device
    observations = torch.tensor(single_observarion, device=device).unsqueeze(0)
    actions = torch.tensor(actions)

    refined_actions = []
    model.eval()
    preprocessed = model.preprocess(observations)
    preprocessed = {k: v.detach() for k, v in preprocessed.items()}
    for start in range(0, len(actions), batch_size):
        action_batch = actions[start:][:batch_size].to(device)
        action_batch = torch.nn.Parameter(action_batch)
        optimizer = torch.optim.Adam([action_batch], lr=learning_rate)
        losses = []
        for _ in range(num_updates):
            optimizer.zero_grad()
            logits = model(None, action_batch, preprocessed=preprocessed)
            if refine_loss == 'ce':
                loss = model.ce_loss(logits, actions.new_ones(len(logits)))
            elif refine_loss == 'linear':
                loss = -logits.sum()
            else:
                raise ValueError(f'Unknown loss: {refine_loss}')
            loss.backward()
            losses.append(loss.item())
            optimizer.step()
        action_batch = torch.clamp_(action_batch.data, 0, 1)
        refined_actions.append(action_batch.cpu().numpy())
    refined_actions = np.concatenate(refined_actions, 0).tolist()
    return refined_actions
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



