def __init__()

in src/SSFN/model.py [0:0]


    def __init__(self, config, args=None):
        super(SequenceAndStructureFusionNetwork, self).__init__(config)
        self.num_labels = config.num_labels
        # sequence encoder, structure encoder, structural embedding encoder
        self.has_seq_encoder = args.has_seq_encoder
        self.has_struct_encoder = args.has_struct_encoder
        self.has_embedding_encoder = args.has_embedding_encoder
        assert args.has_seq_encoder or args.has_struct_encoder or args.has_embedding_encoder
        # includes sequence encoder
        if args.has_seq_encoder:
            # sequence -> transformer(k11 layers) + pooling + dense(k12 layers)
            self.seq_encoder = BertModel(config)
            self.seq_pooler = create_pooler(pooler_type="seq", config=config, args=args)
            assert isinstance(config.seq_fc_size, list)
            self.seq_linear = []
            input_size = config.hidden_size
            for idx in range(len(config.seq_fc_size)):
                linear = nn.Linear(input_size, config.seq_fc_size[idx])
                self.seq_linear.append(linear)
                self.seq_linear.append(create_activate(config.activate_func))
                input_size = config.seq_fc_size[idx]
            self.seq_linear = nn.ModuleList(self.seq_linear)
        # includes structure encoder
        if args.has_struct_encoder:
            # structure-> embedding + gcn(k21 layers) + pooling + dense(k22 layers)
            # k layers
            # output:[batch_size, seq_len, output_dim]
            self.struct_embedder = nn.Embedding(config.struct_vocab_size, config.struct_embed_size, padding_idx=config.pad_token_id)
            self.struct_encoder = []
            assert isinstance(config.struct_hidden_size, list) and isinstance(config.struct_output_size, list)
            input_size = config.struct_embed_size
            output_size = None
            assert len(config.struct_hidden_size) == len(config.struct_output_size)
            for idx in range(len(config.struct_output_size)):
                layer = GAT(feature_size=input_size,
                            hidden_size=config.struct_hidden_size[idx],
                            output_size=config.struct_output_size[idx],
                            dropout=config.hidden_dropout_prob,
                            nheads=config.struct_nb_heads,
                            alpha=config.struct_alpha)
                self.struct_encoder.append(layer)
                input_size = config.struct_output_size[idx]
                output_size = config.struct_output_size[idx]
            self.struct_encoder = nn.ModuleList(self.struct_encoder)
            self.struct_pooler = create_pooler(pooler_type="struct", config=config, args=args)
            assert isinstance(config.struct_fc_size, list)
            self.struct_linear = []
            input_size = output_size * len(config.struct_output_size)
            for idx in range(len(config.struct_fc_size)):
                linear = nn.Linear(input_size, config.struct_fc_size[idx])
                self.struct_linear.append(linear)
                self.struct_linear.append(create_activate(config.activate_func))
                input_size = config.struct_fc_size[idx]
            self.struct_linear = nn.ModuleList(self.struct_linear)
        # includes embedding encoder
        if args.has_embedding_encoder:
            self.embedding_pooler = create_pooler(pooler_type="embedding", config=config, args=args)
            assert isinstance(config.embedding_fc_size, list)
            self.embedding_linear = []
            input_size = config.embedding_input_size
            for idx in range(len(config.embedding_fc_size)):
                linear = nn.Linear(input_size, config.embedding_fc_size[idx])
                self.embedding_linear.append(linear)
                self.embedding_linear.append(create_activate(config.activate_func))
                input_size = config.embedding_fc_size[idx]
            self.embedding_linear = nn.ModuleList(self.embedding_linear)

        # weight assignment for addition of sequence, structure, structural embedding representation vector,
        # if none, concatenation, otherwise weighted sequence
        if args.has_seq_encoder and args.has_struct_encoder and args.has_embedding_encoder:
            if hasattr(config, "seq_weight") and hasattr(config, "struct_weight") and hasattr(config, "embedding_weight"):
                self.seq_weight = config.seq_weight
                self.struct_weight = config.struct_weight
                self.embedding_weight = config.embedding_weight
            else:
                self.seq_weight = None
                self.struct_weight = None
                self.embedding_weight = None
            assert self.seq_weight is None or self.seq_weight + self.struct_weight + self.embedding_weight == 1.0
            if self.seq_weight is None: # concat
                output_size = config.seq_fc_size[-1] + config.struct_fc_size[-1] + config.embedding_fc_size[-1]
            else: # add
                assert config.seq_fc_size[-1] == config.struct_fc_size[-1] == config.embedding_fc_size[-1]
                output_size = config.seq_fc_size[-1]
        elif args.has_seq_encoder and args.has_struct_encoder:
            if hasattr(config, "seq_weight") and hasattr(config, "struct_weight"):
                self.seq_weight = config.seq_weight
                self.struct_weight = config.struct_weight
            else:
                self.seq_weight = None
                self.struct_weight = None
            self.embedding_weight = None
            assert self.seq_weight is None or self.seq_weight + self.struct_weight == 1.0
            if self.seq_weight is None: # concat
                output_size = config.seq_fc_size[-1] + config.struct_fc_size[-1]
            else: # add
                assert config.seq_fc_size[-1] == config.struct_fc_size[-1]
                output_size = config.seq_fc_size[-1]
        elif args.has_seq_encoder and args.has_embedding_encoder:
            if hasattr(config, "seq_weight") and hasattr(config, "embedding_weight"):
                self.seq_weight = config.seq_weight
                self.embedding_weight = config.embedding_weight
            else:
                self.seq_weight = None
                self.embedding_weight = None
            self.struct_weight = None
            assert self.seq_weight is None or self.seq_weight + self.embedding_weight == 1.0
            if self.seq_weight is None: # concat
                output_size = config.seq_fc_size[-1] + config.embedding_fc_size[-1]
            else: # add
                assert config.seq_fc_size[-1] == config.embedding_fc_size[-1]
                output_size = config.seq_fc_size[-1]
        elif args.has_struct_encoder and args.has_embedding_encoder:
            if hasattr(config, "struct_weight") and hasattr(config, "embedding_weight"):
                self.struct_weight = config.struct_weight
                self.embedding_weight = config.embedding_weight
            else:
                self.struct_weight = None
                self.embedding_weight = None
            self.seq_weight = None
            assert self.struct_weight is None or self.struct_weight + self.embedding_weight == 1.0
            if self.struct_weight is None: # concat
                output_size = config.struct_fc_size[-1] + config.embedding_fc_size[-1]
            else: # add
                assert config.struct_fc_size[-1] == config.embedding_fc_size[-1]
                output_size = config.struct_fc_size[-1]
        else: # only one encoder
            self.seq_weight = None
            self.struct_weight = None
            self.embedding_weight = None
            output_size = config.seq_fc_size[-1] if args.has_seq_encoder else (config.struct_fc_size[-1] if args.has_struct_encoder else config.embedding_fc_size[-1])

        # dropout layer
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

        # output layer
        self.output_mode = args.output_mode
        if args and args.sigmoid:
            if args.output_mode in ["binary_class", "binary-class"]:
                self.classifier = nn.Linear(output_size, 1)
            else:
                self.classifier = nn.Linear(output_size, config.num_labels)
            self.output = nn.Sigmoid()
        else:
            self.classifier = nn.Linear(output_size, config.num_labels)
            if self.num_labels > 1:
                self.output = nn.Softmax(dim=1)
            else:
                self.output = None

        # loss function type
        self.loss_type = args.loss_type

        # positive weight
        if hasattr(config, "pos_weight"):
            self.pos_weight = config.pos_weight
        elif hasattr(args, "pos_weight"):
            self.pos_weight = args.pos_weight
        else:
            self.pos_weight = None

        if hasattr(config, "weight"):
            self.weight = config.weight
        elif hasattr(args, "weight"):
            self.weight = args.weight
        else:
            self.weight = None

        if self.output_mode in ["regression"]:
            self.loss_fct = MSELoss()
        elif self.output_mode in ["multi_label", "multi-label"]:
            if self.loss_type == "bce":
                if self.pos_weight:
                    # [1, 1, 1, ,1, 1...] length: self.num_labels
                    assert self.pos_weight.ndim == 1 and self.pos_weight.shape[0] == self.num_labels
                    self.loss_fct = BCEWithLogitsLoss(pos_weight=self.pos_weight)
                else:
                    self.loss_fct = BCEWithLogitsLoss(reduction=config.loss_reduction if hasattr(config, "loss_reduction") else "sum")
            elif self.loss_type == "asl":
                self.loss_fct = AsymmetricLossOptimized(gamma_neg=args.asl_gamma_neg if hasattr(args, "asl_gamma_neg") else 4,
                                                        gamma_pos=args.asl_gamma_pos if hasattr(args, "asl_gamma_pos") else 1,
                                                        clip=args.clip if hasattr(args, "clip") else 0.05,
                                                        eps=args.eps if hasattr(args, "eps") else 1e-8,
                                                        disable_torch_grad_focal_loss=args.disable_torch_grad_focal_loss if hasattr(args, "disable_torch_grad_focal_loss") else False)
            elif self.loss_type == "focal_loss":
                self.loss_fct = FocalLoss(alpha=args.focal_loss_alpha if hasattr(args, "focal_loss_alpha") else 1,
                                          gamma=args.focal_loss_gamma if hasattr(args, "focal_loss_gamma") else 0.25,
                                          normalization=False,
                                          reduce=args.focal_loss_reduce if hasattr(args, "focal_loss_reduce") else False)
            elif self.loss_type == "multilabel_cce":
                self.loss_fct = MultiLabel_CCE(normalization=False)
        elif self.output_mode in ["binary_class", "binary-class"]:
            if self.loss_type == "bce":
                if self.pos_weight:
                    # [0.9]
                    if isinstance(self.pos_weight, int):
                        self.pos_weight = torch.tensor([self.pos_weight], dtype=torch.long).to(args.device)
                    elif isinstance(self.pos_weight, float):
                        self.pos_weight = torch.tensor([self.pos_weight], dtype=torch.float32).to(args.device)

                    assert self.pos_weight.ndim == 1 and self.pos_weight.shape[0] == 1
                    self.loss_fct = BCEWithLogitsLoss(pos_weight=self.pos_weight)
                else:
                    self.loss_fct = BCEWithLogitsLoss()
            elif self.loss_type == "focal_loss":
                self.loss_fct = FocalLoss(alpha=args.focal_loss_alpha if hasattr(args, "focal_loss_alpha") else 1,
                                          gamma=args.focal_loss_gamma if hasattr(args, "focal_loss_gamma") else 0.25,
                                          normalization=False,
                                          reduce=args.focal_loss_reduce if hasattr(args, "focal_loss_reduce") else False)
        elif self.output_mode in ["multi_class", "multi-class"]:
            if self.weight:
                # [1, 1, 1, ,1, 1...] length: self.num_labels
                assert self.weight.ndim == 1 and self.weight.shape[0] == self.num_labels
                self.loss_fct = CrossEntropyLoss(weight=self.weight)
            else:
                self.loss_fct = CrossEntropyLoss()
        else:
            raise Exception("Not support output mode: %s." % self.output_mode)

        self.init_weights()