def build()

in mmf/models/unit/unit.py [0:0]


    def build(self):
        # build the base model (based on DETR)
        self.unit_base_model = UniTBaseModel(self.config.base_args)

        def keep_only_backbone_params(model_state_dict):
            keys = list(model_state_dict.keys())
            for k in keys:
                if "backbone" not in k:
                    model_state_dict.pop(k)

        ckpt_path = self.config.base_ckpt_path
        if ckpt_path != "":
            logger.info(f"initializing base model (UniT) from {ckpt_path}")
            if ckpt_path.startswith("https"):
                base_checkpoint = torch.hub.load_state_dict_from_url(
                    ckpt_path, check_hash=True
                )
            else:
                base_checkpoint = torch.load(ckpt_path)
            if self.config.base_ckpt_load_backbone_only:
                keep_only_backbone_params(base_checkpoint["model"])
                self.unit_base_model.load_state_dict(
                    base_checkpoint["model"], strict=False
                )
            else:
                self.unit_base_model.load_state_dict(
                    base_checkpoint["model"], strict=True
                )

        # build the text encoder (BERT)
        self.bert_model = TransformerEncoder(self.config.base_args.bert_config)
        detr_hidden_dim = self.config.base_args.decoder_hidden_dim
        bert_config = deepcopy(self.bert_model.config)
        self.bert_projection = nn.Linear(bert_config.hidden_size, detr_hidden_dim)
        self.bert_pos_projection = nn.Linear(bert_config.hidden_size, detr_hidden_dim)

        self.classifiers = nn.ModuleDict()

        self.task_embeddings_lang = nn.Identity()
        if self.config.base_args.use_task_embedding_in_lang_encoder:
            self.task_embeddings_lang = nn.Embedding(
                self.config.max_task_num, bert_config.hidden_size
            )

        bert_config.hidden_size = detr_hidden_dim

        # build the task-specific output heads
        self.class_embeds = nn.ModuleDict()
        self.bbox_embeds = nn.ModuleDict()
        self.det_losses = nn.ModuleDict()
        for dataset_name in self.config.base_args.num_queries.get("detection", []):
            num_cls = self.config.heads["detection"][dataset_name]["num_classes"]
            self.class_embeds[dataset_name] = nn.Linear(detr_hidden_dim, num_cls + 1)
            self.bbox_embeds[dataset_name] = MLP(detr_hidden_dim, detr_hidden_dim, 4, 3)
            attr_head = None
            if self.config.heads["detection"][dataset_name]["use_attr"]:
                attr_head = AttributeHead(
                    num_cls, self.config.base_args.attribute_class_num, detr_hidden_dim
                )
            self.det_losses[dataset_name] = build_detection_loss(
                self.config.base_args, num_cls, attr_head
            )

        vl_classifiers = nn.ModuleDict()
        for dataset_name in self.config.base_args.num_queries.get("vl", []):
            vl_classifiers[dataset_name] = nn.Sequential(
                BertPredictionHeadTransform(bert_config),
                nn.Linear(
                    bert_config.hidden_size,
                    self.config.heads["vl"][dataset_name]["num_labels"],
                ),
            )

        self.classifiers["vl"] = vl_classifiers
        self.dropout = nn.Dropout(bert_config.hidden_dropout_prob)

        glue_classifiers = nn.ModuleDict()
        for dataset_name in self.config.base_args.num_queries.get("glue", []):
            glue_classifiers[dataset_name] = nn.Sequential(
                BertPredictionHeadTransform(bert_config),
                nn.Linear(
                    bert_config.hidden_size,
                    self.config.heads["glue"][dataset_name]["num_labels"],
                ),
            )

        self.classifiers["glue"] = glue_classifiers

        self.loss_calculation_fn = {}
        self.loss_calculation_fn["detection"] = self.detection_loss_calculation
        self.loss_calculation_fn["vl"] = self.classifier_loss_calculation
        self.loss_calculation_fn["glue"] = self.classifier_loss_calculation

        self.losses_dict = {}
        self.losses_dict["vl"] = {
            name: self.get_loss_fn(self.config.heads["vl"][name]["loss_type"])
            for name in self.config.heads["vl"]
        }
        self.losses_dict["glue"] = {
            name: self.get_loss_fn(self.config.heads["glue"][name]["loss_type"])
            for name in self.config.heads["glue"]
        }