def __init__()

in aiops/RCRank/model/modules/FuseModel/Attention.py [0:0]


    def __init__(self, head_count, model_dim, dropout=0.1, use_metrics=True, use_log=True):
        self.use_metrics = use_metrics
        self.use_log = use_log
        assert model_dim % head_count == 0
        self.dim_per_head = model_dim // head_count
        self.model_dim = model_dim

        super(MultiHeadedAttention, self).__init__()
        self.head_count = head_count
        self.linear_keys = nn.Linear(model_dim,
                                     head_count * self.dim_per_head)
        self.linear_values = nn.Linear(model_dim,
                                       head_count * self.dim_per_head)
        self.linear_query = nn.Linear(model_dim,
                                      head_count * self.dim_per_head)
        
        self.linear_plan_keys = nn.Linear(model_dim,
                                     head_count * self.dim_per_head)
        self.linear_plan_values = nn.Linear(model_dim,
                                       head_count * self.dim_per_head)
        
        self.linear_log_keys = nn.Linear(model_dim,
                                     head_count * self.dim_per_head)
        self.linear_log_values = nn.Linear(model_dim,
                                       head_count * self.dim_per_head)
        
        self.linear_metrics_keys = nn.Linear(model_dim,
                                     head_count * self.dim_per_head)
        self.linear_metrics_values = nn.Linear(model_dim,
                                       head_count * self.dim_per_head)

        self.softmax = nn.Softmax(dim=-1)
        self.dropout_sql = nn.Dropout(dropout)
        self.dropout_plan = nn.Dropout(dropout)
        self.dropout_log = nn.Dropout(dropout)
        self.dropout_metrics = nn.Dropout(dropout)

        model_num = 4
        if not self.use_metrics: model_num -= 1
        if not self.use_log: model_num -= 1
        self.final_linear = nn.Linear(model_dim * model_num, model_dim)

        self.edge_project = nn.Sequential(nn.Linear(model_dim, model_dim),
                                          SSP(),
                                          nn.Linear(model_dim, model_dim // 2))
        self.edge_update = nn.Sequential(nn.Linear(model_dim * 2, model_dim),
                                         SSP(),
                                         nn.Linear(model_dim, model_dim))