def main()

in egg/zoo/language_bottleneck/guess_number/train.py [0:0]


def main(params):
    opts = get_params(params)
    print(opts)

    device = opts.device

    train_loader = OneHotLoader(
        n_bits=opts.n_bits,
        bits_s=opts.bits_s,
        bits_r=opts.bits_r,
        batch_size=opts.batch_size,
        batches_per_epoch=opts.n_examples_per_epoch / opts.batch_size,
    )

    test_loader = UniformLoader(
        n_bits=opts.n_bits, bits_s=opts.bits_s, bits_r=opts.bits_r
    )
    test_loader.batch = [x.to(device) for x in test_loader.batch]

    if not opts.variable_length:
        sender = Sender(
            n_bits=opts.n_bits, n_hidden=opts.sender_hidden, vocab_size=opts.vocab_size
        )
        if opts.mode == "gs":
            sender = core.GumbelSoftmaxWrapper(
                agent=sender, temperature=opts.temperature
            )
            receiver = Receiver(n_bits=opts.n_bits, n_hidden=opts.receiver_hidden)
            receiver = core.SymbolReceiverWrapper(
                receiver,
                vocab_size=opts.vocab_size,
                agent_input_size=opts.receiver_hidden,
            )
            game = core.SymbolGameGS(sender, receiver, diff_loss)
        elif opts.mode == "rf":
            sender = core.ReinforceWrapper(agent=sender)
            receiver = Receiver(n_bits=opts.n_bits, n_hidden=opts.receiver_hidden)
            receiver = core.SymbolReceiverWrapper(
                receiver,
                vocab_size=opts.vocab_size,
                agent_input_size=opts.receiver_hidden,
            )
            receiver = core.ReinforceDeterministicWrapper(agent=receiver)
            game = core.SymbolGameReinforce(
                sender,
                receiver,
                diff_loss,
                sender_entropy_coeff=opts.sender_entropy_coeff,
            )
        elif opts.mode == "non_diff":
            sender = core.ReinforceWrapper(agent=sender)
            receiver = ReinforcedReceiver(
                n_bits=opts.n_bits, n_hidden=opts.receiver_hidden
            )
            receiver = core.SymbolReceiverWrapper(
                receiver,
                vocab_size=opts.vocab_size,
                agent_input_size=opts.receiver_hidden,
            )

            game = core.SymbolGameReinforce(
                sender,
                receiver,
                non_diff_loss,
                sender_entropy_coeff=opts.sender_entropy_coeff,
                receiver_entropy_coeff=opts.receiver_entropy_coeff,
            )
    else:
        if opts.mode != "rf":
            print("Only mode=rf is supported atm")
            opts.mode = "rf"

        if opts.sender_cell == "transformer":
            receiver = Receiver(n_bits=opts.n_bits, n_hidden=opts.receiver_hidden)
            sender = Sender(
                n_bits=opts.n_bits,
                n_hidden=opts.sender_hidden,
                vocab_size=opts.sender_hidden,
            )  # TODO: not really vocab
            sender = core.TransformerSenderReinforce(
                agent=sender,
                vocab_size=opts.vocab_size,
                embed_dim=opts.sender_emb,
                max_len=opts.max_len,
                num_layers=1,
                num_heads=1,
                hidden_size=opts.sender_hidden,
            )
        else:
            receiver = Receiver(n_bits=opts.n_bits, n_hidden=opts.receiver_hidden)
            sender = Sender(
                n_bits=opts.n_bits,
                n_hidden=opts.sender_hidden,
                vocab_size=opts.sender_hidden,
            )  # TODO: not really vocab
            sender = core.RnnSenderReinforce(
                agent=sender,
                vocab_size=opts.vocab_size,
                embed_dim=opts.sender_emb,
                hidden_size=opts.sender_hidden,
                max_len=opts.max_len,
                cell=opts.sender_cell,
            )

        if opts.receiver_cell == "transformer":
            receiver = Receiver(n_bits=opts.n_bits, n_hidden=opts.receiver_emb)
            receiver = core.TransformerReceiverDeterministic(
                receiver,
                opts.vocab_size,
                opts.max_len,
                opts.receiver_emb,
                num_heads=1,
                hidden_size=opts.receiver_hidden,
                num_layers=1,
            )
        else:
            receiver = Receiver(n_bits=opts.n_bits, n_hidden=opts.receiver_hidden)
            receiver = core.RnnReceiverDeterministic(
                receiver,
                opts.vocab_size,
                opts.receiver_emb,
                opts.receiver_hidden,
                cell=opts.receiver_cell,
            )

            game = core.SenderReceiverRnnGS(sender, receiver, diff_loss)

        game = core.SenderReceiverRnnReinforce(
            sender,
            receiver,
            diff_loss,
            sender_entropy_coeff=opts.sender_entropy_coeff,
            receiver_entropy_coeff=opts.receiver_entropy_coeff,
        )

    optimizer = torch.optim.Adam(
        [
            dict(params=sender.parameters(), lr=opts.sender_lr),
            dict(params=receiver.parameters(), lr=opts.receiver_lr),
        ]
    )

    loss = game.loss

    intervention = CallbackEvaluator(
        test_loader,
        device=device,
        is_gs=opts.mode == "gs",
        loss=loss,
        var_length=opts.variable_length,
        input_intervention=True,
    )

    trainer = core.Trainer(
        game=game,
        optimizer=optimizer,
        train_data=train_loader,
        validation_data=test_loader,
        callbacks=[
            core.ConsoleLogger(as_json=True),
            EarlyStopperAccuracy(opts.early_stopping_thr),
            intervention,
        ],
    )

    trainer.train(n_epochs=opts.n_epochs)

    core.close()