Status ParseNumpyHeader()

in tensorflow_io/core/kernels/numpy_kernels.cc [87:230]


Status ParseNumpyHeader(io::InputStreamInterface* stream,
                        ::tensorflow::DataType* dtype,
                        std::vector<int64>* shape) {
  string descr;
  bool fortran_order = false;

  tstring magic;
  TF_RETURN_IF_ERROR(stream->ReadNBytes(6, &magic));
  if (magic != "\x93NUMPY") {
    return errors::InvalidArgument("numpy file header magic number invalid");
  }
  tstring version;
  TF_RETURN_IF_ERROR(stream->ReadNBytes(2, &version));
  // TODO (yongtang): Support 2.0 which use 4 bytes for length.
  if (!(version[0] == 1 || version[1] == 0)) {
    return errors::InvalidArgument(
        "numpy file version only support 1.0: ", version[0], ".", version[1]);
  }
  tstring chunk;
  TF_RETURN_IF_ERROR(stream->ReadNBytes(2, &chunk));
  int64 length = (uint64)(chunk[0]) + ((uint64)chunk[1] << 8);
  if ((magic.size() + version.size() + chunk.size() + length) % 16 != 0) {
    return errors::InvalidArgument(
        "numpy file header length is not aligned properly: ", length);
  }
  tstring tdict;
  TF_RETURN_IF_ERROR(stream->ReadNBytes(length, &tdict));
  string dict = tdict;
  // {'descr': '<i8', 'fortran_order': False, 'shape': (4,), }\x20...\n
  if (dict.back() != '\n') {
    return errors::InvalidArgument("numpy file header should end with '\\n'");
  }
  dict.pop_back();
  while (dict.back() == '\x20') {
    dict.pop_back();
  }
  TrimNumpyHeader(dict);
  if (!(dict.front() == '{' && dict.back() == '}')) {
    return errors::InvalidArgument("numpy file header error: ", dict);
  }
  dict = dict.substr(1, dict.size() - 2);
  TrimNumpyHeader(dict);

  std::vector<std::pair<size_t, std::string>> positions;
  positions.push_back(std::pair<size_t, std::string>(dict.size(), ""));
  // find "'descr': ", "'fortran_order': ", "'shape': "
  std::vector<std::string> keys{"descr", "fortran_order", "shape"};
  for (auto const& key : keys) {
    size_t p = dict.find("'" + key + "': ");
    if (p == std::string::npos) {
      return errors::InvalidArgument("numpy file header error: ", dict);
    }
    std::pair<size_t, std::string> position_pair(p, key);
    positions.push_back(std::pair<size_t, std::string>(p, key));
  }
  std::sort(positions.begin(), positions.end());
  for (int i = 0; i < positions.size() - 1; i++) {
    std::string key = positions[i].second;
    // "'<descr|fortran_order|shape>': "
    size_t value_offset = positions[i].first + positions[i].second.size() + 4;
    size_t value_length = positions[i + 1].first - value_offset;
    std::string value = dict.substr(value_offset, value_length);
    TrimNumpyHeader(value);
    if (value.back() == ',') {
      value.pop_back();
    }
    if (key == "descr") {
      // "'([<>|])([ifuc])(\\d+)'"
      if (!(value.front() == '\'' && value.back() == '\'')) {
        return errors::InvalidArgument("numpy file header error: ", dict);
      }
      value = value.substr(1, value.size() - 2);
      descr = value;
      if (!(value[0] == '<' || value[0] == '>' || value[0] == '|')) {
        return errors::InvalidArgument("numpy file header error: ", dict);
      }
      if (!(value[1] == 'i' || value[1] == 'f' || value[1] == 'u' ||
            value[1] == 'c')) {
        return errors::InvalidArgument("numpy file header error: ", dict);
      }
      value = value.substr(2);
      size_t p = 0;
      int n = std::stoul(value, &p);
      if (p != value.size() || n == 0) {
        return errors::InvalidArgument("numpy file header error: ", dict);
      }
    }
    if (key == "fortran_order") {
      if (value != "True" && value != "False") {
        return errors::InvalidArgument("numpy file header error: ", dict);
      }
      fortran_order = (value == "True");
    }
    if (key == "shape") {
      if (!(value.front() == '(' && value.back() == ')')) {
        return errors::InvalidArgument("numpy file header error: ", dict);
      }
      value = value.substr(1, value.size() - 2);
      TrimNumpyHeader(value);
      shape->clear();
      while (value.size() != 0) {
        size_t p = value.find(',');
        string number = value.substr(0, p);
        TrimNumpyHeader(number);
        value = (p == std::string::npos) ? "" : value.substr(p + 1);
        TrimNumpyHeader(value);
        int dim = std::stoul(number, &p);
        if (p != number.size() || dim == 0) {
          return errors::InvalidArgument("numpy file header error: ", dict);
        }
        shape->push_back(dim);
      }
    }
  }

  *dtype = ::tensorflow::DataType::DT_INVALID;
  if (!fortran_order) {
    if (descr == "|b1") {
      *dtype = ::tensorflow::DataType::DT_BOOL;
    } else if (descr == "|i1") {
      *dtype = ::tensorflow::DataType::DT_INT8;
    } else if (descr == "<i2") {
      *dtype = ::tensorflow::DataType::DT_INT16;
    } else if (descr == "<i4") {
      *dtype = ::tensorflow::DataType::DT_INT32;
    } else if (descr == "<i8") {
      *dtype = ::tensorflow::DataType::DT_INT64;
    } else if (descr == "|u1") {
      *dtype = ::tensorflow::DataType::DT_UINT8;
    } else if (descr == "<u2") {
      *dtype = ::tensorflow::DataType::DT_UINT16;
    } else if (descr == "<u4") {
      *dtype = ::tensorflow::DataType::DT_UINT32;
    } else if (descr == "<u8") {
      *dtype = ::tensorflow::DataType::DT_UINT64;
    } else if (descr == "<f4") {
      *dtype = ::tensorflow::DataType::DT_FLOAT;
    } else if (descr == "<f8") {
      *dtype = ::tensorflow::DataType::DT_DOUBLE;
    }
  }

  return Status::OK();
}