in src/controlnet_aux/zoe/zoedepth/models/zoedepth_nk/zoedepth_nk_v1.py [0:0]
def get_lr_params(self, lr):
"""
Learning rate configuration for different layers of the model
Args:
lr (float) : Base learning rate
Returns:
list : list of parameters to optimize and their learning rates, in the format required by torch optimizers.
"""
param_conf = []
if self.train_midas:
def get_rel_pos_params():
for name, p in self.core.core.pretrained.named_parameters():
if "relative_position" in name:
yield p
def get_enc_params_except_rel_pos():
for name, p in self.core.core.pretrained.named_parameters():
if "relative_position" not in name:
yield p
encoder_params = get_enc_params_except_rel_pos()
rel_pos_params = get_rel_pos_params()
midas_params = self.core.core.scratch.parameters()
midas_lr_factor = self.midas_lr_factor if self.is_midas_pretrained else 1.0
param_conf.extend([
{'params': encoder_params, 'lr': lr / self.encoder_lr_factor},
{'params': rel_pos_params, 'lr': lr / self.pos_enc_lr_factor},
{'params': midas_params, 'lr': lr / midas_lr_factor}
])
remaining_modules = []
for name, child in self.named_children():
if name != 'core':
remaining_modules.append(child)
remaining_params = itertools.chain(
*[child.parameters() for child in remaining_modules])
param_conf.append({'params': remaining_params, 'lr': lr})
return param_conf