in torch_xla/csrc/init_python_bindings.cpp [441:502]
py::object RecordReadExample(
const std::shared_ptr<xla::util::RecordReader>& reader) {
auto make_r1_size = [](int64_t size) -> std::vector<int64_t> {
return std::vector<int64_t>({size});
};
xla::util::RecordReader::Data value;
if (!RecordRead(reader, &value)) {
return py::none();
}
tensorflow::Example exmsg;
if (!exmsg.ParseFromArray(value.data(), value.size())) {
XLA_ERROR() << "Unable to parse TF example from " << reader->path();
}
auto example = py::dict();
for (auto& name_feat : exmsg.features().feature()) {
switch (name_feat.second.kind_case()) {
case tensorflow::Feature::kBytesList: {
const tensorflow::BytesList& bvalue = name_feat.second.bytes_list();
if (bvalue.value_size() == 1) {
const std::string& svalue = bvalue.value(0);
at::Tensor data = at::empty(make_r1_size(svalue.size()),
at::TensorOptions(at::kChar));
std::memcpy(data.data_ptr<int8_t>(), svalue.data(), svalue.size());
example[py::str(name_feat.first)] =
torch::autograd::make_variable(data);
} else {
auto tlist = py::list(bvalue.value_size());
for (int i = 0; i < bvalue.value_size(); ++i) {
const std::string& svalue = bvalue.value(i);
at::Tensor data = at::empty(make_r1_size(svalue.size()),
at::TensorOptions(at::kChar));
std::memcpy(data.data_ptr<int8_t>(), svalue.data(), svalue.size());
tlist[i] = torch::autograd::make_variable(data);
}
example[py::str(name_feat.first)] = tlist;
}
} break;
case tensorflow::Feature::kFloatList: {
const tensorflow::FloatList& fvalue = name_feat.second.float_list();
at::Tensor data = at::empty(make_r1_size(fvalue.value_size()),
at::TensorOptions(at::kFloat));
std::memcpy(data.data_ptr<float>(), fvalue.value().data(),
fvalue.value_size() * sizeof(float));
example[py::str(name_feat.first)] =
torch::autograd::make_variable(data);
} break;
case tensorflow::Feature::kInt64List: {
const tensorflow::Int64List& ivalue = name_feat.second.int64_list();
at::Tensor data = at::empty(make_r1_size(ivalue.value_size()),
at::TensorOptions(at::kLong));
std::memcpy(data.data_ptr<int64_t>(), ivalue.value().data(),
ivalue.value_size() * sizeof(int64_t));
example[py::str(name_feat.first)] =
torch::autograd::make_variable(data);
} break;
default:
XLA_ERROR() << "Unknown data type from " << reader->path();
}
}
return example;
}