bool DoNetlinkCollection()

in auomscollect.cpp [147:304]


bool DoNetlinkCollection(SPSCDataQueue& raw_queue, std::shared_ptr<Metric>& bytes_metric, std::shared_ptr<Metric>& records_metric, std::shared_ptr<Metric>& lost_bytes_metric, std::shared_ptr<Metric>& lost_segments_metric) {
    // Request that that this process receive a SIGTERM if the parent process (thread in parent) dies/exits.
    auto ret = prctl(PR_SET_PDEATHSIG, SIGTERM);
    if (ret != 0) {
        Logger::Warn("prctl(PR_SET_PDEATHSIG, SIGTERM) failed: %s", std::strerror(errno));
    }

    Netlink data_netlink;
    Netlink netlink;
    Gate _stop_gate;

    FileWatcher::notify_fn_t fn = [&_stop_gate](const std::string& dir, const std::string& name, uint32_t mask) {
        if (name == "auditd" && (mask & (IN_CREATE|IN_MOVED_TO)) != 0) {
            Logger::Info("/sbin/auditd found on the system, exiting.");
            _stop_gate.Open();
        }
    };

    FileWatcher watcher(std::move(fn), {
            {"/sbin", IN_CREATE|IN_MOVED_TO},
    });

    std::function handler = [&](uint16_t type, uint16_t flags, const void* data, size_t len) -> bool {
        // Ignore AUDIT_REPLACE for now since replying to it doesn't actually do anything.
        if (type >= AUDIT_FIRST_USER_MSG && type != static_cast<uint16_t>(RecordType::REPLACE)) {
            size_t loss_bytes = 0;
            auto ptr = raw_queue.Allocate(len+sizeof(RecordType), &loss_bytes);
            if (ptr == nullptr) {
                _stop_gate.Open();
                return false;
            }
            if (loss_bytes > 0) {
                lost_bytes_metric->Update(loss_bytes);
                lost_segments_metric->Update(1);
                loss_bytes = 0;
            }
            *reinterpret_cast<RecordType*>(ptr) = static_cast<RecordType>(type);
            std::memcpy(ptr+sizeof(RecordType), data, len);
            raw_queue.Commit(len+sizeof(RecordType));
            bytes_metric->Update(len);
            records_metric->Update(1.0);
        }
        return false;
    };

    Logger::Info("Connecting to AUDIT NETLINK socket");
    ret = data_netlink.Open(std::move(handler));
    if (ret != 0) {
        Logger::Error("Failed to open AUDIT NETLINK connection: %s", std::strerror(-ret));
        return false;
    }
    Defer _close_data_netlink([&data_netlink]() { data_netlink.Close(); });

    ret = netlink.Open(nullptr);
    if (ret != 0) {
        Logger::Error("Failed to open AUDIT NETLINK connection: %s", std::strerror(-ret));
        return false;
    }
    Defer _close_netlink([&netlink]() { netlink.Close(); });

    watcher.Start();
    Defer _stop_watcher([&watcher]() { watcher.Stop(); });

    uint32_t our_pid = getpid();

    Logger::Info("Checking assigned audit pid");
    audit_status status;
    ret = NetlinkRetry([&netlink,&status]() { return netlink.AuditGet(status); } );
    if (ret != 0) {
        Logger::Error("Failed to get audit status: %s", std::strerror(-ret));
        return false;
    }
    uint32_t pid = status.pid;
    uint32_t enabled = status.enabled;

    if (pid != 0 && PathExists("/proc/" + std::to_string(pid))) {
        Logger::Error("There is another process (pid = %d) already assigned as the audit collector", pid);
        return false;
    }

    Logger::Info("Enabling AUDIT event collection");
    int retry_count = 0;
    do {
        if (retry_count > 5) {
            Logger::Error("Failed to set audit pid: Max retried exceeded");
        }
        ret = data_netlink.AuditSetPid(our_pid);
        if (ret == -ETIMEDOUT) {
            // If setpid timedout, it may have still succeeded, so re-fetch pid
            ret = NetlinkRetry([&]() { return netlink.AuditGetPid(pid); });
            if (ret != 0) {
                Logger::Error("Failed to get audit pid: %s", std::strerror(-ret));
                return false;
            }
        } else if (ret != 0) {
            Logger::Error("Failed to set audit pid: %s", std::strerror(-ret));
            return false;
        } else {
            break;
        }
        retry_count += 1;
    } while (pid != our_pid);
    if (enabled == 0) {
        ret = NetlinkRetry([&netlink,&status]() { return netlink.AuditSetEnabled(1); });
        if (ret != 0) {
            Logger::Error("Failed to enable auditing: %s", std::strerror(-ret));
            return false;
        }
    }

    Defer _revert_enabled([&netlink,enabled]() {
        if (enabled == 0) {
            int ret;
            ret = NetlinkRetry([&netlink]() { return netlink.AuditSetEnabled(1); });
            if (ret != 0) {
                Logger::Error("Failed to enable auditing: %s", std::strerror(-ret));
            }
        }
    });

    Signals::SetExitHandler([&_stop_gate]() { _stop_gate.Open(); });

    auto _last_pid_check = std::chrono::steady_clock::now();
    while(!Signals::IsExit()) {
        if (_stop_gate.Wait(Gate::OPEN, 1000)) {
            return false;
        }

        auto now = std::chrono::steady_clock::now();
        if (_last_pid_check < now - std::chrono::seconds(10)) {
            _last_pid_check = now;
            pid = 0;
            int ret;
            ret = NetlinkRetry([&netlink,&pid]() { return netlink.AuditGetPid(pid); });
            if (ret != 0) {
                if (ret == -ECANCELED || ret == -ENOTCONN) {
                    if (!Signals::IsExit()) {
                        Logger::Error("AUDIT NETLINK connection has closed unexpectedly");
                    }
                } else {
                    Logger::Error("Failed to get audit pid: %s", std::strerror(-ret));
                }
                return false;
            } else {
                if (pid != our_pid) {
                    if (pid != 0) {
                        Logger::Warn("Another process (pid = %d) has taken over AUDIT NETLINK event collection.", pid);
                        return false;
                    } else {
                        Logger::Warn("Audit pid was unexpectedly set to 0, restarting...");
                        return true;
                    }
                }
            }
        }
    }
    return false;
}