in main.py [0:0]
def get_callbacks(args, datasets, is_trnsf):
"""Return the correct callbacks for training."""
if is_trnsf:
callbacks = [
(
"valid_acc",
EpochScoring(
accuracy, # cannot use "accuracy" because using a transformer rather than classifier
name="valid_acc",
lower_is_better=False,
target_extractor=target_extractor,
),
),
(
"valid_loglike",
EpochScoring(
loglike, # the actual loss also contains all regularization
name="valid_loglike",
lower_is_better=False,
target_extractor=target_extractor,
),
),
]
else:
callbacks = []
callbacks += [
(
"train_acc",
EpochScoring(
partial(
accuracy_filter_train,
map_target_position=datasets["train"].map_target_position,
),
name="train_acc",
on_train=True,
lower_is_better=False,
target_extractor=partial(
target_extractor, is_multi_target=True),
),
)
]
callbacks += get_lr_schedulers(args, datasets, is_trnsf=is_trnsf)
# callbacks += [skorch.callbacks.GradientNormClipping(gradient_clip_value=0.1)]
if args.train.freezer.patterns is not None:
callbacks += [
Freezer(
args.train.freezer.patterns,
at=args.train.freezer.at
if args.train.freezer.at is not None
else return_True,
)
]
if args.train.unfreezer.patterns is not None:
callbacks += [
Unfreezer(args.train.unfreezer.patterns,
at=args.train.unfreezer.at)
]
if args.train.ce_threshold is not None:
callbacks += [StopAtThreshold(threshold=args.train.ce_threshold)]
return callbacks