fbgemm_gpu/src/merge_pooled_embeddings_cpu.cpp (33 lines of code) (raw):

/* * Copyright (c) Meta Platforms, Inc. and affiliates. * All rights reserved. * This source code is licensed under the BSD-style license found in the * LICENSE file in the root directory of this source tree. */ #include <ATen/ATen.h> #include <ATen/core/op_registration/op_registration.h> #include <c10/core/TensorOptions.h> #include <torch/library.h> #include "fbgemm_gpu/sparse_ops_utils.h" using Tensor = at::Tensor; namespace fbgemm_gpu { Tensor merge_pooled_embeddings_cpu( std::vector<Tensor> pooled_embeddings, int64_t /*uncat_dim_size*/, at::Device target_device, int64_t cat_dim = 1) { auto cat_host_0 = [&](const std::vector<Tensor>& ts) { int64_t n = 0; for (auto& t : ts) { n += t.numel(); } Tensor r; if (n == 0) { r = at::empty({n}); } else { r = at::empty({n}, ts[0].options()); } r.resize_(0); return at::cat_out(r, ts, cat_dim); // concat the tensor list in dim = 1 }; return cat_host_0(pooled_embeddings); } } // namespace fbgemm_gpu TORCH_LIBRARY_IMPL(fbgemm, CPU, m) { DISPATCH_TO_CPU( "merge_pooled_embeddings", fbgemm_gpu::merge_pooled_embeddings_cpu); }