openr/nl/NetlinkRuleMessage.cpp (110 lines of code) (raw):
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
*
* This source code is licensed under the MIT license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <folly/logging/xlog.h>
#include <openr/nl/NetlinkRuleMessage.h>
namespace openr::fbnl {
NetlinkRuleMessage::NetlinkRuleMessage() : NetlinkMessageBase() {}
NetlinkRuleMessage::~NetlinkRuleMessage() {
CHECK(rulePromise_.isFulfilled());
}
void
NetlinkRuleMessage::rcvdRule(Rule&& rule) {
rcvdRules_.emplace_back(std::move(rule));
}
void
NetlinkRuleMessage::setReturnStatus(int status) {
if (status == 0) {
rulePromise_.setValue(std::move(rcvdRules_));
} else {
rulePromise_.setValue(folly::makeUnexpected(status));
}
NetlinkMessageBase::setReturnStatus(status);
}
void
NetlinkRuleMessage::init(int type) {
if (type != RTM_NEWRULE && type != RTM_DELRULE && type != RTM_GETRULE) {
XLOG(ERR) << "Incorrect Netlink message type";
return;
}
// initialize netlink header
msghdr_->nlmsg_len = NLMSG_LENGTH(sizeof(struct fib_rule_hdr));
msghdr_->nlmsg_type = type;
msghdr_->nlmsg_flags = NLM_F_REQUEST | NLM_F_ACK;
if (type == RTM_GETRULE) {
// Get all rules
msghdr_->nlmsg_flags |= NLM_F_DUMP;
}
if (type == RTM_NEWRULE) {
// We create new rule or replace existing
msghdr_->nlmsg_flags |= NLM_F_CREATE;
msghdr_->nlmsg_flags |= NLM_F_REPLACE;
}
// intialize the rule message header
auto nlmsgAlen = NLMSG_ALIGN(sizeof(struct nlmsghdr));
rulehdr_ = reinterpret_cast<struct fib_rule_hdr*>((char*)msghdr_ + nlmsgAlen);
}
Rule
NetlinkRuleMessage::parseMessage(const struct nlmsghdr* nlmsg) {
const struct fib_rule_hdr* const ruleEntry =
reinterpret_cast<struct fib_rule_hdr*>(NLMSG_DATA(nlmsg));
const uint16_t family = ruleEntry->family;
const uint8_t action = ruleEntry->action;
// table is uint8_t
const uint32_t table = static_cast<uint32_t>(ruleEntry->table);
// construct rule
Rule rule(family, action, table);
const struct rtattr* ruleAttr;
auto ruleAttrLen = RTM_PAYLOAD(nlmsg);
// process all rule attributes
for (ruleAttr = RTM_RTA(ruleEntry); RTA_OK(ruleAttr, ruleAttrLen);
ruleAttr = RTA_NEXT(ruleAttr, ruleAttrLen)) {
switch (ruleAttr->rta_type) {
case FRA_FWMARK: {
rule.setFwmark(*(reinterpret_cast<uint32_t*> RTA_DATA(ruleAttr)));
} break;
case FRA_TABLE: {
rule.setTable(*(reinterpret_cast<uint32_t*> RTA_DATA(ruleAttr)));
} break;
case FRA_PRIORITY: {
rule.setPriority(*(reinterpret_cast<uint32_t*> RTA_DATA(ruleAttr)));
} break;
}
}
XLOG(DBG3) << "Netlink parsed rule message. " << rule.str();
return rule;
}
int
NetlinkRuleMessage::addRule(const Rule& rule) {
init(RTM_NEWRULE);
return addRuleAttributes(rule);
}
int
NetlinkRuleMessage::deleteRule(const Rule& rule) {
init(RTM_DELRULE);
return addRuleAttributes(rule);
}
int
NetlinkRuleMessage::addRuleAttributes(const Rule& rule) {
int status{0};
const uint32_t table = rule.getTable();
// set rulehdr fields
rulehdr_->table = table < 256 ? table : RT_TABLE_COMPAT;
rulehdr_->action = rule.getAction();
rulehdr_->family = rule.getFamily();
// add attributes
if ((status = addAttributes(
FRA_TABLE,
reinterpret_cast<const char*>(&table),
sizeof(uint32_t)))) {
return status;
}
if (rule.getFwmark()) {
const uint32_t fwmark = rule.getFwmark().value();
if ((status = addAttributes(
FRA_FWMARK,
reinterpret_cast<const char*>(&fwmark),
sizeof(uint32_t)))) {
return status;
}
}
if (rule.getPriority()) {
const uint32_t priority = rule.getPriority().value();
if ((status = addAttributes(
FRA_PRIORITY,
reinterpret_cast<const char*>(&priority),
sizeof(uint32_t)))) {
return status;
}
}
return status;
}
} // namespace openr::fbnl