in mmf/models/unit/unit.py [0:0]
def build(self):
# build the base model (based on DETR)
self.unit_base_model = UniTBaseModel(self.config.base_args)
def keep_only_backbone_params(model_state_dict):
keys = list(model_state_dict.keys())
for k in keys:
if "backbone" not in k:
model_state_dict.pop(k)
ckpt_path = self.config.base_ckpt_path
if ckpt_path != "":
logger.info(f"initializing base model (UniT) from {ckpt_path}")
if ckpt_path.startswith("https"):
base_checkpoint = torch.hub.load_state_dict_from_url(
ckpt_path, check_hash=True
)
else:
base_checkpoint = torch.load(ckpt_path)
if self.config.base_ckpt_load_backbone_only:
keep_only_backbone_params(base_checkpoint["model"])
self.unit_base_model.load_state_dict(
base_checkpoint["model"], strict=False
)
else:
self.unit_base_model.load_state_dict(
base_checkpoint["model"], strict=True
)
# build the text encoder (BERT)
self.bert_model = TransformerEncoder(self.config.base_args.bert_config)
detr_hidden_dim = self.config.base_args.decoder_hidden_dim
bert_config = deepcopy(self.bert_model.config)
self.bert_projection = nn.Linear(bert_config.hidden_size, detr_hidden_dim)
self.bert_pos_projection = nn.Linear(bert_config.hidden_size, detr_hidden_dim)
self.classifiers = nn.ModuleDict()
self.task_embeddings_lang = nn.Identity()
if self.config.base_args.use_task_embedding_in_lang_encoder:
self.task_embeddings_lang = nn.Embedding(
self.config.max_task_num, bert_config.hidden_size
)
bert_config.hidden_size = detr_hidden_dim
# build the task-specific output heads
self.class_embeds = nn.ModuleDict()
self.bbox_embeds = nn.ModuleDict()
self.det_losses = nn.ModuleDict()
for dataset_name in self.config.base_args.num_queries.get("detection", []):
num_cls = self.config.heads["detection"][dataset_name]["num_classes"]
self.class_embeds[dataset_name] = nn.Linear(detr_hidden_dim, num_cls + 1)
self.bbox_embeds[dataset_name] = MLP(detr_hidden_dim, detr_hidden_dim, 4, 3)
attr_head = None
if self.config.heads["detection"][dataset_name]["use_attr"]:
attr_head = AttributeHead(
num_cls, self.config.base_args.attribute_class_num, detr_hidden_dim
)
self.det_losses[dataset_name] = build_detection_loss(
self.config.base_args, num_cls, attr_head
)
vl_classifiers = nn.ModuleDict()
for dataset_name in self.config.base_args.num_queries.get("vl", []):
vl_classifiers[dataset_name] = nn.Sequential(
BertPredictionHeadTransform(bert_config),
nn.Linear(
bert_config.hidden_size,
self.config.heads["vl"][dataset_name]["num_labels"],
),
)
self.classifiers["vl"] = vl_classifiers
self.dropout = nn.Dropout(bert_config.hidden_dropout_prob)
glue_classifiers = nn.ModuleDict()
for dataset_name in self.config.base_args.num_queries.get("glue", []):
glue_classifiers[dataset_name] = nn.Sequential(
BertPredictionHeadTransform(bert_config),
nn.Linear(
bert_config.hidden_size,
self.config.heads["glue"][dataset_name]["num_labels"],
),
)
self.classifiers["glue"] = glue_classifiers
self.loss_calculation_fn = {}
self.loss_calculation_fn["detection"] = self.detection_loss_calculation
self.loss_calculation_fn["vl"] = self.classifier_loss_calculation
self.loss_calculation_fn["glue"] = self.classifier_loss_calculation
self.losses_dict = {}
self.losses_dict["vl"] = {
name: self.get_loss_fn(self.config.heads["vl"][name]["loss_type"])
for name in self.config.heads["vl"]
}
self.losses_dict["glue"] = {
name: self.get_loss_fn(self.config.heads["glue"][name]["loss_type"])
for name in self.config.heads["glue"]
}