in models/base_ssl3d_model.py [0:0]
def _get_trunk(self):
import models.trunks as models
trunks = torch.nn.ModuleList()
if 'arch_point' in self.config:
assert self.config['arch_point'] in models.TRUNKS, 'Unknown model architecture'
trunks.append(models.TRUNKS[self.config['arch_point']](**self.config['args_point']))
trunks.append(models.TRUNKS[self.config['arch_point']](**self.config['args_point']))
if 'arch_vox' in self.config:
assert self.config['arch_vox'] in models.TRUNKS, 'Unknown model architecture'
trunks.append(models.TRUNKS[self.config['arch_vox']](**self.config['args_vox']))
trunks.append(models.TRUNKS[self.config['arch_vox']](**self.config['args_vox']))
for numh in range(len(trunks)//2):
for param_q, param_k in zip(trunks[numh*2].parameters(), trunks[numh*2+1].parameters()):
param_k.data.copy_(param_q.data) # initialize
param_k.requires_grad = False # not update by gradient
logger = self.logger
for model in trunks:
if logger is not None:
if isinstance(model, (list, tuple)):
logger.add_line("=" * 30 + " Model " + "=" * 30)
for m in model:
logger.add_line(str(m))
logger.add_line("=" * 30 + " Parameters " + "=" * 30)
for m in model:
logger.add_line(parameter_description(m))
else:
logger.add_line("=" * 30 + " Model " + "=" * 30)
logger.add_line(str(model))
logger.add_line("=" * 30 + " Parameters " + "=" * 30)
logger.add_line(parameter_description(model))
return trunks