Expr RewriteAttention()

in src/contrib/msc/framework/tensorrt/transform_tensorrt.cc [188:328]


Expr RewriteAttention(BlockBuilder builder, const Var& var, const Call& src_call,
                      const Map<Expr, Call>& new_calls, const String& config) {
  const auto& call = new_calls.count(src_call) ? new_calls[src_call] : src_call;
  const auto& in_dtype = ExprUtils::GetDataType(call->args[0]);
  const auto* src_attrs = src_call->attrs.as<AttentionAttrs>();

  // define dims
  const auto& in_q_shape = ExprUtils::GetShape(call->args[0]);
  const auto& in_v_shape = ExprUtils::GetShape(call->args[2]);
  const auto& batch_size = in_q_shape[0];
  const auto& seq_len = in_q_shape[1];
  const auto& num_head = in_q_shape[2];
  const auto& head_dim = in_q_shape[3];
  const auto& seq_len_kv = in_v_shape[1];
  const auto& head_dim_v = in_v_shape[3];

  // create ops
  static const Op& permute_dims_op = Op::Get("relax.permute_dims");
  static const Op& reshape_op = Op::Get("relax.reshape");
  static const Op& matmul_op = Op::Get("relax.matmul");
  static const Op& multiply_op = Op::Get("relax.multiply");
  static const Op& add_op = Op::Get("relax.add");
  static const Op& divide_op = Op::Get("relax.divide");
  static const Op& sqrt_op = Op::Get("relax.sqrt");
  static const Op& softmax_op = Op::Get("relax.nn.softmax");
  static const Op& tril_op = Op::Get("relax.tril");
  static const Op& max_op = Op::Get("relax.max");
  static const Op& sum_op = Op::Get("relax.sum");
  static const Op& subtract_op = Op::Get("relax.subtract");
  static const Op& exp_op = Op::Get("relax.exp");

  // prepare q,k,v
  auto permute_attrs = make_object<PermuteDimsAttrs>();
  Array<Integer> axes{Integer(0), Integer(2), Integer(1), Integer(3)};
  permute_attrs->axes = axes;
  const auto& q_trans =
      RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "q_trans"), permute_dims_op,
                             {call->args[0]}, Attrs(permute_attrs));
  const auto& k_trans =
      RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "k_trans"), permute_dims_op,
                             {call->args[1]}, Attrs(permute_attrs));
  const auto& v_trans =
      RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "v_trans"), permute_dims_op,
                             {call->args[2]}, Attrs(permute_attrs));
  Array<PrimExpr> q_shape({batch_size * num_head, seq_len, head_dim});
  const auto& q_reshape = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "q_reshape"),
                                                 reshape_op, {q_trans, ShapeExpr(q_shape)});
  Array<PrimExpr> k_shape({batch_size * num_head, seq_len_kv, head_dim});
  const auto& k_reshape = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "k_reshape"),
                                                 reshape_op, {k_trans, ShapeExpr(k_shape)});
  Array<PrimExpr> v_shape({batch_size * num_head, seq_len_kv, head_dim_v});
  const auto& v_reshape = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "v_reshape"),
                                                 reshape_op, {v_trans, ShapeExpr(v_shape)});
  auto reduce_permute_attrs = make_object<PermuteDimsAttrs>();
  Array<Integer> v_axes{Integer(0), Integer(2), Integer(1)};
  reduce_permute_attrs->axes = v_axes;
  // transpose for batch_matmul
  const auto& k_reshape_trans =
      RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "k_reshape_trans"),
                             permute_dims_op, {k_reshape}, Attrs(reduce_permute_attrs));

  // calculate product
  auto matmul_attrs = make_object<MatmulAttrs>();
  matmul_attrs->out_dtype = in_dtype;
  const auto& qk_prod =
      RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "qk_prod"), matmul_op,
                             {q_reshape, k_reshape_trans}, Attrs(matmul_attrs));
  Expr p_scale;
  if (src_attrs->scale.defined()) {
    double value = static_cast<double>(src_attrs->scale.value()->value);
    const auto& scale = RewriteUtils::MakeConstant(builder, ExprUtils::GetSpanName(call, "scale"),
                                                   value, in_dtype, 3);
    p_scale = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "p_scale"), multiply_op,
                                     {qk_prod, scale});
  } else {
    double value = static_cast<double>(Downcast<Integer>(head_dim)->value);
    const auto& scale = RewriteUtils::MakeConstant(builder, ExprUtils::GetSpanName(call, "scale"),
                                                   value, in_dtype, 3);
    const auto& sqrt_scale = RewriteUtils::MakeCall(
        builder, ExprUtils::GetSpanName(call, "sqrt_scale"), sqrt_op, {scale});
    p_scale = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "p_scale"), divide_op,
                                     {qk_prod, sqrt_scale});
  }

  // bias
  Expr prod = p_scale;
  if (call->args.size() == 4) {
    Array<PrimExpr> exp_shape{batch_size, num_head, seq_len, seq_len_kv};
    Array<PrimExpr> reduce_shape{batch_size * num_head, seq_len, seq_len_kv};
    const auto& prod_exp = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "prod_exp"),
                                                  reshape_op, {prod, ShapeExpr(exp_shape)});
    const auto& prod_add = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "prod_add"),
                                                  add_op, {prod_exp, call->args[3]});
    prod = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "prod_reduce"), reshape_op,
                                  {prod_add, ShapeExpr(reduce_shape)});
  }

  // causal_mask
  Expr s_value;
  if (!src_attrs->causal_mask.defined()) {
    auto softmax_attrs = make_object<SoftmaxAttrs>();
    softmax_attrs->axis = 2;
    s_value = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "act"), softmax_op,
                                     {prod}, Attrs(softmax_attrs));
  } else {
    const auto& causal_mask = src_attrs->causal_mask.value();
    PrimValue tril_k;
    if (causal_mask == "TopLeft") {
      tril_k = PrimValue(Integer(0));
    } else if (causal_mask == "BottomRight") {
      tril_k = PrimValue(seq_len - seq_len_kv);
    } else {
      LOG_FATAL << "Unexpected causal_mask " << causal_mask;
    }
    const auto& p_masked = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "p_masked"),
                                                  tril_op, {prod, tril_k});
    auto reduce_attrs = make_object<StatisticalAttrs>();
    Array<Integer> axis{Integer(2)};
    reduce_attrs->axis = axis;
    reduce_attrs->keepdims = true;
    const auto& p_max = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "p_max"),
                                               max_op, {prod}, Attrs(reduce_attrs));
    const auto& p_diff = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "p_diff"),
                                                subtract_op, {p_masked, p_max});
    const auto& p_exp =
        RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "p_exp"), exp_op, {p_diff});
    const auto& p_masked_exp = RewriteUtils::MakeCall(
        builder, ExprUtils::GetSpanName(call, "p_masked_exp"), tril_op, {p_exp, tril_k});
    const auto& p_masked_sum =
        RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "p_masked_sum"), sum_op,
                               {p_masked_exp}, Attrs(reduce_attrs));
    s_value = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "act"), divide_op,
                                     {p_masked_exp, p_masked_sum});
  }

  // final calculation
  const auto& o_prod = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "o_prod"),
                                              matmul_op, {s_value, v_reshape}, Attrs(matmul_attrs));
  Array<PrimExpr> o_shape{batch_size, num_head, seq_len, head_dim_v};
  return Call(reshape_op, {o_prod, ShapeExpr(o_shape)}, Attrs(), call->sinfo_args, call->span);
}