def get_callbacks()

in main.py [0:0]


def get_callbacks(args, datasets, is_trnsf):
    """Return the correct callbacks for training."""
    if is_trnsf:
        callbacks = [
            (
                "valid_acc",
                EpochScoring(
                    accuracy,  # cannot use "accuracy" because using a transformer rather than classifier
                    name="valid_acc",
                    lower_is_better=False,
                    target_extractor=target_extractor,
                ),
            ),
            (
                "valid_loglike",
                EpochScoring(
                    loglike,  # the actual loss also contains all regularization
                    name="valid_loglike",
                    lower_is_better=False,
                    target_extractor=target_extractor,
                ),
            ),
        ]
    else:
        callbacks = []

    callbacks += [
        (
            "train_acc",
            EpochScoring(
                partial(
                    accuracy_filter_train,
                    map_target_position=datasets["train"].map_target_position,
                ),
                name="train_acc",
                on_train=True,
                lower_is_better=False,
                target_extractor=partial(
                    target_extractor, is_multi_target=True),
            ),
        )
    ]

    callbacks += get_lr_schedulers(args, datasets, is_trnsf=is_trnsf)

    # callbacks += [skorch.callbacks.GradientNormClipping(gradient_clip_value=0.1)]

    if args.train.freezer.patterns is not None:
        callbacks += [
            Freezer(
                args.train.freezer.patterns,
                at=args.train.freezer.at
                if args.train.freezer.at is not None
                else return_True,
            )
        ]

    if args.train.unfreezer.patterns is not None:
        callbacks += [
            Unfreezer(args.train.unfreezer.patterns,
                      at=args.train.unfreezer.at)
        ]

    if args.train.ce_threshold is not None:
        callbacks += [StopAtThreshold(threshold=args.train.ce_threshold)]

    return callbacks