in src/contrib/msc/framework/tensorrt/transform_tensorrt.cc [188:328]
Expr RewriteAttention(BlockBuilder builder, const Var& var, const Call& src_call,
const Map<Expr, Call>& new_calls, const String& config) {
const auto& call = new_calls.count(src_call) ? new_calls[src_call] : src_call;
const auto& in_dtype = ExprUtils::GetDataType(call->args[0]);
const auto* src_attrs = src_call->attrs.as<AttentionAttrs>();
// define dims
const auto& in_q_shape = ExprUtils::GetShape(call->args[0]);
const auto& in_v_shape = ExprUtils::GetShape(call->args[2]);
const auto& batch_size = in_q_shape[0];
const auto& seq_len = in_q_shape[1];
const auto& num_head = in_q_shape[2];
const auto& head_dim = in_q_shape[3];
const auto& seq_len_kv = in_v_shape[1];
const auto& head_dim_v = in_v_shape[3];
// create ops
static const Op& permute_dims_op = Op::Get("relax.permute_dims");
static const Op& reshape_op = Op::Get("relax.reshape");
static const Op& matmul_op = Op::Get("relax.matmul");
static const Op& multiply_op = Op::Get("relax.multiply");
static const Op& add_op = Op::Get("relax.add");
static const Op& divide_op = Op::Get("relax.divide");
static const Op& sqrt_op = Op::Get("relax.sqrt");
static const Op& softmax_op = Op::Get("relax.nn.softmax");
static const Op& tril_op = Op::Get("relax.tril");
static const Op& max_op = Op::Get("relax.max");
static const Op& sum_op = Op::Get("relax.sum");
static const Op& subtract_op = Op::Get("relax.subtract");
static const Op& exp_op = Op::Get("relax.exp");
// prepare q,k,v
auto permute_attrs = make_object<PermuteDimsAttrs>();
Array<Integer> axes{Integer(0), Integer(2), Integer(1), Integer(3)};
permute_attrs->axes = axes;
const auto& q_trans =
RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "q_trans"), permute_dims_op,
{call->args[0]}, Attrs(permute_attrs));
const auto& k_trans =
RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "k_trans"), permute_dims_op,
{call->args[1]}, Attrs(permute_attrs));
const auto& v_trans =
RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "v_trans"), permute_dims_op,
{call->args[2]}, Attrs(permute_attrs));
Array<PrimExpr> q_shape({batch_size * num_head, seq_len, head_dim});
const auto& q_reshape = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "q_reshape"),
reshape_op, {q_trans, ShapeExpr(q_shape)});
Array<PrimExpr> k_shape({batch_size * num_head, seq_len_kv, head_dim});
const auto& k_reshape = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "k_reshape"),
reshape_op, {k_trans, ShapeExpr(k_shape)});
Array<PrimExpr> v_shape({batch_size * num_head, seq_len_kv, head_dim_v});
const auto& v_reshape = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "v_reshape"),
reshape_op, {v_trans, ShapeExpr(v_shape)});
auto reduce_permute_attrs = make_object<PermuteDimsAttrs>();
Array<Integer> v_axes{Integer(0), Integer(2), Integer(1)};
reduce_permute_attrs->axes = v_axes;
// transpose for batch_matmul
const auto& k_reshape_trans =
RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "k_reshape_trans"),
permute_dims_op, {k_reshape}, Attrs(reduce_permute_attrs));
// calculate product
auto matmul_attrs = make_object<MatmulAttrs>();
matmul_attrs->out_dtype = in_dtype;
const auto& qk_prod =
RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "qk_prod"), matmul_op,
{q_reshape, k_reshape_trans}, Attrs(matmul_attrs));
Expr p_scale;
if (src_attrs->scale.defined()) {
double value = static_cast<double>(src_attrs->scale.value()->value);
const auto& scale = RewriteUtils::MakeConstant(builder, ExprUtils::GetSpanName(call, "scale"),
value, in_dtype, 3);
p_scale = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "p_scale"), multiply_op,
{qk_prod, scale});
} else {
double value = static_cast<double>(Downcast<Integer>(head_dim)->value);
const auto& scale = RewriteUtils::MakeConstant(builder, ExprUtils::GetSpanName(call, "scale"),
value, in_dtype, 3);
const auto& sqrt_scale = RewriteUtils::MakeCall(
builder, ExprUtils::GetSpanName(call, "sqrt_scale"), sqrt_op, {scale});
p_scale = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "p_scale"), divide_op,
{qk_prod, sqrt_scale});
}
// bias
Expr prod = p_scale;
if (call->args.size() == 4) {
Array<PrimExpr> exp_shape{batch_size, num_head, seq_len, seq_len_kv};
Array<PrimExpr> reduce_shape{batch_size * num_head, seq_len, seq_len_kv};
const auto& prod_exp = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "prod_exp"),
reshape_op, {prod, ShapeExpr(exp_shape)});
const auto& prod_add = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "prod_add"),
add_op, {prod_exp, call->args[3]});
prod = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "prod_reduce"), reshape_op,
{prod_add, ShapeExpr(reduce_shape)});
}
// causal_mask
Expr s_value;
if (!src_attrs->causal_mask.defined()) {
auto softmax_attrs = make_object<SoftmaxAttrs>();
softmax_attrs->axis = 2;
s_value = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "act"), softmax_op,
{prod}, Attrs(softmax_attrs));
} else {
const auto& causal_mask = src_attrs->causal_mask.value();
PrimValue tril_k;
if (causal_mask == "TopLeft") {
tril_k = PrimValue(Integer(0));
} else if (causal_mask == "BottomRight") {
tril_k = PrimValue(seq_len - seq_len_kv);
} else {
LOG_FATAL << "Unexpected causal_mask " << causal_mask;
}
const auto& p_masked = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "p_masked"),
tril_op, {prod, tril_k});
auto reduce_attrs = make_object<StatisticalAttrs>();
Array<Integer> axis{Integer(2)};
reduce_attrs->axis = axis;
reduce_attrs->keepdims = true;
const auto& p_max = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "p_max"),
max_op, {prod}, Attrs(reduce_attrs));
const auto& p_diff = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "p_diff"),
subtract_op, {p_masked, p_max});
const auto& p_exp =
RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "p_exp"), exp_op, {p_diff});
const auto& p_masked_exp = RewriteUtils::MakeCall(
builder, ExprUtils::GetSpanName(call, "p_masked_exp"), tril_op, {p_exp, tril_k});
const auto& p_masked_sum =
RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "p_masked_sum"), sum_op,
{p_masked_exp}, Attrs(reduce_attrs));
s_value = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "act"), divide_op,
{p_masked_exp, p_masked_sum});
}
// final calculation
const auto& o_prod = RewriteUtils::MakeCall(builder, ExprUtils::GetSpanName(call, "o_prod"),
matmul_op, {s_value, v_reshape}, Attrs(matmul_attrs));
Array<PrimExpr> o_shape{batch_size, num_head, seq_len, head_dim_v};
return Call(reshape_op, {o_prod, ShapeExpr(o_shape)}, Attrs(), call->sinfo_args, call->span);
}