def _get_trunk()

in models/base_ssl3d_model.py [0:0]


    def _get_trunk(self):
        import models.trunks as models
        trunks = torch.nn.ModuleList()
        if 'arch_point' in self.config:
            assert self.config['arch_point'] in models.TRUNKS, 'Unknown model architecture'
            trunks.append(models.TRUNKS[self.config['arch_point']](**self.config['args_point']))
            trunks.append(models.TRUNKS[self.config['arch_point']](**self.config['args_point']))
        if 'arch_vox' in self.config:
            assert self.config['arch_vox'] in models.TRUNKS, 'Unknown model architecture'
            trunks.append(models.TRUNKS[self.config['arch_vox']](**self.config['args_vox']))
            trunks.append(models.TRUNKS[self.config['arch_vox']](**self.config['args_vox']))

        for numh in range(len(trunks)//2):
            for param_q, param_k in zip(trunks[numh*2].parameters(), trunks[numh*2+1].parameters()):
                param_k.data.copy_(param_q.data)  # initialize
                param_k.requires_grad = False  # not update by gradient

        logger = self.logger
        for model in trunks:
            if logger is not None:
                if isinstance(model, (list, tuple)):
                    logger.add_line("=" * 30 + "   Model   " + "=" * 30)
                    for m in model:
                        logger.add_line(str(m))
                    logger.add_line("=" * 30 + "   Parameters   " + "=" * 30)
                    for m in model:
                        logger.add_line(parameter_description(m))
                else:
                    logger.add_line("=" * 30 + "   Model   " + "=" * 30)
                    logger.add_line(str(model))
                    logger.add_line("=" * 30 + "   Parameters   " + "=" * 30)
                    logger.add_line(parameter_description(model))
        return trunks