absl::Status DecodeLin16WaveAsFloatVector()

in tensorflow_lite_support/cc/task/audio/utils/wav_io.cc [114:223]


absl::Status DecodeLin16WaveAsFloatVector(const std::string& wav_string,
                                          std::vector<float>* float_values,
                                          uint32_t* sample_count,
                                          uint16_t* channel_count,
                                          uint32_t* sample_rate) {
  int offset = 0;
  RETURN_IF_ERROR(ExpectText(wav_string, kRiffChunkId, &offset));
  uint32_t total_file_size;
  RETURN_IF_ERROR(ReadValue<uint32_t>(wav_string, &total_file_size, &offset));
  RETURN_IF_ERROR(ExpectText(wav_string, kRiffType, &offset));
  RETURN_IF_ERROR(ExpectText(wav_string, kFormatChunkId, &offset));
  uint32_t format_chunk_size;
  RETURN_IF_ERROR(ReadValue<uint32_t>(wav_string, &format_chunk_size, &offset));
  if ((format_chunk_size != 16) && (format_chunk_size != 18)) {
    return absl::InvalidArgumentError(absl::StrFormat(
        "Bad format chunk size for WAV: Expected 16 or 18, but got %" PRIu32,
        format_chunk_size));
  }
  uint16_t audio_format;
  RETURN_IF_ERROR(ReadValue<uint16_t>(wav_string, &audio_format, &offset));
  if (audio_format != 1) {
    return absl::InvalidArgumentError(absl::StrFormat(
        "Bad audio format for WAV: Expected 1 (PCM), but got %" PRIu16,
        audio_format));
  }
  RETURN_IF_ERROR(ReadValue<uint16_t>(wav_string, channel_count, &offset));
  if (*channel_count < 1) {
    return absl::InvalidArgumentError(absl::StrFormat(
        "Bad number of channels for WAV: Expected at least 1, but got %" PRIu16,
        *channel_count));
  }
  RETURN_IF_ERROR(ReadValue<uint32_t>(wav_string, sample_rate, &offset));
  uint32_t bytes_per_second;
  RETURN_IF_ERROR(ReadValue<uint32_t>(wav_string, &bytes_per_second, &offset));
  uint16_t bytes_per_sample;
  RETURN_IF_ERROR(ReadValue<uint16_t>(wav_string, &bytes_per_sample, &offset));
  // Confusingly, bits per sample is defined as holding the number of bits for
  // one channel, unlike the definition of sample used elsewhere in the WAV
  // spec. For example, bytes per sample is the memory needed for all channels
  // for one point in time.
  uint16_t bits_per_sample;
  RETURN_IF_ERROR(ReadValue<uint16_t>(wav_string, &bits_per_sample, &offset));
  if (bits_per_sample != 16) {
    return absl::InvalidArgumentError(
        absl::StrFormat("Can only read 16-bit WAV files, but received %" PRIu16,
                        bits_per_sample));
  }
  const uint32_t expected_bytes_per_sample =
      ((bits_per_sample * *channel_count) + 7) / 8;
  if (bytes_per_sample != expected_bytes_per_sample) {
    return absl::InvalidArgumentError(
        absl::StrFormat("Bad bytes per sample in WAV header: Expected %" PRIu32
                        " but got %" PRIu16,
                        expected_bytes_per_sample, bytes_per_sample));
  }
  const uint32_t expected_bytes_per_second = bytes_per_sample * *sample_rate;
  if (bytes_per_second != expected_bytes_per_second) {
    return absl::InvalidArgumentError(
        absl::StrFormat("Bad bytes per second in WAV header: Expected %" PRIu32
                        " but got %" PRIu32 " (sample_rate=%" PRIu32
                        ", bytes_per_sample=%" PRIu16 ")",
                        expected_bytes_per_second, bytes_per_second,
                        *sample_rate, bytes_per_sample));
  }
  if (format_chunk_size == 18) {
    // Skip over this unused section.
    offset += 2;
  }

  bool was_data_found = false;
  while (offset < wav_string.size()) {
    std::string chunk_id;
    RETURN_IF_ERROR(ReadString(wav_string, 4, &chunk_id, &offset));
    uint32_t chunk_size;
    RETURN_IF_ERROR(ReadValue<uint32_t>(wav_string, &chunk_size, &offset));
    if (chunk_size > std::numeric_limits<int32_t>::max()) {
      return absl::InvalidArgumentError(absl::StrFormat(
          "WAV data chunk '%s' is too large: %" PRIu32
          " bytes, but the limit is %d",
          chunk_id.c_str(), chunk_size, std::numeric_limits<int32_t>::max()));
    }
    if (chunk_id == kDataChunkId) {
      if (was_data_found) {
        return absl::InvalidArgumentError(
            "More than one data chunk found in WAV");
      }
      was_data_found = true;
      *sample_count = chunk_size / bytes_per_sample;
      const uint32_t data_count = *sample_count * *channel_count;
      int unused_new_offset = 0;
      // Validate that the data exists before allocating space for it
      // (prevent easy OOM errors).
      RETURN_IF_ERROR(IncrementOffset(offset, sizeof(int16_t) * data_count,
                                      wav_string.size(), &unused_new_offset));
      float_values->resize(data_count);
      for (int i = 0; i < data_count; ++i) {
        int16_t single_channel_value = 0;
        RETURN_IF_ERROR(
            ReadValue<int16_t>(wav_string, &single_channel_value, &offset));
        (*float_values)[i] = Int16SampleToFloat(single_channel_value);
      }
    } else {
      offset += chunk_size;
    }
  }
  if (!was_data_found) {
    return absl::InvalidArgumentError("No data chunk found in WAV");
  }
  return absl::OkStatus();
}