in pytext/trainers/trainer.py [0:0]
def __init__(self, config: Config, model: torch.nn.Module):
if config.early_stop_after > 0:
assert config.do_eval, "can't do early stopping when not running evalution"
if (
config.discriminative_lr is not None
or config.freeze_params_pattern is not None
):
optimizer_grouped_parameters = []
optimizer_parameters_covered = []
if config.freeze_params_pattern is not None:
tmp_param = {
n: p
for n, p in model.named_parameters()
if any(nd in n for nd in config.freeze_params_pattern)
}
if len(tmp_param) > 0:
optimizer_parameters_covered.extend(list(tmp_param.keys()))
optimizer_grouped_parameters.append(
{
"params": list(tmp_param.values()),
"lr": 0.0,
}
)
if config.discriminative_lr is not None:
assert (
config.discriminative_lr_params_pattern is not None
), "Missing discriminative_lr_params_pattern"
tmp_param = {
n: p
for n, p in model.named_parameters()
if any(nd in n for nd in config.discriminative_lr_params_pattern)
and n not in optimizer_parameters_covered
}
if len(tmp_param) > 0:
optimizer_parameters_covered.extend(list(tmp_param.keys()))
optimizer_grouped_parameters.append(
{
"params": list(tmp_param.values()),
"lr": config.discriminative_lr,
}
)
optimizer_grouped_parameters.append(
{
"params": [
p
for n, p in model.named_parameters()
if n not in optimizer_parameters_covered
]
}
)
if precision.FP16_ENABLED:
self.optimizer: torch.optim.Optimizer = create_optimizer(
config.fp16_args,
model,
config.optimizer,
config.num_accumulated_batches,
optimizer_grouped_parameters,
)
else:
self.optimizer: torch.optim.Optimizer = create_optimizer(
config.optimizer, model, optimizer_grouped_parameters
)
else:
if precision.FP16_ENABLED:
self.optimizer: torch.optim.Optimizer = create_optimizer(
config.fp16_args,
model,
config.optimizer,
config.num_accumulated_batches,
)
else:
self.optimizer: torch.optim.Optimizer = create_optimizer(
config.optimizer, model
)
self.scheduler: torch.optim.lr_scheduler = (
create_scheduler(config.scheduler, self.optimizer)
if config.scheduler
else Scheduler()
)
self.sparsifier: Sparsifier = (
create_sparsifier(config.sparsifier) if config.sparsifier else Sparsifier()
)
self.config = config