packetbeat/protos/mongodb/mongodb.go (360 lines of code) (raw):
// Licensed to Elasticsearch B.V. under one or more contributor
// license agreements. See the NOTICE file distributed with
// this work for additional information regarding copyright
// ownership. Elasticsearch B.V. licenses this file to you under
// the Apache License, Version 2.0 (the "License"); you may
// not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package mongodb
import (
"fmt"
"strings"
"time"
"github.com/elastic/beats/v7/libbeat/common"
conf "github.com/elastic/elastic-agent-libs/config"
"github.com/elastic/elastic-agent-libs/logp"
"github.com/elastic/elastic-agent-libs/mapstr"
"github.com/elastic/elastic-agent-libs/monitoring"
"github.com/elastic/beats/v7/packetbeat/pb"
"github.com/elastic/beats/v7/packetbeat/procs"
"github.com/elastic/beats/v7/packetbeat/protos"
"github.com/elastic/beats/v7/packetbeat/protos/tcp"
"go.mongodb.org/mongo-driver/bson/primitive"
)
var debugf = logp.MakeDebug("mongodb")
type mongodbPlugin struct {
// config
ports []int
sendRequest bool
sendResponse bool
maxDocs int
maxDocLength int
requests *common.Cache
responses *common.Cache
transactionTimeout time.Duration
results protos.Reporter
watcher *procs.ProcessesWatcher
}
type transactionKey struct {
tcp common.HashableTCPTuple
id int32
}
var unmatchedRequests = monitoring.NewInt(nil, "mongodb.unmatched_requests")
func init() {
protos.Register("mongodb", New)
}
func New(
testMode bool,
results protos.Reporter,
watcher *procs.ProcessesWatcher,
cfg *conf.C,
) (protos.Plugin, error) {
p := &mongodbPlugin{}
config := defaultConfig
if !testMode {
if err := cfg.Unpack(&config); err != nil {
return nil, err
}
}
if err := p.init(results, watcher, &config); err != nil {
return nil, err
}
return p, nil
}
func (mongodb *mongodbPlugin) init(results protos.Reporter, watcher *procs.ProcessesWatcher, config *mongodbConfig) error {
debugf("Init a MongoDB protocol parser")
mongodb.setFromConfig(config)
mongodb.requests = common.NewCache(
mongodb.transactionTimeout,
protos.DefaultTransactionHashSize)
mongodb.requests.StartJanitor(mongodb.transactionTimeout)
mongodb.responses = common.NewCache(
mongodb.transactionTimeout,
protos.DefaultTransactionHashSize)
mongodb.responses.StartJanitor(mongodb.transactionTimeout)
mongodb.results = results
mongodb.watcher = watcher
return nil
}
func (mongodb *mongodbPlugin) setFromConfig(config *mongodbConfig) {
mongodb.ports = config.Ports
mongodb.sendRequest = config.SendRequest
mongodb.sendResponse = config.SendResponse
mongodb.maxDocs = config.MaxDocs
mongodb.maxDocLength = config.MaxDocLength
mongodb.transactionTimeout = config.TransactionTimeout
}
func (mongodb *mongodbPlugin) GetPorts() []int {
return mongodb.ports
}
func (mongodb *mongodbPlugin) ConnectionTimeout() time.Duration {
return mongodb.transactionTimeout
}
func (mongodb *mongodbPlugin) Parse(
pkt *protos.Packet,
tcptuple *common.TCPTuple,
dir uint8,
private protos.ProtocolData,
) protos.ProtocolData {
debugf("Parse method triggered")
conn := ensureMongodbConnection(private)
conn = mongodb.doParse(conn, pkt, tcptuple, dir)
if conn == nil {
return nil
}
return conn
}
func ensureMongodbConnection(private protos.ProtocolData) *mongodbConnectionData {
if private == nil {
return &mongodbConnectionData{}
}
priv, ok := private.(*mongodbConnectionData)
if !ok {
logp.Warn("mongodb connection data type error, create new one")
return &mongodbConnectionData{}
}
if priv == nil {
debugf("Unexpected: mongodb connection data not set, create new one")
return &mongodbConnectionData{}
}
return priv
}
func (mongodb *mongodbPlugin) doParse(
conn *mongodbConnectionData,
pkt *protos.Packet,
tcptuple *common.TCPTuple,
dir uint8,
) *mongodbConnectionData {
st := conn.streams[dir]
if st == nil {
st = newStream(pkt, tcptuple)
conn.streams[dir] = st
debugf("new stream: %p (dir=%v, len=%v)", st, dir, len(pkt.Payload))
} else {
// concatenate bytes
st.data = append(st.data, pkt.Payload...)
if len(st.data) > tcp.TCPMaxDataInStream {
debugf("Stream data too large, dropping TCP stream")
conn.streams[dir] = nil
return conn
}
}
for len(st.data) > 0 {
if st.message == nil {
st.message = &mongodbMessage{ts: pkt.Ts}
}
ok, complete := mongodbMessageParser(st)
if !ok {
// drop this tcp stream. Will retry parsing with the next
// segment in it
conn.streams[dir] = nil
debugf("Ignore Mongodb message. Drop tcp stream. Try parsing with the next segment")
return conn
}
if !complete {
// wait for more data
debugf("MongoDB wait for more data before parsing message")
break
}
// all ok, go to next level and reset stream for new message
debugf("MongoDB message complete")
mongodb.handleMongodb(conn, st.message, tcptuple, dir)
st.PrepareForNewMessage()
}
return conn
}
func newStream(pkt *protos.Packet, tcptuple *common.TCPTuple) *stream {
s := &stream{
tcptuple: tcptuple,
data: pkt.Payload,
message: &mongodbMessage{ts: pkt.Ts},
}
return s
}
func (mongodb *mongodbPlugin) handleMongodb(
conn *mongodbConnectionData,
m *mongodbMessage,
tcptuple *common.TCPTuple,
dir uint8,
) {
m.tcpTuple = *tcptuple
m.direction = dir
m.cmdlineTuple = mongodb.watcher.FindProcessesTupleTCP(tcptuple.IPPort())
if m.isResponse {
debugf("MongoDB response message")
mongodb.onResponse(conn, m)
} else {
debugf("MongoDB request message")
mongodb.onRequest(conn, m)
}
}
func (mongodb *mongodbPlugin) onRequest(conn *mongodbConnectionData, msg *mongodbMessage) {
// publish request only transaction
if !awaitsReply(msg) {
mongodb.onTransComplete(msg, nil)
return
}
id := msg.requestID
key := transactionKey{tcp: msg.tcpTuple.Hashable(), id: id}
// try to find matching response potentially inserted before
if v := mongodb.responses.Delete(key); v != nil {
resp := v.(*mongodbMessage)
mongodb.onTransComplete(msg, resp)
return
}
// insert into cache for correlation
old := mongodb.requests.Put(key, msg)
if old != nil {
debugf("Two requests without a Response. Dropping old request")
unmatchedRequests.Add(1)
}
}
func (mongodb *mongodbPlugin) onResponse(conn *mongodbConnectionData, msg *mongodbMessage) {
id := msg.responseTo
key := transactionKey{tcp: msg.tcpTuple.Hashable(), id: id}
// try to find matching request
if v := mongodb.requests.Delete(key); v != nil {
requ := v.(*mongodbMessage)
mongodb.onTransComplete(requ, msg)
return
}
// insert into cache for correlation
mongodb.responses.Put(key, msg)
}
func (mongodb *mongodbPlugin) onTransComplete(requ, resp *mongodbMessage) {
trans := newTransaction(requ, resp)
debugf("Mongodb transaction completed: %s", trans.mongodb)
mongodb.publishTransaction(trans)
}
func newTransaction(requ, resp *mongodbMessage) *transaction {
trans := &transaction{}
// fill request
if requ != nil {
trans.mongodb = mapstr.M{}
trans.event = requ.event
trans.method = requ.method
trans.cmdline = requ.cmdlineTuple
trans.ts = requ.ts
trans.src, trans.dst = common.MakeEndpointPair(requ.tcpTuple.BaseTuple, requ.cmdlineTuple)
if requ.direction == tcp.TCPDirectionReverse {
trans.src, trans.dst = trans.dst, trans.src
}
trans.params = requ.params
trans.resource = requ.resource
trans.bytesIn = int(requ.messageLength)
trans.documents = requ.documents
trans.requestDocuments = requ.documents // preserving request documents that contains mongodb query for the new OP_MSG based protocol
}
// fill response
if resp != nil {
for k, v := range resp.event {
trans.event[k] = v
}
trans.error = resp.error
trans.documents = resp.documents
trans.endTime = resp.ts
trans.bytesOut = int(resp.messageLength)
}
return trans
}
func (mongodb *mongodbPlugin) GapInStream(tcptuple *common.TCPTuple, dir uint8,
nbytes int, private protos.ProtocolData) (priv protos.ProtocolData, drop bool) {
return private, true
}
func (mongodb *mongodbPlugin) ReceivedFin(tcptuple *common.TCPTuple, dir uint8,
private protos.ProtocolData) protos.ProtocolData {
return private
}
func copyMapWithoutKey(d map[string]interface{}, keys ...string) map[string]interface{} {
res := map[string]interface{}{}
for k, v := range d {
found := false
for _, excludeKey := range keys {
if k == excludeKey {
found = true
break
}
}
if !found {
res[k] = v
}
}
return res
}
func reconstructQuery(t *transaction, full bool) (query string) {
query = t.resource + "." + t.method + "("
var doc interface{}
if len(t.params) > 0 {
if !full {
// remove the actual data.
// TODO: review if we need to add other commands here
switch t.method {
case "insert":
doc = copyMapWithoutKey(t.params, "documents")
case "update":
doc = copyMapWithoutKey(t.params, "updates")
case "findandmodify":
doc = copyMapWithoutKey(t.params, "update")
}
} else {
doc = t.params
}
} else if len(t.requestDocuments) > 0 { // This recovers the query document from OP_MSG
if m, ok := t.requestDocuments[0].(primitive.M); ok {
excludeKeys := []string{"lsid"}
if !full {
excludeKeys = append(excludeKeys, "documents")
}
doc = copyMapWithoutKey(m, excludeKeys...)
}
}
queryString, err := doc2str(doc)
if err != nil {
debugf("Error marshaling query document: %v", err)
} else {
query += queryString
}
query += ")"
skip, _ := t.event["numberToSkip"].(int)
if skip > 0 {
query += fmt.Sprintf(".skip(%d)", skip)
}
limit, _ := t.event["numberToReturn"].(int)
if limit > 0 && limit < 0x7fffffff {
query += fmt.Sprintf(".limit(%d)", limit)
}
return query
}
func (mongodb *mongodbPlugin) publishTransaction(t *transaction) {
if mongodb.results == nil {
debugf("Try to publish transaction with null results")
return
}
evt, pbf := pb.NewBeatEvent(t.ts)
pbf.SetSource(&t.src)
pbf.AddIP(t.src.IP)
pbf.SetDestination(&t.dst)
pbf.AddIP(t.dst.IP)
pbf.Source.Bytes = int64(t.bytesIn)
pbf.Destination.Bytes = int64(t.bytesOut)
pbf.Event.Dataset = "mongodb"
pbf.Event.Start = t.ts
pbf.Event.End = t.endTime
pbf.Network.Transport = "tcp"
pbf.Network.Protocol = pbf.Event.Dataset
fields := evt.Fields
fields["type"] = pbf.Event.Dataset
if t.error == "" {
fields["status"] = common.OK_STATUS
} else {
t.event["error"] = t.error
fields["status"] = common.ERROR_STATUS
}
fields["mongodb"] = t.event
fields["method"] = t.method
fields["resource"] = t.resource
fields["query"] = reconstructQuery(t, false)
if mongodb.sendRequest {
fields["request"] = reconstructQuery(t, true)
}
if mongodb.sendResponse {
if len(t.documents) > 0 {
// response field needs to be a string
docs := make([]string, 0, len(t.documents))
for i, doc := range t.documents {
if mongodb.maxDocs > 0 && i >= mongodb.maxDocs {
docs = append(docs, "[...]")
break
}
str, err := doc2str(doc)
if err != nil {
logp.Warn("Failed to JSON marshal document from Mongo: %v (error: %v)", doc, err)
} else {
if mongodb.maxDocLength > 0 && len(str) > mongodb.maxDocLength {
str = str[:mongodb.maxDocLength] + " ..."
}
docs = append(docs, str)
}
}
fields["response"] = strings.Join(docs, "\n")
}
}
mongodb.results(evt)
}