in curiosity/baseline_models.py [0:0]
def __init__(self,
vocab: Vocabulary,
use_glove: bool,
use_bert: bool,
bert_trainable: bool,
bert_name: str,
mention_embedder: TextFieldEmbedder,
dialog_context: FeedForward,
fact_ranker: FactRanker,
dropout_prob: float,
sender_emb_size: int,
act_emb_size: int,
fact_loss_weight: float,
fact_pos_weight: float,
utter_embedder: TextFieldEmbedder = None,
utter_context: Seq2VecEncoder = None,
disable_known_entities: bool = False,
disable_dialog_acts: bool = False,
disable_likes: bool = False,
disable_facts: bool = False):
super().__init__(vocab)
self._disable_known_entities = disable_known_entities
self._disable_dialog_acts = disable_dialog_acts
self._clamp_dialog_acts = Clamp(should_clamp=disable_dialog_acts)
self._disable_likes = disable_likes
self._clamp_likes = Clamp(should_clamp=disable_likes)
self._disable_facts = disable_facts
self._clamp_facts = Clamp(should_clamp=disable_facts)
self._fact_loss_weight = fact_loss_weight
self._fact_pos_weight = fact_pos_weight
if int(use_glove) + int(use_bert) != 1:
raise ValueError('Cannot use bert and glove together')
self._use_glove = use_glove
self._use_bert = use_bert
self._bert_trainable = bert_trainable
self._bert_name = bert_name
self._utter_embedder = utter_embedder
self._utter_context = utter_context
# Bert encoder is embedder + context
if use_bert:
# Not trainable for now
print('Using BERT encoder ...')
self._bert_encoder = BertEncoder(
self._bert_name, requires_grad=bert_trainable
)
self._dist_utter_context = None
self._utter_dim = self._bert_encoder.get_output_dim()
else:
print('Using LSTM encoder ...')
self._bert_encoder = None
self._dist_utter_context = TimeDistributed(self._utter_context)
self._utter_dim = self._utter_context.get_output_dim()
self._dialog_context = dialog_context
self._fact_ranker = fact_ranker
# Easier to code as cross entropy with two classes
# Likes are per message, for only assistant messages
self._like_classifier = nn.Linear(
self._dialog_context.get_output_dim(), 2
)
self._like_accuracy = CategoricalAccuracy()
self._like_loss_metric = Average()
# Dialog acts are per message, for all messages
# This network predicts the dialog act of the current message
# for both student and teacher
self._da_classifier = nn.Sequential(
nn.Linear(
self._utter_dim + self._dialog_context.get_output_dim(),
self._dialog_context.get_output_dim()
),
GeLU(),
nn.Linear(
self._dialog_context.get_output_dim(),
vocab.get_vocab_size(DIALOG_ACT_LABELS)
)
)
self._da_bce_loss = torch.nn.BCEWithLogitsLoss(reduction='none')
self._da_f1_metric = MultilabelMicroF1()
self._da_loss_metric = Average()
# This network predicts what the next action should be
# It predicts for user and assistant since there isn't a real
# reason to restrict that
self._policy_classifier = nn.Sequential(
nn.Linear(
self._dialog_context.get_output_dim(),
self._dialog_context.get_output_dim()
),
GeLU(),
nn.Linear(
self._dialog_context.get_output_dim(),
vocab.get_vocab_size(DIALOG_ACT_LABELS)
)
)
self._policy_bce_loss = torch.nn.BCEWithLogitsLoss(reduction='none')
self._policy_f1_metric = MultilabelMicroF1()
self._policy_loss_metric = Average()
self._fact_mrr = MeanReciprocalRank()
self._fact_loss_metric = Average()
self._dropout_prob = dropout_prob
self._dropout = nn.Dropout(dropout_prob)
# Fact use is much less prevalant, about 9 times less so, so factor that in
self._fact_bce_loss = torch.nn.BCEWithLogitsLoss(
reduction='none',
pos_weight=torch.Tensor([self._fact_pos_weight])
)