Dassl.pytorch/dassl/engine/ssl/fixmatch.py (83 lines of code) (raw):

import torch from torch.nn import functional as F from dassl.data import DataManager from dassl.engine import TRAINER_REGISTRY, TrainerXU from dassl.metrics import compute_accuracy from dassl.data.transforms import build_transform @TRAINER_REGISTRY.register() class FixMatch(TrainerXU): """FixMatch: Simplifying Semi-Supervised Learning with Consistency and Confidence. https://arxiv.org/abs/2001.07685. """ def __init__(self, cfg): super().__init__(cfg) self.weight_u = cfg.TRAINER.FIXMATCH.WEIGHT_U self.conf_thre = cfg.TRAINER.FIXMATCH.CONF_THRE def check_cfg(self, cfg): assert len(cfg.TRAINER.FIXMATCH.STRONG_TRANSFORMS) > 0 def build_data_loader(self): cfg = self.cfg tfm_train = build_transform(cfg, is_train=True) custom_tfm_train = [tfm_train] choices = cfg.TRAINER.FIXMATCH.STRONG_TRANSFORMS tfm_train_strong = build_transform(cfg, is_train=True, choices=choices) custom_tfm_train += [tfm_train_strong] self.dm = DataManager(self.cfg, custom_tfm_train=custom_tfm_train) self.train_loader_x = self.dm.train_loader_x self.train_loader_u = self.dm.train_loader_u self.val_loader = self.dm.val_loader self.test_loader = self.dm.test_loader self.num_classes = self.dm.num_classes def assess_y_pred_quality(self, y_pred, y_true, mask): n_masked_correct = (y_pred.eq(y_true).float() * mask).sum() acc_thre = n_masked_correct / (mask.sum() + 1e-5) acc_raw = y_pred.eq(y_true).sum() / y_pred.numel() # raw accuracy keep_rate = mask.sum() / mask.numel() output = { "acc_thre": acc_thre, "acc_raw": acc_raw, "keep_rate": keep_rate } return output def forward_backward(self, batch_x, batch_u): parsed_data = self.parse_batch_train(batch_x, batch_u) input_x, input_x2, label_x, input_u, input_u2, label_u = parsed_data input_u = torch.cat([input_x, input_u], 0) input_u2 = torch.cat([input_x2, input_u2], 0) n_x = input_x.size(0) # Generate pseudo labels with torch.no_grad(): output_u = F.softmax(self.model(input_u), 1) max_prob, label_u_pred = output_u.max(1) mask_u = (max_prob >= self.conf_thre).float() # Evaluate pseudo labels' accuracy y_u_pred_stats = self.assess_y_pred_quality( label_u_pred[n_x:], label_u, mask_u[n_x:] ) # Supervised loss output_x = self.model(input_x) loss_x = F.cross_entropy(output_x, label_x) # Unsupervised loss output_u = self.model(input_u2) loss_u = F.cross_entropy(output_u, label_u_pred, reduction="none") loss_u = (loss_u * mask_u).mean() loss = loss_x + loss_u * self.weight_u self.model_backward_and_update(loss) loss_summary = { "loss_x": loss_x.item(), "acc_x": compute_accuracy(output_x, label_x)[0].item(), "loss_u": loss_u.item(), "y_u_pred_acc_raw": y_u_pred_stats["acc_raw"], "y_u_pred_acc_thre": y_u_pred_stats["acc_thre"], "y_u_pred_keep": y_u_pred_stats["keep_rate"], } if (self.batch_idx + 1) == self.num_batches: self.update_lr() return loss_summary def parse_batch_train(self, batch_x, batch_u): input_x = batch_x["img"] input_x2 = batch_x["img2"] label_x = batch_x["label"] input_u = batch_u["img"] input_u2 = batch_u["img2"] # label_u is used only for evaluating pseudo labels' accuracy label_u = batch_u["label"] input_x = input_x.to(self.device) input_x2 = input_x2.to(self.device) label_x = label_x.to(self.device) input_u = input_u.to(self.device) input_u2 = input_u2.to(self.device) label_u = label_u.to(self.device) return input_x, input_x2, label_x, input_u, input_u2, label_u