at::Tensor bt_min_mha()

in nestedtensor/csrc/cuda/mha.cpp [22:101]


at::Tensor bt_min_mha(
    int64_t num_heads,
    int64_t head_dim,
    double dropout_p,
    bool training,
    at::Tensor query,
    at::Tensor key,
    at::Tensor value,
    at::Tensor attr_kernel,
    at::Tensor attr_bias,
    double scaling,
    at::Tensor out_proj_weight,
    at::Tensor out_proj_bias) {
  // TODO: Assert that max seq_len is 1024!
  TORCH_CHECK(get_dim(query) == 3, "query needs to be 3 dim.");
  TORCH_CHECK(get_dim(key) == 3, "key needs to be 3 dim.");
  TORCH_CHECK(get_dim(value) == 3, "value needs to be 3 dim.");
  TORCH_CHECK(get_nested_dim(query) == 1, "Query nested dim isn't 1.");
  TORCH_CHECK(get_nested_dim(key) == 1, "Key nested dim isn't 1.");
  TORCH_CHECK(get_nested_dim(value) == 1, "Value nested dim isn't 1.");
  // TORCH_CHECK(in_proj_bias, "Input projection bias needs to be defined.");
  // auto opt_sizes = get_opt_sizes(query);
  // if (!opt_sizes[2]) {
  //   throw std::runtime_error("query's third dimension must be regular.");
  // }
  // TODO: Add explicit check that verifies query, key and value are the same
  // auto start = std::chrono::system_clock::now();
  auto options =
      torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA);
  at::Tensor input_mask = to_mask(query, 2);
  input_mask = input_mask.to(options);
  int64_t batch_size = input_mask.size(0);
  int64_t seq_len = input_mask.size(1);
  int64_t embedding_dim = head_dim * num_heads; //*(opt_sizes[2]);
  int64_t head_num = num_heads;
  int64_t size_per_head = embedding_dim / head_num;
  auto float_options =
      torch::TensorOptions().dtype(torch::kFloat).device(torch::kCUDA);
  at::cuda::CUDAStream defaultStream = at::cuda::getDefaultCUDAStream();
  at::cuda::setCurrentCUDAStream(defaultStream);

  at::Tensor packed = at::matmul(query, attr_kernel.t()) + attr_bias;

  // TODO: Move into implementation of chunk for NestedTensor
  at::Tensor packed_buf = get_buffer(packed).contiguous().reshape({-1, 3 * embedding_dim});
  std::vector<at::Tensor> packed_chunks = packed_buf.chunk(3, -1);
  at::Tensor q_buf_ = packed_chunks[0].contiguous().reshape({-1});
  at::Tensor k_buf_ = packed_chunks[1].contiguous().reshape({-1});
  at::Tensor v_buf_ = packed_chunks[2].contiguous().reshape({-1});
  at::Tensor q = wrap_buffer(std::move(q_buf_), get_efficient_nested_size(query), get_efficient_nested_stride(query));
  at::Tensor k = wrap_buffer(std::move(k_buf_), get_efficient_nested_size(query), get_efficient_nested_stride(query));
  at::Tensor v = wrap_buffer(std::move(v_buf_), get_efficient_nested_size(query), get_efficient_nested_stride(query));

  at::Tensor query_buf = to_padded_tensor(q, 0).contiguous();
  at::Tensor key_buf = to_padded_tensor(k, 0).contiguous();
  at::Tensor val_buf = to_padded_tensor(v, 0).contiguous();
  query_buf = query_buf.reshape({batch_size, seq_len, head_num, size_per_head}).transpose(1, 2);
  key_buf =     key_buf.reshape({batch_size, seq_len, head_num, size_per_head}).transpose(1, 2);
  val_buf =     val_buf.reshape({batch_size, seq_len, head_num, size_per_head}).transpose(1, 2);

  key_buf = key_buf.transpose(2, 3);
  at::Tensor attn_output_weights = at::matmul(query_buf, key_buf).contiguous();

  at::Tensor attr_mask = input_mask.view({-1, 1, 1, seq_len}).to(float_options);
  attr_mask = attr_mask * attr_mask.transpose(2, 3);

  nteffectivetransformer::cuda::softmax_kernel_kernelLauncher<float>(
      attn_output_weights.data_ptr<float>(),
      attr_mask.data_ptr<float>(),
      batch_size,
      head_num,
      seq_len,
      (float)(scaling),
      defaultStream);

  auto attn_output = at::matmul(attn_output_weights, val_buf).contiguous();
  attn_output = attn_output.transpose(1, 2).reshape({batch_size, seq_len, embedding_dim}).contiguous();
  at::Tensor attr_out = from_padded_tensor(attn_output, get_efficient_nested_size(query));
  return at::matmul(attr_out, out_proj_weight.t());
}