netlink.go (132 lines of code) (raw):

// Licensed to Elasticsearch B.V. under one or more contributor // license agreements. See the NOTICE file distributed with // this work for additional information regarding copyright // ownership. Elasticsearch B.V. licenses this file to you under // the Apache License, Version 2.0 (the "License"); you may // not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, // software distributed under the License is distributed on an // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. //go:build linux // +build linux package libaudit import ( "errors" "fmt" "io" "os" "sync/atomic" "syscall" "unsafe" ) // Generic Netlink Client // NetlinkSender sends a netlink message and returns the sequence number used // in the message and an error if it occurred. type NetlinkSender interface { Send(msg syscall.NetlinkMessage) (uint32, error) SendNoWait(msg syscall.NetlinkMessage) (uint32, error) } // NetlinkReceiver receives data from the netlink socket and uses the provided // parser to convert the raw bytes to NetlinkMessages. For most uses cases // syscall.ParseNetlinkMessage should be used. If nonBlocking is true then // instead of blocking when no data is available, EWOULDBLOCK is returned. type NetlinkReceiver interface { Receive(nonBlocking bool, p NetlinkParser) ([]syscall.NetlinkMessage, error) } // NetlinkSendReceiver combines the Send and Receive into one interface. type NetlinkSendReceiver interface { io.Closer NetlinkSender NetlinkReceiver } // NetlinkParser parses the raw bytes read from the netlink socket into // netlink messages. type NetlinkParser func([]byte) ([]syscall.NetlinkMessage, error) // NetlinkClient is a generic client for sending and receiving netlink messages. type NetlinkClient struct { fd int // File descriptor used for communication. src syscall.Sockaddr // Local socket address. dest syscall.Sockaddr // Remote socket address (client assumes the dest is the kernel). pid uint32 // Port ID of the local socket. seq uint32 // Sequence number used in outgoing messages. readBuf []byte respWriter io.Writer } // NewNetlinkClient creates a new NetlinkClient. It creates a socket and binds // it. readBuf is an optional byte buffer used for reading data from the socket. // The size of the buffer limits the maximum message size the can be read. If no // buffer is provided one will be allocated using the OS page size. resp is // optional and can be used to receive a copy of all bytes read from the socket // (this is useful for debugging). // // The returned NetlinkClient must be closed with Close() when finished. func NewNetlinkClient(proto int, groups uint32, readBuf []byte, resp io.Writer) (*NetlinkClient, error) { s, err := syscall.Socket(syscall.AF_NETLINK, syscall.SOCK_RAW|syscall.SOCK_CLOEXEC, proto) if err != nil { return nil, err } src := &syscall.SockaddrNetlink{Family: syscall.AF_NETLINK, Groups: groups} if err = syscall.Bind(s, src); err != nil { syscall.Close(s) return nil, fmt.Errorf("bind failed: %w", err) } pid, err := getPortID(s) if err != nil { syscall.Close(s) return nil, err } if len(readBuf) == 0 { // Default size used in libnl. readBuf = make([]byte, os.Getpagesize()) } return &NetlinkClient{ fd: s, src: src, dest: &syscall.SockaddrNetlink{}, pid: pid, readBuf: readBuf, respWriter: resp, }, nil } // getPortID gets the kernel assigned port ID (PID) of the local netlink socket. // The kernel assigns the processes PID to the first socket then assigns arbitrary values // to any follow-on sockets. See man netlink for details. func getPortID(fd int) (uint32, error) { address, err := syscall.Getsockname(fd) if err != nil { return 0, err } addr, ok := address.(*syscall.SockaddrNetlink) if !ok { return 0, errors.New("unexpected socket address type") } return addr.Pid, nil } // SendNoWait sends a message to the netlink client in non-blocking mode. Behavior is otherwise identical to Send() func (c *NetlinkClient) SendNoWait(msg syscall.NetlinkMessage) (uint32, error) { return c.send(msg, syscall.MSG_DONTWAIT) } // Send sends a netlink message and returns the sequence number used // in the message and an error if it occurred. If the PID is not set then // the value will be populated automatically (recommended). func (c *NetlinkClient) Send(msg syscall.NetlinkMessage) (uint32, error) { return c.send(msg, 0) } func (c *NetlinkClient) send(msg syscall.NetlinkMessage, flags int) (uint32, error) { if msg.Header.Pid == 0 { msg.Header.Pid = c.pid } msg.Header.Seq = atomic.AddUint32(&c.seq, 1) to := &syscall.SockaddrNetlink{} return msg.Header.Seq, syscall.Sendto(c.fd, serialize(msg), flags, to) } func serialize(msg syscall.NetlinkMessage) []byte { msg.Header.Len = uint32(syscall.SizeofNlMsghdr + len(msg.Data)) b := make([]byte, msg.Header.Len) *(*syscall.NlMsghdr)(unsafe.Pointer(&b[0])) = msg.Header copy(b[syscall.SizeofNlMsghdr:], msg.Data) return b } // Receive receives data from the netlink socket and uses the provided // parser to convert the raw bytes to NetlinkMessages. See NetlinkReceiver docs. func (c *NetlinkClient) Receive(nonBlocking bool, p NetlinkParser) ([]syscall.NetlinkMessage, error) { var flags int if nonBlocking { flags |= syscall.MSG_DONTWAIT } // XXX (akroh): A possible enhancement is to use the MSG_PEEK flag to // check the message size and increase the buffer size to handle it all. nr, from, err := syscall.Recvfrom(c.fd, c.readBuf, flags) if err != nil { // EAGAIN or EWOULDBLOCK will be returned for non-blocking reads where // the read would normally have blocked. return nil, err } if nr < syscall.NLMSG_HDRLEN { return nil, fmt.Errorf("not enough bytes (%v) received to form a netlink header", nr) } fromNetlink, ok := from.(*syscall.SockaddrNetlink) if !ok || fromNetlink.Pid != 0 { // Spoofed packet received on audit netlink socket. return nil, errors.New("message received was not from the kernel") } buf := c.readBuf[:nr] // Dump raw data for inspection purposes. if c.respWriter != nil { if _, err = c.respWriter.Write(buf); err != nil { return nil, err } } msgs, err := p(buf) if err != nil { return nil, fmt.Errorf("failed to parse netlink messages (bytes_received=%v): %w", nr, err) } return msgs, nil } // Close closes the netlink client's raw socket. func (c *NetlinkClient) Close() error { return syscall.Close(c.fd) } // Netlink Error Code Handling // ParseNetlinkError parses the errno from the data section of a // syscall.NetlinkMessage. If netlinkData is less than 4 bytes an error // describing the problem will be returned. func ParseNetlinkError(netlinkData []byte) error { if len(netlinkData) >= 4 { errno := -*(*int32)(unsafe.Pointer(&netlinkData[0])) if errno == 0 { return nil } return syscall.Errno(errno) } return errors.New("received netlink error (data too short to read errno)") }