in nestedtensor/csrc/python_functions.cpp [52:125]
at::Tensor interpolate(
at::Tensor input,
c10::optional<std::vector<std::vector<int64_t>>> size,
c10::optional<at::ArrayRef<double>> scale_factor,
c10::optional<std::string> mode,
c10::optional<bool> align_corners) {
F::InterpolateFuncOptions::mode_t int_mode;
if (mode.value() == "nearest" || mode.value() == "none") {
int_mode = torch::kNearest;
} else if (mode.value() == "trilinear") {
int_mode = torch::kTrilinear;
} else if (mode.value() == "linear") {
int_mode = torch::kLinear;
} else if (mode.value() == "bicubic") {
int_mode = torch::kBicubic;
} else if (mode.value() == "area") {
int_mode = torch::kArea;
} else if (mode.value() == "bilinear") {
int_mode = torch::kBilinear;
} else {
throw std::runtime_error(
"Unexpected mode for interpolate: " + mode.value());
}
auto options = F::InterpolateFuncOptions().mode(int_mode);
if (align_corners.has_value()) {
options.align_corners() = align_corners.value();
}
// Either scale factor or size can be passed
if (scale_factor.has_value()) {
options = options.scale_factor(scale_factor.value().vec());
return map_nested_tensor(
[&options](at::Tensor input_tensor) {
return F::interpolate(input_tensor.unsqueeze(0), options).squeeze(0);
},
input);
}
// Get input leaves count
auto leaves_count = reduce_nested_tensor(
[](at::Tensor leaf, int64_t input) { return input + 1; }, 0, input);
if (size.has_value()) {
// There can be either 1 size for all tensor or an individual size value per
// tensor
if (size.value().size() != 1 && size.value().size() != leaves_count) {
throw std::runtime_error(
"Interpolate has to take either 1 size tuple or same amount as leaves in Nested Tensor.");
}
if (size.value().size() == 1) {
return map_nested_tensor(
[&options, &size](at::Tensor input_tensor) {
options = options.size(size.value()[0]);
return F::interpolate(input_tensor.unsqueeze(0), options)
.squeeze(0);
},
input);
} else {
int size_i = 0;
return map_nested_tensor(
[&options, &size_i, &size](at::Tensor input_tensor) {
options = options.size(size.value()[size_i]);
size_i++;
return F::interpolate(input_tensor.unsqueeze(0), options)
.squeeze(0);
},
input);
}
}
throw std::runtime_error("Either size or scale_factor should be defined.");
}