extensions/http-curl/protocols/RESTSender.cpp (166 lines of code) (raw):
/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "RESTSender.h"
#include <algorithm>
#include <memory>
#include <string>
#include <vector>
#include <utility>
#include <limits>
#include "utils/file/FileUtils.h"
#include "core/Resource.h"
#include "properties/Configuration.h"
#include "io/ZlibStream.h"
using namespace std::literals::chrono_literals;
namespace org::apache::nifi::minifi::c2 {
RESTSender::RESTSender(std::string name, const utils::Identifier &uuid)
: C2Protocol(std::move(name), uuid) {
}
void RESTSender::initialize(core::controller::ControllerServiceProvider* controller, const std::shared_ptr<Configure> &configure) {
C2Protocol::initialize(controller, configure);
RESTProtocol::initialize(controller, configure);
// base URL when one is not specified.
if (nullptr != configure) {
std::string update_str;
std::string ssl_context_service_str;
configure->get(Configuration::nifi_c2_rest_url, "c2.rest.url", rest_uri_);
configure->get(Configuration::nifi_c2_rest_url_ack, "c2.rest.url.ack", ack_uri_);
if (controller && configure->get(Configuration::nifi_c2_rest_ssl_context_service, "c2.rest.ssl.context.service", ssl_context_service_str)) {
if (auto service = controller->getControllerService(ssl_context_service_str)) {
ssl_context_service_ = std::static_pointer_cast<minifi::controllers::SSLContextService>(service);
}
}
if (nullptr == ssl_context_service_) {
std::string ssl_context_str;
if (configure->get(Configure::nifi_remote_input_secure, ssl_context_str) && org::apache::nifi::minifi::utils::StringUtils::toBool(ssl_context_str).value_or(false)) {
ssl_context_service_ = std::make_shared<minifi::controllers::SSLContextService>("RESTSenderSSL", configure);
ssl_context_service_->onEnable();
}
}
if (auto req_encoding_str = configure->get(Configuration::nifi_c2_rest_request_encoding)) {
if (auto req_encoding = magic_enum::enum_cast<RequestEncoding>(*req_encoding_str, magic_enum::case_insensitive)) {
logger_->log_debug("Using request encoding '%s'", std::string{magic_enum::enum_name(*req_encoding)});
req_encoding_ = *req_encoding;
} else {
logger_->log_error("Invalid request encoding '%s'", req_encoding_str.value());
req_encoding_ = RequestEncoding::none;
}
} else {
logger_->log_debug("Request encoding is not specified, using default '%s'", std::string{magic_enum::enum_name(RequestEncoding::none)});
req_encoding_ = RequestEncoding::none;
}
}
logger_->log_debug("Submitting to %s", rest_uri_);
}
C2Payload RESTSender::consumePayload(const std::string &url, const C2Payload &payload, Direction direction, bool /*async*/) {
std::optional<std::string> data;
if (direction == Direction::TRANSMIT && payload.getOperation() != Operation::transfer) {
// treat payload as json
data = serializeJsonRootPayload(payload);
}
return sendPayload(url, direction, payload, std::move(data));
}
C2Payload RESTSender::consumePayload(const C2Payload &payload, Direction direction, bool async) {
if (payload.getOperation() == Operation::acknowledge) {
return consumePayload(ack_uri_, payload, direction, async);
}
return consumePayload(rest_uri_, payload, direction, async);
}
void RESTSender::update(const std::shared_ptr<Configure> &) {
}
void RESTSender::setSecurityContext(extensions::curl::HTTPClient &client, const std::string &type, const std::string &url) {
// only use the SSL Context if we have a secure URL.
auto generatedService = std::make_shared<minifi::controllers::SSLContextService>("Service", configuration_);
generatedService->onEnable();
client.initialize(type, url, generatedService);
}
C2Payload RESTSender::sendPayload(const std::string& url, const Direction direction, const C2Payload &payload, std::optional<std::string> data,
const std::optional<std::vector<std::string>>& accepted_formats) {
if (url.empty()) {
return {payload.getOperation(), state::UpdateState::READ_ERROR};
}
// Client declared last to make sure callbacks are still available when client is destructed
extensions::curl::HTTPClient client(url, ssl_context_service_);
client.setKeepAliveProbe(extensions::curl::KeepAliveProbeData{2s, 2s});
client.setConnectionTimeout(2s);
auto setUpHttpRequest = [&](const std::string& http_method) {
client.set_request_method(http_method);
if (url.find("https://") == 0) {
if (!ssl_context_service_) {
setSecurityContext(client, http_method, url);
} else {
client.initialize(http_method, url, ssl_context_service_);
}
}
};
if (direction == Direction::TRANSMIT) {
setUpHttpRequest("POST");
if (payload.getOperation() == Operation::transfer) {
// treat nested payloads as files
for (const auto& file : payload.getNestedPayloads()) {
std::string filename = file.getLabel();
if (filename.empty()) {
throw std::logic_error("Missing filename");
}
auto file_cb = std::make_unique<utils::HTTPUploadByteArrayInputCallback>();
file_cb->write(file.getRawDataAsString());
client.addFormPart("application/octet-stream", "file", std::move(file_cb), filename);
}
} else {
auto data_input = std::make_unique<utils::HTTPUploadByteArrayInputCallback>();
if (data && req_encoding_ == RequestEncoding::gzip) {
io::BufferStream compressed_payload;
bool compression_success = [&] {
io::ZlibCompressStream compressor(gsl::make_not_null(&compressed_payload), io::ZlibCompressionFormat::GZIP, Z_BEST_COMPRESSION);
auto ret = compressor.write(as_bytes(std::span(data.value())));
if (ret != data->length()) {
return false;
}
compressor.close();
return compressor.isFinished();
}();
if (compression_success) {
data_input->setBuffer(compressed_payload.moveBuffer());
client.setRequestHeader("Content-Encoding", "gzip");
} else {
logger_->log_error("Failed to compress request body, falling back to no compression");
data_input->write(data.value());
}
} else {
data_input->write(data.value_or(""));
}
client.setPostSize(data_input->getBufferSize());
client.setUploadCallback(std::move(data_input));
}
} else {
// we do not need to set the upload callback
// since we are not uploading anything on a get
setUpHttpRequest("GET");
}
if (payload.getOperation() == Operation::transfer) {
auto read = std::make_unique<utils::HTTPReadCallback>(std::numeric_limits<size_t>::max());
client.setReadCallback(std::move(read));
if (accepted_formats && !accepted_formats->empty()) {
client.setRequestHeader("Accept", utils::StringUtils::join(", ", accepted_formats.value()));
}
} else {
// Due to a bug in MiNiFi C2 the Accept header is not handled properly thus we need to exclude it to be compatible
// TODO(lordgamez): The header should be re-added when the issue in MiNiFi C2 is fixed: https://issues.apache.org/jira/browse/NIFI-10535
// client.setRequestHeader("Accept", "application/json");
client.setContentType("application/json");
}
bool isOkay = client.submit();
int64_t respCode = client.getResponseCode();
const bool clientError = 400 <= respCode && respCode < 500;
const bool serverError = 500 <= respCode && respCode < 600;
if (clientError || serverError) {
logger_->log_error("Error response code '" "%" PRId64 "' from '%s'", respCode, url);
} else {
logger_->log_debug("Response code '" "%" PRId64 "' from '%s'", respCode, url);
}
const auto response_body_bytes = gsl::make_span(client.getResponseBody()).as_span<const std::byte>();
logger_->log_trace("Received response: \"%s\"", [&] {return utils::StringUtils::escapeUnprintableBytes(response_body_bytes);});
if (isOkay && !clientError && !serverError) {
if (accepted_formats) {
C2Payload response_payload(payload.getOperation(), state::UpdateState::READ_COMPLETE, true);
response_payload.setRawData(response_body_bytes);
return response_payload;
}
return parseJsonResponse(payload, response_body_bytes);
} else {
return {payload.getOperation(), state::UpdateState::READ_ERROR};
}
}
C2Payload RESTSender::fetch(const std::string& url, const std::vector<std::string>& accepted_formats, bool /*async*/) {
return sendPayload(url, Direction::RECEIVE, C2Payload(Operation::transfer, true), std::nullopt, accepted_formats);
}
REGISTER_RESOURCE(RESTSender, DescriptionOnly);
} // namespace org::apache::nifi::minifi::c2