egg/zoo/language_bottleneck/mnist_adv/train.py [20:42]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
def diff_loss_symbol(
    _sender_input, _message, _receiver_input, receiver_output, labels, _aux_input
):
    loss = F.nll_loss(receiver_output, labels, reduction="none").mean()
    acc = (receiver_output.argmax(dim=1) == labels).float()
    return loss, {"acc": acc}


def get_params(params):
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--temperature",
        type=float,
        default=1.0,
        help="GS temperature for the sender (default: 1)",
    )

    parser.add_argument(
        "--early_stopping_thr",
        type=float,
        default=1.0,
        help="Early stopping threshold on accuracy (default: 1.0)",
    )
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



egg/zoo/language_bottleneck/mnist_overfit/train.py [21:43]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
def diff_loss_symbol(
    _sender_input, _message, _receiver_input, receiver_output, labels, _aux_input
):
    loss = F.nll_loss(receiver_output, labels, reduction="none").mean()
    acc = (receiver_output.argmax(dim=1) == labels).float()
    return loss, {"acc": acc}


def get_params(params):
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--temperature",
        type=float,
        default=1.0,
        help="GS temperature for the sender (default: 1)",
    )

    parser.add_argument(
        "--early_stopping_thr",
        type=float,
        default=1.0,
        help="Early stopping threshold on accuracy (default: 1.0)",
    )
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



