def forward()

in aiops/RCRank/model/modules/FuseModel/Attention.py [0:0]


    def forward(self, sql, plan, log, metrics, sql_mask, plan_mask, mask=None, additional_mask=None, layer_cache=None, type=None, edge_feature=None, pair_indices=None):

        query = sql
        sql_key = sql
        sql_value = sql

        plan_key = plan
        plan_value = plan

        batch_size = query.size(0)
        dim_per_head = self.dim_per_head
        head_count = self.head_count

        def shape(x):
            return x.view(batch_size, -1, head_count, dim_per_head) \
                .transpose(1, 2)

        def unshape(x):
            return x.transpose(1, 2).contiguous() \
                .view(batch_size, -1, head_count * dim_per_head)

        sql_key_projected = self.linear_keys(sql_key)
        sql_value_projected = self.linear_values(sql_value)
        query_projected = self.linear_query(query)
        sql_key_shaped = shape(sql_key_projected)
        sql_value_shaped = shape(sql_value_projected)

        plan_key_projected = self.linear_plan_keys(plan_key)
        plan_value_projected = self.linear_plan_values(plan_value)
        plan_key_shaped = shape(plan_key_projected)
        plan_value_shaped = shape(plan_value_projected)

        query_shaped = shape(query_projected)
        query_len = query_shaped.size(2)
        sql_key_len = sql_key_shaped.size(2)
        plan_key_len = plan_key_shaped.size(2)

        # sql encoder
        sql_query_shaped = query_shaped / math.sqrt(dim_per_head)
        scores = torch.matmul(sql_query_shaped, sql_key_shaped.transpose(2, 3))
        top_score = scores.view(batch_size, scores.shape[1],
                                query_len, sql_key_len)[:, 0, :, :].contiguous()
        attn = self.softmax(scores)
        drop_attn = self.dropout_sql(attn)
        context = torch.matmul(drop_attn, sql_value_shaped)
        sql_context = unshape(context)

        # plan encoder
        sql_query_shaped = query_shaped / math.sqrt(dim_per_head)
        scores = torch.matmul(sql_query_shaped, plan_key_shaped.transpose(2, 3))
        attn = self.softmax(scores)
        drop_attn = self.dropout_plan(attn)
        context = torch.matmul(drop_attn, plan_value_shaped)
        plan_context = unshape(context)

        # metrics encoder
        if self.use_metrics:
            metrics = metrics.unsqueeze(1)
            metrics_key = metrics
            metrics_value = metrics

            metrics_key_projected = self.linear_metrics_keys(metrics_key)
            metrics_value_projected = self.linear_metrics_values(metrics_value)
            metrics_key_shaped = shape(metrics_key_projected)
            metrics_value_shaped = shape(metrics_value_projected)

            metrics_key_len = metrics_key_shaped.size(2)


            sql_query_shaped = query_shaped / math.sqrt(dim_per_head)
            scores = torch.matmul(sql_query_shaped, metrics_key_shaped.transpose(2, 3))
            attn = torch.sigmoid(scores)
            drop_attn = self.dropout_metrics(attn)
            context = torch.matmul(drop_attn, metrics_value_shaped)
            metrics_context = unshape(context)


        if self.use_log:
            log = log.unsqueeze(1)
            log_key = log
            log_value = log

            log_key_projected = self.linear_log_keys(log_key)
            log_value_projected = self.linear_log_values(log_value)
            log_key_shaped = shape(log_key_projected)
            log_value_shaped = shape(log_value_projected)

            sql_query_shaped = query_shaped / math.sqrt(dim_per_head)
            scores = torch.matmul(sql_query_shaped, log_key_shaped.transpose(2, 3))
            attn = torch.sigmoid(scores) 
            drop_attn = self.dropout_log(attn)
            context = torch.matmul(drop_attn, log_value_shaped)
            log_context = unshape(context)

        context = torch.cat([sql_context, plan_context], dim=-1)

        if self.use_metrics:
            context = torch.cat([context, metrics_context], dim=-1)
        if self.use_log:
            context = torch.cat([context, log_context], dim=-1)

        output = self.final_linear(context)

        return output, top_score