at::Tensor interpolate()

in nestedtensor/csrc/python_functions.cpp [52:125]


at::Tensor interpolate(
    at::Tensor input,
    c10::optional<std::vector<std::vector<int64_t>>> size,
    c10::optional<at::ArrayRef<double>> scale_factor,
    c10::optional<std::string> mode,
    c10::optional<bool> align_corners) {
  F::InterpolateFuncOptions::mode_t int_mode;
  if (mode.value() == "nearest" || mode.value() == "none") {
    int_mode = torch::kNearest;
  } else if (mode.value() == "trilinear") {
    int_mode = torch::kTrilinear;
  } else if (mode.value() == "linear") {
    int_mode = torch::kLinear;
  } else if (mode.value() == "bicubic") {
    int_mode = torch::kBicubic;
  } else if (mode.value() == "area") {
    int_mode = torch::kArea;
  } else if (mode.value() == "bilinear") {
    int_mode = torch::kBilinear;
  } else {
    throw std::runtime_error(
        "Unexpected mode for interpolate: " + mode.value());
  }

  auto options = F::InterpolateFuncOptions().mode(int_mode);
  if (align_corners.has_value()) {
    options.align_corners() = align_corners.value();
  }

  // Either scale factor or size can be passed
  if (scale_factor.has_value()) {
    options = options.scale_factor(scale_factor.value().vec());
    return map_nested_tensor(
        [&options](at::Tensor input_tensor) {
          return F::interpolate(input_tensor.unsqueeze(0), options).squeeze(0);
        },
        input);
  }

  // Get input leaves count
  auto leaves_count = reduce_nested_tensor(
      [](at::Tensor leaf, int64_t input) { return input + 1; }, 0, input);

  if (size.has_value()) {
    // There can be either 1 size for all tensor or an individual size value per
    // tensor
    if (size.value().size() != 1 && size.value().size() != leaves_count) {
      throw std::runtime_error(
          "Interpolate has to take either 1 size tuple or same amount as leaves in Nested Tensor.");
    }

    if (size.value().size() == 1) {
      return map_nested_tensor(
          [&options, &size](at::Tensor input_tensor) {
            options = options.size(size.value()[0]);
            return F::interpolate(input_tensor.unsqueeze(0), options)
                .squeeze(0);
          },
          input);
    } else {
      int size_i = 0;
      return map_nested_tensor(
          [&options, &size_i, &size](at::Tensor input_tensor) {
            options = options.size(size.value()[size_i]);
            size_i++;
            return F::interpolate(input_tensor.unsqueeze(0), options)
                .squeeze(0);
          },
          input);
    }
  }

  throw std::runtime_error("Either size or scale_factor should be defined.");
}