in ubteacher/engine/trainer.py [0:0]
def run_step_full_semisup(self):
self._trainer.iter = self.iter
assert self.model.training, "[UBTeacherTrainer] model was changed to eval mode!"
start = time.perf_counter()
data = next(self._trainer._data_loader_iter)
# data_q and data_k from different augmentations (q:strong, k:weak)
# label_strong, label_weak, unlabed_strong, unlabled_weak
label_data_q, label_data_k, unlabel_data_q, unlabel_data_k = data
data_time = time.perf_counter() - start
# remove unlabeled data labels
unlabel_data_q = self.remove_label(unlabel_data_q)
unlabel_data_k = self.remove_label(unlabel_data_k)
# burn-in stage (supervised training with labeled data)
if self.iter < self.cfg.SEMISUPNET.BURN_UP_STEP:
# input both strong and weak supervised data into model
label_data_q.extend(label_data_k)
record_dict, _, _, _ = self.model(
label_data_q, branch="supervised")
# weight losses
loss_dict = {}
for key in record_dict.keys():
if key[:4] == "loss":
loss_dict[key] = record_dict[key] * 1
losses = sum(loss_dict.values())
else:
if self.iter == self.cfg.SEMISUPNET.BURN_UP_STEP:
# update copy the the whole model
self._update_teacher_model(keep_rate=0.00)
elif (
self.iter - self.cfg.SEMISUPNET.BURN_UP_STEP
) % self.cfg.SEMISUPNET.TEACHER_UPDATE_ITER == 0:
self._update_teacher_model(
keep_rate=self.cfg.SEMISUPNET.EMA_KEEP_RATE)
record_dict = {}
# generate the pseudo-label using teacher model
# note that we do not convert to eval mode, as 1) there is no gradient computed in
# teacher model and 2) batch norm layers are not updated as well
with torch.no_grad():
(
_,
proposals_rpn_unsup_k,
proposals_roih_unsup_k,
_,
) = self.model_teacher(unlabel_data_k, branch="unsup_data_weak")
# Pseudo-labeling
cur_threshold = self.cfg.SEMISUPNET.BBOX_THRESHOLD
joint_proposal_dict = {}
joint_proposal_dict["proposals_rpn"] = proposals_rpn_unsup_k
(
pesudo_proposals_rpn_unsup_k,
nun_pseudo_bbox_rpn,
) = self.process_pseudo_label(
proposals_rpn_unsup_k, cur_threshold, "rpn", "thresholding"
)
joint_proposal_dict["proposals_pseudo_rpn"] = pesudo_proposals_rpn_unsup_k
# Pseudo_labeling for ROI head (bbox location/objectness)
pesudo_proposals_roih_unsup_k, _ = self.process_pseudo_label(
proposals_roih_unsup_k, cur_threshold, "roih", "thresholding"
)
joint_proposal_dict["proposals_pseudo_roih"] = pesudo_proposals_roih_unsup_k
# add pseudo-label to unlabeled data
unlabel_data_q = self.add_label(
unlabel_data_q, joint_proposal_dict["proposals_pseudo_roih"]
)
unlabel_data_k = self.add_label(
unlabel_data_k, joint_proposal_dict["proposals_pseudo_roih"]
)
all_label_data = label_data_q + label_data_k
all_unlabel_data = unlabel_data_q
record_all_label_data, _, _, _ = self.model(
all_label_data, branch="supervised"
)
record_dict.update(record_all_label_data)
record_all_unlabel_data, _, _, _ = self.model(
all_unlabel_data, branch="supervised"
)
new_record_all_unlabel_data = {}
for key in record_all_unlabel_data.keys():
new_record_all_unlabel_data[key + "_pseudo"] = record_all_unlabel_data[
key
]
record_dict.update(new_record_all_unlabel_data)
# weight losses
loss_dict = {}
for key in record_dict.keys():
if key[:4] == "loss":
if key == "loss_rpn_loc_pseudo" or key == "loss_box_reg_pseudo":
# pseudo bbox regression <- 0
loss_dict[key] = record_dict[key] * 0
elif key[-6:] == "pseudo": # unsupervised loss
loss_dict[key] = (
record_dict[key] *
self.cfg.SEMISUPNET.UNSUP_LOSS_WEIGHT
)
else: # supervised loss
loss_dict[key] = record_dict[key] * 1
losses = sum(loss_dict.values())
metrics_dict = record_dict
metrics_dict["data_time"] = data_time
self._write_metrics(metrics_dict)
self.optimizer.zero_grad()
losses.backward()
self.optimizer.step()