bool ImageInputOp::GetImageAndLabelAndInfoFromDBValue()

in caffe2/image/image_input_op.h [450:790]


bool ImageInputOp<Context>::GetImageAndLabelAndInfoFromDBValue(
    const string& value,
    cv::Mat* img,
    PerImageArg& info,
    int item_id,
    std::mt19937* randgen) {
  //
  // recommend using --caffe2_use_fatal_for_enforce=1 when using ImageInputOp
  // as this function runs on a worker thread and the exceptions from
  // CAFFE_ENFORCE are silently dropped by the thread worker functions
  //
  cv::Mat src;

  // Use the default information for images
  info = default_arg_;
  if (use_caffe_datum_) {
    // The input is a caffe datum format.
    CaffeDatum datum;
    CAFFE_ENFORCE(datum.ParseFromString(value));

    prefetched_label_.mutable_data<int>()[item_id] = datum.label();
    if (datum.encoded()) {
      // encoded image in datum.
      // count the number of exceptions from opencv imdecode
      try {
        src = cv::imdecode(
            cv::Mat(
                1,
                datum.data().size(),
                CV_8UC1,
                const_cast<char*>(datum.data().data())),
            color_ ? cv::IMREAD_COLOR : cv::IMREAD_GRAYSCALE);
        if (src.rows == 0 || src.cols == 0) {
          num_decode_errors_in_batch_++;
          src = cv::Mat::zeros(cv::Size(224, 224), CV_8UC3);
        }
      } catch (cv::Exception& e) {
        num_decode_errors_in_batch_++;
        src = cv::Mat::zeros(cv::Size(224, 224), CV_8UC3);
      }
    } else {
      // Raw image in datum.
      CAFFE_ENFORCE(datum.channels() == 3 || datum.channels() == 1);

      int src_c = datum.channels();
      src.create(
          datum.height(), datum.width(), (src_c == 3) ? CV_8UC3 : CV_8UC1);

      if (src_c == 1) {
        memcpy(src.ptr<uchar>(0), datum.data().data(), datum.data().size());
      } else {
        // Datum stores things in CHW order, let's do HWC for images to make
        // things more consistent with conventional image storage.
        for (const auto c : c10::irange(3)) {
          const char* datum_buffer =
              datum.data().data() + datum.height() * datum.width() * c;
          uchar* ptr = src.ptr<uchar>(0) + c;
          for (const auto h : c10::irange(datum.height())) {
            for (const auto w : c10::irange(datum.width())) {
              *ptr = *(datum_buffer++);
              ptr += 3;
            }
          }
        }
      }
    }
  } else {
    // The input is a caffe2 format.
    TensorProtos protos;
    CAFFE_ENFORCE(protos.ParseFromString(value));
    const TensorProto& image_proto = protos.protos(0);
    const TensorProto& label_proto = protos.protos(1);
    // add handle protos
    vector<TensorProto> additional_output_protos;
    int start = additional_inputs_offset_;
    int end = start + additional_inputs_count_;
    for (const auto i : c10::irange(start, end)) {
      additional_output_protos.push_back(protos.protos(i));
    }

    if (protos.protos_size() == end + 1) {
      // We have bounding box information
      const TensorProto& bounding_proto = protos.protos(end);
      DCHECK_EQ(bounding_proto.data_type(), TensorProto::INT32);
      DCHECK_EQ(bounding_proto.int32_data_size(), 4);
      info.bounding_params.valid = true;
      info.bounding_params.ymin = bounding_proto.int32_data(0);
      info.bounding_params.xmin = bounding_proto.int32_data(1);
      info.bounding_params.height = bounding_proto.int32_data(2);
      info.bounding_params.width = bounding_proto.int32_data(3);
    }

    if (image_proto.data_type() == TensorProto::STRING) {
      // encoded image string.
      DCHECK_EQ(image_proto.string_data_size(), 1);
      const string& encoded_image_str = image_proto.string_data(0);
      int encoded_size = encoded_image_str.size();
      // We use a cv::Mat to wrap the encoded str so we do not need a copy.
      // count the number of exceptions from opencv imdecode
      try {
        src = cv::imdecode(
            cv::Mat(
                1,
                &encoded_size,
                CV_8UC1,
                const_cast<char*>(encoded_image_str.data())),
            color_ ? cv::IMREAD_COLOR : cv::IMREAD_GRAYSCALE);
        if (src.rows == 0 || src.cols == 0) {
          num_decode_errors_in_batch_++;
          src = cv::Mat::zeros(cv::Size(224, 224), CV_8UC3);
        }
      } catch (cv::Exception& e) {
        num_decode_errors_in_batch_++;
        src = cv::Mat::zeros(cv::Size(224, 224), CV_8UC3);
      }
    } else if (image_proto.data_type() == TensorProto::BYTE) {
      // raw image content.
      int src_c = (image_proto.dims_size() == 3) ? image_proto.dims(2) : 1;
      CAFFE_ENFORCE(src_c == 3 || src_c == 1);

      src.create(
          image_proto.dims(0),
          image_proto.dims(1),
          (src_c == 3) ? CV_8UC3 : CV_8UC1);
      memcpy(
          src.ptr<uchar>(0),
          image_proto.byte_data().data(),
          image_proto.byte_data().size());
    } else {
      LOG(FATAL) << "Unknown image data type.";
    }

    // TODO: if image decoding was unsuccessful, set label to 0
    if (label_proto.data_type() == TensorProto::FLOAT) {
      if (label_type_ == SINGLE_LABEL || label_type_ == SINGLE_LABEL_WEIGHTED) {
        DCHECK_EQ(label_proto.float_data_size(), 1);
        prefetched_label_.mutable_data<float>()[item_id] =
            label_proto.float_data(0);
      } else if (label_type_ == MULTI_LABEL_SPARSE) {
        float* label_data =
            prefetched_label_.mutable_data<float>() + item_id * num_labels_;
        memset(label_data, 0, sizeof(float) * num_labels_);
        for (const auto i : c10::irange(label_proto.float_data_size())) {
          label_data[(int)label_proto.float_data(i)] = 1.0;
        }
      } else if (label_type_ == MULTI_LABEL_WEIGHTED_SPARSE) {
        const TensorProto& weight_proto = protos.protos(2);
        float* label_data =
            prefetched_label_.mutable_data<float>() + item_id * num_labels_;
        memset(label_data, 0, sizeof(float) * num_labels_);
        for (const auto i : c10::irange(label_proto.float_data_size())) {
          label_data[(int)label_proto.float_data(i)] =
              weight_proto.float_data(i);
        }
      } else if (
          label_type_ == MULTI_LABEL_DENSE || label_type_ == EMBEDDING_LABEL) {
        CAFFE_ENFORCE(label_proto.float_data_size() == num_labels_);
        float* label_data =
            prefetched_label_.mutable_data<float>() + item_id * num_labels_;
        for (const auto i : c10::irange(label_proto.float_data_size())) {
          label_data[i] = label_proto.float_data(i);
        }
      } else {
        LOG(ERROR) << "Unknown label type:" << label_type_;
      }
    } else if (label_proto.data_type() == TensorProto::INT32) {
      if (label_type_ == SINGLE_LABEL || label_type_ == SINGLE_LABEL_WEIGHTED) {
        DCHECK_EQ(label_proto.int32_data_size(), 1);
        prefetched_label_.mutable_data<int>()[item_id] =
            label_proto.int32_data(0);
      } else if (label_type_ == MULTI_LABEL_SPARSE) {
        int* label_data =
            prefetched_label_.mutable_data<int>() + item_id * num_labels_;
        memset(label_data, 0, sizeof(int) * num_labels_);
        for (const auto i : c10::irange(label_proto.int32_data_size())) {
          label_data[label_proto.int32_data(i)] = 1;
        }
      } else if (label_type_ == MULTI_LABEL_WEIGHTED_SPARSE) {
        const TensorProto& weight_proto = protos.protos(2);
        float* label_data =
            prefetched_label_.mutable_data<float>() + item_id * num_labels_;
        memset(label_data, 0, sizeof(float) * num_labels_);
        for (const auto i : c10::irange(label_proto.int32_data_size())) {
          label_data[label_proto.int32_data(i)] = weight_proto.float_data(i);
        }
      } else if (
          label_type_ == MULTI_LABEL_DENSE || label_type_ == EMBEDDING_LABEL) {
        CAFFE_ENFORCE(label_proto.int32_data_size() == num_labels_);
        int* label_data =
            prefetched_label_.mutable_data<int>() + item_id * num_labels_;
        for (const auto i : c10::irange(label_proto.int32_data_size())) {
          label_data[i] = label_proto.int32_data(i);
        }
      } else {
        LOG(ERROR) << "Unknown label type:" << label_type_;
      }
    } else {
      LOG(FATAL) << "Unsupported label data type.";
    }

    for (const auto i : c10::irange(additional_output_protos.size())) {
      auto additional_output_proto = additional_output_protos[i];
      if (additional_output_proto.data_type() == TensorProto::FLOAT) {
        float* additional_output =
            prefetched_additional_outputs_[i].template mutable_data<float>() +
            item_id * additional_output_proto.float_data_size();

        for (const auto j : c10::irange(additional_output_proto.float_data_size())) {
          additional_output[j] = additional_output_proto.float_data(j);
        }
      } else if (additional_output_proto.data_type() == TensorProto::INT32) {
        int* additional_output =
            prefetched_additional_outputs_[i].template mutable_data<int>() +
            item_id * additional_output_proto.int32_data_size();

        for (const auto j : c10::irange(additional_output_proto.int32_data_size())) {
          additional_output[j] = additional_output_proto.int32_data(j);
        }
      } else if (additional_output_proto.data_type() == TensorProto::INT64) {
        int64_t* additional_output =
            prefetched_additional_outputs_[i].template mutable_data<int64_t>() +
            item_id * additional_output_proto.int64_data_size();

        for (const auto j : c10::irange(additional_output_proto.int64_data_size())) {
          additional_output[j] = additional_output_proto.int64_data(j);
        }
      } else if (additional_output_proto.data_type() == TensorProto::UINT8) {
        uint8_t* additional_output =
            prefetched_additional_outputs_[i].template mutable_data<uint8_t>() +
            item_id * additional_output_proto.int32_data_size();

        for (const auto j : c10::irange(additional_output_proto.int32_data_size())) {
          additional_output[j] =
              static_cast<uint8_t>(additional_output_proto.int32_data(j));
        }
      } else {
        LOG(FATAL) << "Unsupported output type.";
      }
    }
  }

  //
  // convert source to the color format requested from Op
  //
  int out_c = color_ ? 3 : 1;
  if (out_c == src.channels()) {
    *img = src;
  } else {
    cv::cvtColor(
        src, *img, (out_c == 1) ? cv::COLOR_BGR2GRAY : cv::COLOR_GRAY2BGR);
  }

  // Note(Yangqing): I believe that the mat should be created continuous.
  CAFFE_ENFORCE(img->isContinuous());

  // Sanity check now that we decoded everything

  // Ensure that the bounding box is legit
  if (info.bounding_params.valid &&
      (src.rows < info.bounding_params.ymin + info.bounding_params.height ||
       src.cols < info.bounding_params.xmin + info.bounding_params.width)) {
    info.bounding_params.valid = false;
  }

  // Apply the bounding box if requested
  if (info.bounding_params.valid) {
    // If we reach here, we know the parameters are sane
    cv::Rect bounding_box(
        info.bounding_params.xmin,
        info.bounding_params.ymin,
        info.bounding_params.width,
        info.bounding_params.height);
    *img = (*img)(bounding_box);

    /*
    LOG(INFO) << "Did bounding with ymin:"
              << info.bounding_params.ymin << " xmin:" <<
    info.bounding_params.xmin
              << " height:" << info.bounding_params.height
              << " width:" << info.bounding_params.width << "\n";
    LOG(INFO) << "Bounded matrix: " << img;
    */
  } else {
    // LOG(INFO) << "No bounding\n";
  }

  cv::Mat scaled_img;
  bool inception_scale_jitter = false;
  if (scale_jitter_type_ == INCEPTION_STYLE) {
    if (!is_test_) {
      // Inception-stype scale jittering is only used for training
      inception_scale_jitter =
          RandomSizedCropping<Context>(img, crop_, randgen);
      // if a random crop is still not found, do simple random cropping later
    }
  }

  if ((scale_jitter_type_ == NO_SCALE_JITTER) ||
      (scale_jitter_type_ == INCEPTION_STYLE && !inception_scale_jitter)) {
    int scaled_width, scaled_height;
    int scale_to_use = scale_ > 0 ? scale_ : minsize_;

    // set the random minsize
    if (random_scaling_) {
      scale_to_use = std::uniform_int_distribution<>(
          random_scale_[0], random_scale_[1])(*randgen);
    }

    if (warp_) {
      scaled_width = scale_to_use;
      scaled_height = scale_to_use;
    } else if (img->rows > img->cols) {
      scaled_width = scale_to_use;
      scaled_height = static_cast<float>(img->rows) * scale_to_use / img->cols;
    } else {
      scaled_height = scale_to_use;
      scaled_width = static_cast<float>(img->cols) * scale_to_use / img->rows;
    }
    if ((scale_ > 0 &&
         (scaled_height != img->rows || scaled_width != img->cols)) ||
        (scaled_height > img->rows || scaled_width > img->cols)) {
      // We rescale in all cases if we are using scale_
      // but only to make the image bigger if using minsize_
      /*
      LOG(INFO) << "Scaling to " << scaled_width << " x " << scaled_height
                << " From " << img->cols << " x " << img->rows;
      */
      cv::resize(
          *img,
          scaled_img,
          cv::Size(scaled_width, scaled_height),
          0,
          0,
          cv::INTER_AREA);
      *img = scaled_img;
    }
  }

  // TODO(Yangqing): return false if any error happens.
  return true;
}