def __init__()

in curiosity/baseline_models.py [0:0]


    def __init__(self,
                 vocab: Vocabulary,
                 use_glove: bool,
                 use_bert: bool,
                 bert_trainable: bool,
                 bert_name: str,
                 mention_embedder: TextFieldEmbedder,
                 dialog_context: FeedForward,
                 fact_ranker: FactRanker,
                 dropout_prob: float,
                 sender_emb_size: int,
                 act_emb_size: int,
                 fact_loss_weight: float,
                 fact_pos_weight: float,
                 utter_embedder: TextFieldEmbedder = None,
                 utter_context: Seq2VecEncoder = None,
                 disable_known_entities: bool = False,
                 disable_dialog_acts: bool = False,
                 disable_likes: bool = False,
                 disable_facts: bool = False):
        super().__init__(vocab)
        self._disable_known_entities = disable_known_entities
        self._disable_dialog_acts = disable_dialog_acts
        self._clamp_dialog_acts = Clamp(should_clamp=disable_dialog_acts)
        self._disable_likes = disable_likes
        self._clamp_likes = Clamp(should_clamp=disable_likes)
        self._disable_facts = disable_facts
        self._clamp_facts = Clamp(should_clamp=disable_facts)

        self._fact_loss_weight = fact_loss_weight
        self._fact_pos_weight = fact_pos_weight

        if int(use_glove) + int(use_bert) != 1:
            raise ValueError('Cannot use bert and glove together')

        self._use_glove = use_glove
        self._use_bert = use_bert
        self._bert_trainable = bert_trainable
        self._bert_name = bert_name
        self._utter_embedder = utter_embedder
        self._utter_context = utter_context
        # Bert encoder is embedder + context
        if use_bert:
            # Not trainable for now
            print('Using BERT encoder ...')
            self._bert_encoder = BertEncoder(
                self._bert_name, requires_grad=bert_trainable
            )
            self._dist_utter_context = None
            self._utter_dim = self._bert_encoder.get_output_dim()
        else:
            print('Using LSTM encoder ...')
            self._bert_encoder = None
            self._dist_utter_context = TimeDistributed(self._utter_context)
            self._utter_dim = self._utter_context.get_output_dim()
        self._dialog_context = dialog_context
        self._fact_ranker = fact_ranker
        # Easier to code as cross entropy with two classes
        # Likes are per message, for only assistant messages
        self._like_classifier = nn.Linear(
            self._dialog_context.get_output_dim(), 2
        )
        self._like_accuracy = CategoricalAccuracy()
        self._like_loss_metric = Average()

        # Dialog acts are per message, for all messages
        # This network predicts the dialog act of the current message
        # for both student and teacher
        self._da_classifier = nn.Sequential(
            nn.Linear(
                self._utter_dim + self._dialog_context.get_output_dim(),
                self._dialog_context.get_output_dim()
            ),
            GeLU(),
            nn.Linear(
                self._dialog_context.get_output_dim(),
                vocab.get_vocab_size(DIALOG_ACT_LABELS)
            )
        )
        self._da_bce_loss = torch.nn.BCEWithLogitsLoss(reduction='none')
        self._da_f1_metric = MultilabelMicroF1()
        self._da_loss_metric = Average()

        # This network predicts what the next action should be
        # It predicts for user and assistant since there isn't a real
        # reason to restrict that
        self._policy_classifier = nn.Sequential(
            nn.Linear(
                self._dialog_context.get_output_dim(),
                self._dialog_context.get_output_dim()
            ),
            GeLU(),
            nn.Linear(
                self._dialog_context.get_output_dim(),
                vocab.get_vocab_size(DIALOG_ACT_LABELS)
            )
        )
        self._policy_bce_loss = torch.nn.BCEWithLogitsLoss(reduction='none')
        self._policy_f1_metric = MultilabelMicroF1()
        self._policy_loss_metric = Average()

        self._fact_mrr = MeanReciprocalRank()
        self._fact_loss_metric = Average()
        self._dropout_prob = dropout_prob
        self._dropout = nn.Dropout(dropout_prob)
        # Fact use is much less prevalant, about 9 times less so, so factor that in
        self._fact_bce_loss = torch.nn.BCEWithLogitsLoss(
            reduction='none',
            pos_weight=torch.Tensor([self._fact_pos_weight])
        )