packetbeat/protos/pgsql/pgsql.go (398 lines of code) (raw):

// Licensed to Elasticsearch B.V. under one or more contributor // license agreements. See the NOTICE file distributed with // this work for additional information regarding copyright // ownership. Elasticsearch B.V. licenses this file to you under // the Apache License, Version 2.0 (the "License"); you may // not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, // software distributed under the License is distributed on an // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. package pgsql import ( "errors" "strings" "time" "github.com/elastic/beats/v7/libbeat/common" conf "github.com/elastic/elastic-agent-libs/config" "github.com/elastic/elastic-agent-libs/logp" "github.com/elastic/elastic-agent-libs/mapstr" "github.com/elastic/elastic-agent-libs/monitoring" "github.com/elastic/beats/v7/packetbeat/pb" "github.com/elastic/beats/v7/packetbeat/procs" "github.com/elastic/beats/v7/packetbeat/protos" "github.com/elastic/beats/v7/packetbeat/protos/tcp" "go.uber.org/zap" ) type pgsqlPlugin struct { log, debug, detail *logp.Logger isDebug, isDetail bool // config ports []int maxStoreRows int maxRowLength int sendRequest bool sendResponse bool transactions *common.Cache transactionTimeout time.Duration results protos.Reporter watcher *procs.ProcessesWatcher // function pointer for mocking handlePgsql func(pgsql *pgsqlPlugin, m *pgsqlMessage, tcp *common.TCPTuple, dir uint8, raw_msg []byte) } type pgsqlMessage struct { start int end int isSSLResponse bool isSSLRequest bool toExport bool ts time.Time isRequest bool query string size uint64 fields []string fieldsFormat []byte rows [][]string numberOfRows int numberOfFields int isOK bool isError bool errorInfo string errorCode string errorSeverity string notes []string direction uint8 tcpTuple common.TCPTuple cmdlineTuple *common.ProcessTuple } type pgsqlTransaction struct { tuple common.TCPTuple src common.Endpoint dst common.Endpoint ts time.Time endTime time.Time query string method string bytesOut uint64 bytesIn uint64 notes []string isError bool pgsql mapstr.M requestRaw string responseRaw string } type pgsqlStream struct { data []byte parseOffset int parseState int seenSSLRequest bool expectSSLResponse bool message *pgsqlMessage } const ( pgsqlStartState = iota pgsqlGetDataState pgsqlExtendedQueryState ) const ( sslRequest = iota startupMessage cancelRequest ) var errInvalidLength = errors.New("invalid length") var unmatchedResponses = monitoring.NewInt(nil, "pgsql.unmatched_responses") func init() { protos.Register("pgsql", New) } func New( testMode bool, results protos.Reporter, watcher *procs.ProcessesWatcher, cfg *conf.C, ) (protos.Plugin, error) { p := &pgsqlPlugin{} config := defaultConfig if !testMode { if err := cfg.Unpack(&config); err != nil { return nil, err } } if err := p.init(results, watcher, &config); err != nil { return nil, err } return p, nil } func (pgsql *pgsqlPlugin) init(results protos.Reporter, watcher *procs.ProcessesWatcher, config *pgsqlConfig) error { pgsql.setFromConfig(config) pgsql.log = logp.NewLogger("pgsql") pgsql.debug = logp.NewLogger("pgsql", zap.AddCallerSkip(1)) pgsql.detail = logp.NewLogger("pgsqldetailed", zap.AddCallerSkip(1)) pgsql.isDebug, pgsql.isDetail = logp.IsDebug("pgsql"), logp.IsDebug("pgsqldetailed") pgsql.transactions = common.NewCache( pgsql.transactionTimeout, protos.DefaultTransactionHashSize) pgsql.transactions.StartJanitor(pgsql.transactionTimeout) pgsql.handlePgsql = handlePgsql pgsql.results = results pgsql.watcher = watcher return nil } func (pgsql *pgsqlPlugin) setFromConfig(config *pgsqlConfig) { pgsql.ports = config.Ports pgsql.maxRowLength = config.MaxRowLength pgsql.maxStoreRows = config.MaxRows pgsql.sendRequest = config.SendRequest pgsql.sendResponse = config.SendResponse pgsql.transactionTimeout = config.TransactionTimeout } func (pgsql *pgsqlPlugin) getTransaction(k common.HashableTCPTuple) []*pgsqlTransaction { v := pgsql.transactions.Get(k) if v != nil { return v.([]*pgsqlTransaction) } return nil } //go:inline func (pgsql *pgsqlPlugin) debugf(format string, v ...interface{}) { if pgsql.isDebug { pgsql.debug.Debugf(format, v...) } } //go:inline func (pgsql *pgsqlPlugin) detailf(format string, v ...interface{}) { if pgsql.isDetail { pgsql.detail.Debugf(format, v...) } } func (pgsql *pgsqlPlugin) GetPorts() []int { return pgsql.ports } func (stream *pgsqlStream) prepareForNewMessage() { stream.data = stream.data[stream.message.end:] stream.parseState = pgsqlStartState stream.parseOffset = 0 stream.message = nil } // Extract the method from a SQL query func getQueryMethod(q string) string { index := strings.Index(q, " ") var method string if index > 0 { method = strings.ToUpper(q[:index]) } else { method = strings.ToUpper(q) } return method } type pgsqlPrivateData struct { data [2]*pgsqlStream } func (pgsql *pgsqlPlugin) ConnectionTimeout() time.Duration { return pgsql.transactionTimeout } func (pgsql *pgsqlPlugin) Parse(pkt *protos.Packet, tcptuple *common.TCPTuple, dir uint8, private protos.ProtocolData, ) protos.ProtocolData { priv := pgsqlPrivateData{} if private != nil { var ok bool priv, ok = private.(pgsqlPrivateData) if !ok { priv = pgsqlPrivateData{} } } if priv.data[dir] == nil { priv.data[dir] = &pgsqlStream{ data: pkt.Payload, message: &pgsqlMessage{ts: pkt.Ts}, } pgsql.detailf("New stream created") } else { // concatenate bytes priv.data[dir].data = append(priv.data[dir].data, pkt.Payload...) pgsql.detailf("Len data: %d cap data: %d", len(priv.data[dir].data), cap(priv.data[dir].data)) if len(priv.data[dir].data) > tcp.TCPMaxDataInStream { pgsql.debugf("Stream data too large, dropping TCP stream") priv.data[dir] = nil return priv } } stream := priv.data[dir] if priv.data[1-dir] != nil && priv.data[1-dir].seenSSLRequest { stream.expectSSLResponse = true } for len(stream.data) > 0 { if stream.message == nil { stream.message = &pgsqlMessage{ts: pkt.Ts} } ok, complete := pgsql.pgsqlMessageParser(priv.data[dir]) if !ok { // drop this tcp stream. Will retry parsing with the next // segment in it priv.data[dir] = nil pgsql.debugf("Ignore Postgresql message. Drop tcp stream. Try parsing with the next segment") return priv } if complete { // all ok, ship it msg := stream.data[stream.message.start:stream.message.end] if stream.message.isSSLRequest { // SSL request stream.seenSSLRequest = true } else if stream.message.isSSLResponse { // SSL request answered stream.expectSSLResponse = false priv.data[1-dir].seenSSLRequest = false } else { if stream.message.toExport { pgsql.handlePgsql(pgsql, stream.message, tcptuple, dir, msg) } } // and reset message stream.prepareForNewMessage() } else { // wait for more data break } } return priv } func messageHasEnoughData(msg *pgsqlMessage) bool { if msg == nil { return false } if msg.isSSLRequest || msg.isSSLResponse { return false } if msg.isRequest { return len(msg.query) > 0 } return len(msg.rows) > 0 } // Called when there's a drop packet func (pgsql *pgsqlPlugin) GapInStream(tcptuple *common.TCPTuple, dir uint8, nbytes int, private protos.ProtocolData) (priv protos.ProtocolData, drop bool, ) { if private == nil { return private, false } pgsqlData, ok := private.(pgsqlPrivateData) if !ok { return private, false } if pgsqlData.data[dir] == nil { return pgsqlData, false } // If enough data was received, send it to the // next layer but mark it as incomplete. stream := pgsqlData.data[dir] if messageHasEnoughData(stream.message) { pgsql.debugf("Message not complete, but sending to the next layer") m := stream.message m.toExport = true m.end = stream.parseOffset if m.isRequest { m.notes = append(m.notes, "Packet loss while capturing the request") } else { m.notes = append(m.notes, "Packet loss while capturing the response") } msg := stream.data[stream.message.start:stream.message.end] pgsql.handlePgsql(pgsql, stream.message, tcptuple, dir, msg) // and reset message stream.prepareForNewMessage() } return pgsqlData, true } func (pgsql *pgsqlPlugin) ReceivedFin(tcptuple *common.TCPTuple, dir uint8, private protos.ProtocolData) protos.ProtocolData { return private } var handlePgsql = func(pgsql *pgsqlPlugin, m *pgsqlMessage, tcptuple *common.TCPTuple, dir uint8, raw_msg []byte, ) { m.tcpTuple = *tcptuple m.direction = dir m.cmdlineTuple = pgsql.watcher.FindProcessesTupleTCP(tcptuple.IPPort()) if m.isRequest { pgsql.receivedPgsqlRequest(m) } else { pgsql.receivedPgsqlResponse(m) } } func (pgsql *pgsqlPlugin) receivedPgsqlRequest(msg *pgsqlMessage) { tuple := msg.tcpTuple // parse the query, as it might contain a list of pgsql command // separated by ';' queries := pgsqlQueryParser(msg.query) pgsql.debugf("Queries (%d) :%s", len(queries), queries) transList := pgsql.getTransaction(tuple.Hashable()) if transList == nil { transList = []*pgsqlTransaction{} } for _, query := range queries { trans := &pgsqlTransaction{tuple: tuple} trans.ts = msg.ts trans.src, trans.dst = common.MakeEndpointPair(msg.tcpTuple.BaseTuple, msg.cmdlineTuple) if msg.direction == tcp.TCPDirectionReverse { trans.src, trans.dst = trans.dst, trans.src } trans.pgsql = mapstr.M{} trans.query = query trans.method = getQueryMethod(query) trans.bytesIn = msg.size trans.notes = msg.notes trans.requestRaw = query transList = append(transList, trans) } pgsql.transactions.Put(tuple.Hashable(), transList) } func (pgsql *pgsqlPlugin) receivedPgsqlResponse(msg *pgsqlMessage) { tuple := msg.tcpTuple transList := pgsql.getTransaction(tuple.Hashable()) if len(transList) == 0 { pgsql.debugf("Response from unknown transaction. Ignoring.") unmatchedResponses.Add(1) return } // extract the first transaction from the array trans := pgsql.removeTransaction(transList, tuple, 0) // check if the request was received if trans.pgsql == nil { pgsql.debugf("Response from unknown transaction. Ignoring.") unmatchedResponses.Add(1) return } trans.pgsql.Update(mapstr.M{ "num_rows": msg.numberOfRows, "num_fields": msg.numberOfFields, }) if msg.isError { trans.pgsql.Update(mapstr.M{ "error_code": msg.errorCode, "error_message": msg.errorInfo, "error_severity": msg.errorSeverity, }) } trans.bytesOut = msg.size trans.isError = msg.isError trans.endTime = msg.ts trans.responseRaw = common.DumpInCSVFormat(msg.fields, msg.rows) trans.notes = append(trans.notes, msg.notes...) pgsql.publishTransaction(trans) pgsql.debugf("Postgres transaction completed: %s\n%s", trans.pgsql, trans.responseRaw) } func (pgsql *pgsqlPlugin) publishTransaction(t *pgsqlTransaction) { if pgsql.results == nil { return } evt, pbf := pb.NewBeatEvent(t.ts) pbf.SetSource(&t.src) pbf.SetDestination(&t.dst) pbf.Source.Bytes = int64(t.bytesIn) pbf.Destination.Bytes = int64(t.bytesOut) pbf.Event.Start = t.ts pbf.Event.End = t.endTime pbf.Event.Dataset = "pgsql" pbf.Network.Transport = "tcp" pbf.Network.Protocol = pbf.Event.Dataset pbf.Error.Message = t.notes fields := evt.Fields fields["type"] = pbf.Event.Dataset fields["query"] = t.query fields["method"] = t.method fields["pgsql"] = t.pgsql if t.isError { fields["status"] = common.ERROR_STATUS } else { fields["status"] = common.OK_STATUS } if pgsql.sendRequest { fields["request"] = t.requestRaw } if pgsql.sendResponse { fields["response"] = t.responseRaw } pgsql.results(evt) } func (pgsql *pgsqlPlugin) removeTransaction(transList []*pgsqlTransaction, tuple common.TCPTuple, index int, ) *pgsqlTransaction { trans := transList[index] transList = append(transList[:index], transList[index+1:]...) if len(transList) == 0 { pgsql.transactions.Delete(trans.tuple.Hashable()) } else { pgsql.transactions.Put(tuple.Hashable(), transList) } return trans }