in tensorflow/libs/octree_property_op.cc [35:125]
void Compute(OpKernelContext* context) override {
auto octree_ptr = context->input(0).flat<int8>().data();
OctreeParser octree_;
octree_.set_gpu(octree_ptr);
Tensor buf0, buf1;
const void* property_ptr = nullptr;
int length = octree_.info().node_num(depth_), channel = 1;
bool key32 = std::is_same<uintk, uint32>::value;
DataType key_dtype = key32 ? DataType::DT_UINT32 : DataType::DT_UINT64;
if (property_name_ == "key") {
property_ptr = octree_.key_gpu(depth_);
channel = octree_.info().channel(OctreeInfo::kKey);
CHECK_EQ(dtype_, key_dtype);
} else if (property_name_ == "xyz") {
property_ptr = octree_.key_gpu(depth_);
channel = octree_.info().channel(OctreeInfo::kKey);
if (!octree_.info().is_key2xyz()) {
OP_REQUIRES_OK(context, context->allocate_temp(
key_dtype, TensorShape({length}), &buf0));
uintk* ptr = buf0.flat<uintk>().data();
key2xyz_gpu(ptr, (const uintk*)property_ptr, length, depth_);
property_ptr = ptr;
}
CHECK_EQ(dtype_, key_dtype);
} else if (property_name_ == "index") {
const uintk* key_ptr = octree_.key_gpu(depth_);
channel = octree_.info().channel(OctreeInfo::kKey);
OP_REQUIRES_OK(context, context->allocate_temp(
DT_INT32, TensorShape({length}), &buf0));
int* idx_ptr = buf0.flat<int>().data();
key2idx_gpu(idx_ptr, key_ptr, length);
property_ptr = idx_ptr;
CHECK_EQ(dtype_, DataType::DT_INT32);
} else if (property_name_ == "child") {
property_ptr = octree_.children_gpu(depth_);
channel = octree_.info().channel(OctreeInfo::kChild);
CHECK_EQ(dtype_, DataType::DT_INT32);
} else if (property_name_ == "neigh") {
property_ptr = octree_.neighbor_gpu(depth_);
channel = octree_.info().channel(OctreeInfo::kNeigh);
CHECK_EQ(dtype_, DataType::DT_INT32);
} else if (property_name_ == "feature") {
property_ptr = octree_.feature_gpu(depth_);
channel = octree_.info().channel(OctreeInfo::kFeature);
CHECK_EQ(dtype_, DataType::DT_FLOAT);
} else if (property_name_ == "label") {
property_ptr = octree_.label_gpu(depth_);
channel = octree_.info().channel(OctreeInfo::kLabel);
CHECK_EQ(dtype_, DataType::DT_FLOAT);
} else if (property_name_ == "split") {
property_ptr = octree_.split_gpu(depth_);
channel = octree_.info().channel(OctreeInfo::kSplit);
CHECK_EQ(dtype_, DataType::DT_FLOAT);
} else {
LOG(FATAL) << "Unsupported Octree Property";
}
CHECK_EQ(channel_, channel) << " The specified channel_ is wrong."
<< " Property name: " << property_name_;
Tensor* out_tensor;
TensorShape out_shape({channel, length});
OP_REQUIRES_OK(context,
context->allocate_output(0, out_shape, &out_tensor));
int num = channel * length;
switch (dtype_) {
case DataType::DT_UINT32: {
auto ptr = out_tensor->flat<uint32>().data();
cudaMemcpy(ptr, property_ptr, sizeof(uint32) * num,
cudaMemcpyDeviceToDevice);
} break;
case DataType::DT_UINT64: {
auto ptr = out_tensor->flat<uint64>().data();
cudaMemcpy(ptr, property_ptr, sizeof(uint64) * num,
cudaMemcpyDeviceToDevice);
} break;
case DataType::DT_INT32: {
auto ptr = out_tensor->flat<int>().data();
cudaMemcpy(ptr, property_ptr, sizeof(int) * num,
cudaMemcpyDeviceToDevice);
} break;
case DataType::DT_FLOAT: {
auto ptr = out_tensor->flat<float>().data();
cudaMemcpy(ptr, property_ptr, sizeof(float) * num,
cudaMemcpyDeviceToDevice);
} break;
default:
LOG(FATAL) << "Invalid DataType";
}
}