tools/mongodb-hybrid-dlp/dlpfunction.go (597 lines of code) (raw):
// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package dlpfunction
import (
"context"
"crypto/sha256"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"os"
"slices"
"strings"
"sync"
"time"
"github.com/GoogleCloudPlatform/functions-framework-go/functions"
"google.golang.org/api/option"
"go.mongodb.org/mongo-driver/v2/bson"
"go.mongodb.org/mongo-driver/v2/mongo"
"go.mongodb.org/mongo-driver/v2/mongo/options"
"go.mongodb.org/mongo-driver/v2/mongo/readpref"
dlp "cloud.google.com/go/dlp/apiv2"
"cloud.google.com/go/dlp/apiv2/dlppb"
"cloud.google.com/go/storage"
"github.com/rs/zerolog/log"
gocache "github.com/TwiN/gocache/v2"
)
type MongoSource struct {
Collection string `json:"collection"`
Database string `json:"database"`
Deployment string `json:"deployment"`
}
type MongoChangeSource struct {
Source MongoSource `json:"source"`
ResumeToken string `json:"resume_token"`
}
type MongoScanner struct {
RunPeriod time.Duration `json:"-"`
Collections []MongoChangeSource `json:"-"`
Databases []MongoChangeSource `json:"-"`
Deployments []MongoChangeSource `json:"-"`
ConnectionString string `json:"-"`
Username string `json:"-"`
Password string `json:"-"`
Client *mongo.Client `json:"-"`
GcpBillingProject string `json:"-"`
GcpDlpEndpoint string `json:"-"`
GcpDlpTriggerName string `json:"-"`
GcpDlpTriggerJobId string `json:"-"`
StateFile *url.URL `json:"-"`
DlpClient *dlp.Client `json:"-"`
DlpJobActive bool `json:"-"`
Cache *gocache.Cache `json:"-"`
LastChanges map[MongoSource]MongoChangeSource `json:"last_changes"`
}
type MongoChange struct {
Source MongoSource
ConnectionString string `json:"-"`
Document *bson.Raw `json:"-"`
ResumeToken interface{}
}
var fieldsToRemove []string = []string{"_id"}
var mongoScanner MongoScanner
func (c MongoChange) DlpType() string {
return "mongodb"
}
func (c MongoChange) DlpVersion() string {
return "1.0"
}
func (c MongoChange) DlpFullPath() string {
if strings.HasSuffix(c.ConnectionString, "/") {
return c.ConnectionString + c.DlpRelativePath()
}
return c.ConnectionString + "/" + c.DlpRelativePath()
}
func (c MongoChange) DlpRelativePath() string {
if c.Source.Collection != "" {
return c.Source.Database + "/" + c.Source.Collection
}
if c.Source.Database != "" {
return c.Source.Database
}
return c.Source.Deployment
}
func (c MongoChange) DlpRootPath() string {
if c.Source.Collection != "" {
return c.Source.Database
}
if c.Source.Database != "" {
return c.Source.Database
}
return c.Source.Deployment
}
func init() {
mongoScanner = MongoScanner{
LastChanges: make(map[MongoSource]MongoChangeSource, 0),
}
if os.Getenv("RUN_PERIOD") != "" {
runPeriod, err := time.ParseDuration(os.Getenv("RUN_PERIOD"))
if err != nil {
log.Fatal().Err(err).Msgf("Failed to parse RUN_PERIOD: %s", os.Getenv("RUN_PERIOD"))
}
mongoScanner.RunPeriod = runPeriod
} else {
mongoScanner.RunPeriod, _ = time.ParseDuration("10m")
}
if os.Getenv("MONGO_CONNECTION_STRING") != "" {
mongoScanner.ConnectionString = os.Getenv("MONGO_CONNECTION_STRING")
} else {
log.Fatal().Msg("No MONGO_CONNECTION_STRING specified!")
}
if os.Getenv("MONGO_USERNAME") != "" {
mongoScanner.Username = os.Getenv("MONGO_USERNAME")
}
if os.Getenv("MONGO_PASSWORD") != "" {
mongoScanner.Password = os.Getenv("MONGO_PASSWORD")
}
if os.Getenv("MONGO_DEPLOYMENTS") != "" {
deployments := strings.Split(os.Getenv("MONGO_DEPLOYMENTS"), ",")
for _, d := range deployments {
mongoScanner.Deployments = append(mongoScanner.Deployments, MongoChangeSource{Source: MongoSource{Deployment: strings.TrimSpace(d)}})
}
}
if os.Getenv("MONGO_DATABASES") != "" {
databases := strings.Split(os.Getenv("MONGO_DATABASES"), ",")
for _, d := range databases {
mongoScanner.Databases = append(mongoScanner.Databases, MongoChangeSource{Source: MongoSource{Database: strings.TrimSpace(d)}})
}
}
if os.Getenv("MONGO_COLLECTIONS") != "" {
collections := strings.Split(os.Getenv("MONGO_COLLECTIONS"), ",")
for _, d := range collections {
c := strings.SplitN(d, ".", 2)
mongoScanner.Collections = append(mongoScanner.Collections, MongoChangeSource{Source: MongoSource{Database: strings.TrimSpace(c[0]), Collection: strings.TrimSpace(c[1])}})
}
}
if len(mongoScanner.Deployments) == 0 && len(mongoScanner.Databases) == 0 && len(mongoScanner.Deployments) == 0 {
log.Fatal().Msg("No sources for change streams specified, set at least one of: MONGO_COLLECTIONS, MONGO_DATABASES, MONGO_DEPLOYMENTS")
}
if os.Getenv("DLP_TRIGGER_NAME") != "" {
mongoScanner.GcpDlpTriggerName = os.Getenv("DLP_TRIGGER_NAME")
} else {
log.Fatal().Msg("No DLP_TRIGGER_NAME specified!")
}
if os.Getenv("PROJECT_ID") != "" {
mongoScanner.GcpBillingProject = os.Getenv("PROJECT_ID")
} else {
log.Fatal().Msg("No PROJECT_ID specified!")
}
if os.Getenv("DLP_ENDPOINT") != "" {
mongoScanner.GcpDlpEndpoint = os.Getenv("DLP_ENDPOINT")
}
if os.Getenv("STATE_FILE") != "" {
if !strings.HasPrefix(os.Getenv("STATE_FILE"), "gs://") {
log.Fatal().Msg("State file should start with gs://!")
}
url, err := url.Parse(os.Getenv("STATE_FILE"))
if err != nil {
log.Fatal().Err(err).Msg("Failed to parse STATE_FILE location!")
}
mongoScanner.StateFile = url
} else {
log.Fatal().Msg("No STATE_FILE specified!")
}
functions.HTTP("DLPFunctionHTTP", DLPFunctionHTTP)
}
func (s *MongoScanner) connect(ctx context.Context) error {
clientOptions := options.Client()
clientOptions.ApplyURI(s.ConnectionString)
if s.Username != "" {
clientOptions.SetAuth(options.Credential{Username: s.Username, Password: s.Password})
}
s.Client, _ = mongo.Connect(clientOptions)
return nil
}
func (s *MongoScanner) disconnect(ctx context.Context) {
if s.Client != nil {
if err := s.Client.Disconnect(ctx); err != nil {
panic(err)
}
}
}
func (s *MongoScanner) ProcessChangeStream(ctx context.Context, cs *mongo.ChangeStream, base MongoChange, change chan<- MongoChange, resumeToken chan<- string, errors chan<- error) {
log.Info().Interface("source", base).Msg("Starting to process change stream (this happens for each change stream)...")
for {
select {
case <-ctx.Done():
cs.Close(ctx)
return
default:
}
ok := cs.Next(ctx)
if ok {
newChange := base
newChange.Document = &cs.Current
log.Debug().Interface("change", newChange).Msg("Received change from MongoDB")
var resumeToken map[string]interface{}
err := bson.Unmarshal(cs.ResumeToken(), &resumeToken)
if err != nil {
// Lets not some unmarshalling errors stop us entirely...
log.Error().Err(err).Msg("Failed unmarshaling incoming resume token")
errors <- err
} else {
newChange.ResumeToken = resumeToken
}
change <- newChange
log.Info().Interface("change", newChange).Msg("Emitting changed document")
} else {
select {
case <-ctx.Done():
cs.Close(ctx)
return
default:
}
err := cs.Err()
if err != nil {
log.Error().Err(err).Msg("Received error from change stream iterator")
errors <- err
return
}
}
}
}
func (s *MongoScanner) HybridInspect(ctx context.Context, change MongoChange, original map[string]interface{}, redacted bson.D) error {
var err error
if s.DlpClient == nil {
options := []option.ClientOption{option.WithUserAgent("google-pso-tool/mongodb-hybrid-dlp/0.1.0")}
if s.GcpDlpEndpoint != "" {
options = append(options, option.WithEndpoint(s.GcpDlpEndpoint))
}
if s.GcpBillingProject != "" {
options = append(options, option.WithQuotaProject(s.GcpBillingProject))
}
s.DlpClient, err = dlp.NewClient(ctx, options...)
if err != nil {
return err
}
}
marshaledDoc, err := bson.MarshalExtJSON(redacted, true, false)
if err != nil {
log.Error().Err(err).Msg("Failed to marshal redacted document")
return err
}
marshaledDocHash := sha256.New()
marshaledDocHash.Write([]byte(marshaledDoc))
hashSum := fmt.Sprintf("%x", marshaledDocHash.Sum(nil))
if _, exists := s.Cache.Get(hashSum); exists {
log.Info().Str("hash", hashSum).Msg("Document already processed")
return nil
}
contentItem := &dlppb.ContentItem{
DataItem: &dlppb.ContentItem_Value{
Value: string(marshaledDoc),
},
}
container := &dlppb.Container{
Type: change.DlpType(),
FullPath: change.DlpFullPath(),
RelativePath: "/" + change.DlpRelativePath(),
RootPath: "/" + change.DlpRootPath(),
Version: change.DlpVersion(),
}
labels := map[string]string{}
hybridFindingDetails := &dlppb.HybridFindingDetails{
ContainerDetails: container,
Labels: labels,
}
hybridContentItem := &dlppb.HybridContentItem{
Item: contentItem,
FindingDetails: hybridFindingDetails,
}
if !s.DlpJobActive {
activateJobReq := &dlppb.ActivateJobTriggerRequest{
Name: s.GcpDlpTriggerName,
}
log.Info().Str("triggerID", s.GcpDlpTriggerName).Msg("Activating DLP job...")
activateRes, err := s.DlpClient.ActivateJobTrigger(ctx, activateJobReq)
if err != nil {
if !strings.Contains(err.Error(), "already running") {
log.Error().Err(err).Msg("DLP job activation failed")
return err
}
s.DlpJobActive = true
log.Warn().Str("triggerID", s.GcpDlpTriggerName).Msg("Job is already running")
listReq := &dlppb.ListDlpJobsRequest{
Parent: fmt.Sprintf("projects/%s", s.GcpBillingProject),
}
for resp, err := range s.DlpClient.ListDlpJobs(ctx, listReq).All() {
if err != nil {
break
}
if resp.GetJobTriggerName() == s.GcpDlpTriggerName && resp.GetState() == dlppb.DlpJob_ACTIVE {
s.GcpDlpTriggerJobId = resp.GetName()
log.Warn().Str("jobID", s.GcpDlpTriggerJobId).Msg("Found existing active job in ACTIVE state")
}
}
} else {
s.GcpDlpTriggerJobId = activateRes.Name
log.Info().Str("jobID", s.GcpDlpTriggerJobId).Msg("DLP trigger job activated")
s.DlpJobActive = true
}
}
req := &dlppb.HybridInspectJobTriggerRequest{
Name: s.GcpDlpTriggerName,
HybridItem: hybridContentItem,
}
// Send the hybrid inspect request.
_, err = s.DlpClient.HybridInspectJobTrigger(ctx, req)
if err != nil {
return err
}
s.Cache.Set(hashSum, true)
return nil
}
func (s *MongoScanner) InspectChanges(ctx context.Context, changes <-chan MongoChange, errors chan<- error) {
for {
var gotChange bool = false
select {
case doc := (<-changes):
var parsedDoc map[string]interface{}
var longDocId interface{}
var shortDocId string
var fullDocument bson.D
err := bson.Unmarshal(*doc.Document, &parsedDoc)
if err == nil {
resumeToken := doc.ResumeToken.(map[string]interface{})
if _, ok := resumeToken["_data"]; ok {
resumeTokenData := resumeToken["_data"].(string)
s.LastChanges[doc.Source] = MongoChangeSource{
Source: doc.Source,
ResumeToken: resumeTokenData,
}
}
if _, ok := parsedDoc["_id"]; ok {
longDocId = parsedDoc["_id"]
}
if _, ok := parsedDoc["fullDocument"]; ok {
fullDocument = parsedDoc["fullDocument"].(bson.D)
// Clean up the document to remove stuff that is different from every document even if
// they are semantically equal, so we can cache hits
redactedDocument := bson.D{}
for _, v := range fullDocument {
if !slices.Contains(fieldsToRemove, v.Key) {
redactedDocument = append(redactedDocument, v)
}
if v.Key == "_id" {
shortDocId = v.Value.(bson.ObjectID).Hex()
}
}
log.Debug().Interface("docId", longDocId).Interface("objectID", shortDocId).Msg("Inspecting document")
err := s.HybridInspect(ctx, doc, parsedDoc, redactedDocument)
if err != nil {
log.Error().Err(err).Msg("Hybrid inspection returned an error")
errors <- err
}
} else {
errors <- fmt.Errorf("Change was missing full document: %v", doc.Document)
}
} else {
errors <- err
}
gotChange = true
default:
}
if !gotChange {
select {
case <-ctx.Done():
return
default:
}
}
}
}
func (s *MongoScanner) LoadState() error {
ctx := context.Background()
client, err := storage.NewClient(ctx)
if err != nil {
return err
}
defer client.Close()
ctx, cancel := context.WithTimeout(ctx, time.Second*50)
defer cancel()
o := client.Bucket(s.StateFile.Host).Object(strings.TrimPrefix(s.StateFile.Path, "/"))
rc, err := o.NewReader(ctx)
if err != nil {
return fmt.Errorf("Error reading file gs://%s/%s: %w", s.StateFile.Host, strings.TrimPrefix(s.StateFile.Path, "/"), err)
}
defer rc.Close()
stateData, err := io.ReadAll(rc)
if err != nil {
return err
}
lastChangesList := make([]MongoChangeSource, 0)
err = json.Unmarshal(stateData, &lastChangesList)
if err != nil {
return err
}
for _, lc := range lastChangesList {
s.LastChanges[lc.Source] = lc
}
return nil
}
func (s *MongoScanner) SaveState() error {
ctx := context.Background()
client, err := storage.NewClient(ctx)
if err != nil {
return err
}
defer client.Close()
ctx, cancel := context.WithTimeout(ctx, time.Second*60)
defer cancel()
o := client.Bucket(s.StateFile.Host).Object(strings.TrimPrefix(s.StateFile.Path, "/"))
fmt.Printf("host: %v, object: %v\n", s.StateFile.Host, strings.TrimPrefix(s.StateFile.Path, "/"))
wc := o.NewWriter(ctx)
lastChangesList := make([]MongoChangeSource, 0)
for _, v := range s.LastChanges {
lastChangesList = append(lastChangesList, v)
}
jsonData, err := json.Marshal(lastChangesList)
if err != nil {
return err
}
_, err = wc.Write(jsonData)
if err != nil {
return err
}
if err := wc.Close(); err != nil {
return fmt.Errorf("Writer.Close: %w", err)
}
return nil
}
func (s *MongoScanner) Process(w http.ResponseWriter) error {
connectCtx, connectCancel := context.WithTimeout(context.Background(), 10*time.Second)
defer connectCancel()
s.connect(connectCtx)
defer s.disconnect(connectCtx)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
err := s.Client.Ping(ctx, readpref.Primary())
if err != nil {
log.Error().Err(err).Msg("Great difficulties connecting to MongoDB server, probably you got no connectivity or bad connection string")
return err
}
// 16MB cache with LRU eviction policy
if s.Cache == nil {
var cacheSize int = 16 * 1024 * 1024
log.Info().Int("cacheSize", cacheSize).Msgf("Initialized cache: %d MB", cacheSize/(1024*1024))
s.Cache = gocache.NewCache().WithMaxMemoryUsage(cacheSize).WithEvictionPolicy(gocache.LeastRecentlyUsed)
}
errors := make([]chan error, 0)
documents := make(chan MongoChange, 10)
resumeTokens := make([]chan string, 0)
cancels := make([]context.CancelFunc, 0)
ctxs := make([]context.Context, 0)
var wg sync.WaitGroup
var index int = 0
// Load resume tokens
err = s.LoadState()
if err != nil {
log.Warn().Err(err).Msg("Failed to load state from bucket, possibly no state yet.")
err = nil
}
// There is really only one deployment, but we keep the parameters consistent
for _, deployment := range s.Deployments {
errors = append(errors, make(chan error, 1))
resumeTokens = append(resumeTokens, make(chan string, 1))
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
cancels = append(cancels, cancel)
ctxs = append(ctxs, ctx)
options := options.ChangeStream()
if s.LastChanges[deployment.Source].ResumeToken != "" {
log.Info().Str("resumeToken", s.LastChanges[deployment.Source].ResumeToken).Msg("Using resume token for deployment")
options.SetResumeAfter(bson.M{"_data": s.LastChanges[deployment.Source].ResumeToken})
}
cs, err := s.Client.Watch(ctx, mongo.Pipeline{}, options)
if err != nil {
return err
}
wg.Add(1)
go func(i int) {
defer wg.Done()
base := MongoChange{
Source: deployment.Source,
ConnectionString: s.ConnectionString,
}
s.ProcessChangeStream(ctxs[i], cs, base, documents, resumeTokens[i], errors[i])
}(index)
index += 1
}
// Watch entire databases
for _, database := range s.Databases {
db := s.Client.Database(database.Source.Database)
if err != nil {
return err
}
errors = append(errors, make(chan error, 1))
resumeTokens = append(resumeTokens, make(chan string, 1))
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
cancels = append(cancels, cancel)
ctxs = append(ctxs, ctx)
options := options.ChangeStream()
if s.LastChanges[database.Source].ResumeToken != "" {
log.Info().Str("resumeToken", s.LastChanges[database.Source].ResumeToken).Msg("Using resume token for database")
options.SetResumeAfter(bson.M{"_data": s.LastChanges[database.Source].ResumeToken})
}
cs, err := db.Watch(ctx, mongo.Pipeline{}, options)
if err != nil {
return err
}
wg.Add(1)
go func(i int) {
defer wg.Done()
base := MongoChange{
Source: database.Source,
ConnectionString: s.ConnectionString,
}
s.ProcessChangeStream(ctxs[i], cs, base, documents, resumeTokens[i], errors[i])
}(index)
index += 1
}
// Watch for specific collections
for _, col := range s.Collections {
db := s.Client.Database(col.Source.Database)
if err != nil {
return err
}
coll := db.Collection(col.Source.Collection)
if err != nil {
return err
}
errors = append(errors, make(chan error, 1))
resumeTokens = append(resumeTokens, make(chan string, 1))
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
cancels = append(cancels, cancel)
ctxs = append(ctxs, ctx)
options := options.ChangeStream()
if s.LastChanges[col.Source].ResumeToken != "" {
log.Info().Str("resumeToken", s.LastChanges[col.Source].ResumeToken).Msg("Using resume token for collection")
options.SetResumeAfter(bson.M{"_data": s.LastChanges[col.Source].ResumeToken})
}
cs, err := coll.Watch(ctx, mongo.Pipeline{}, options)
if err != nil {
return err
}
wg.Add(1)
go func(i int) {
defer wg.Done()
base := MongoChange{
Source: col.Source,
ConnectionString: s.ConnectionString,
}
s.ProcessChangeStream(ctxs[i], cs, base, documents, resumeTokens[i], errors[i])
}(index)
index += 1
}
inspectErrors := make(chan error, 1)
inspectCtx, inspectCancel := context.WithCancel(context.Background())
wg.Add(1)
go func() {
defer wg.Done()
s.InspectChanges(inspectCtx, documents, inspectErrors)
}()
var sleepCycles int = int(s.RunPeriod / (time.Second * 10))
// We take one 10 second cycle off to save stuff
for i := 0; i < sleepCycles-1; i++ {
time.Sleep(10 * time.Second)
fmt.Fprintf(w, "Still processing, cycle=%d/%d ...\n", i+1, sleepCycles-1)
if flush, ok := w.(http.Flusher); ok {
flush.Flush()
}
}
for i := 0; i < index; i++ {
cancels[i]()
}
inspectCancel()
wg.Wait()
// Same resume tokens
err = s.SaveState()
if err != nil {
log.Error().Err(err).Msg("Failed to save state to bucket!")
return err
}
// We've done our bit here, finish the job and let the next run start another one
if s.DlpClient != nil && s.DlpJobActive {
if s.GcpDlpTriggerJobId != "" {
finishJobReq := &dlppb.FinishDlpJobRequest{
Name: s.GcpDlpTriggerJobId,
}
log.Info().Str("jobID", s.GcpDlpTriggerJobId).Msg("Finishing the DLP job...")
finishCtx, _ := context.WithTimeout(context.Background(), 30*time.Second)
err = s.DlpClient.FinishDlpJob(finishCtx, finishJobReq)
if err != nil {
log.Error().Err(err).Msg("Finishing the DLP job errored out")
return err
}
}
s.DlpClient.Close()
}
return nil
}
func DLPFunctionHTTP(w http.ResponseWriter, r *http.Request) {
fmt.Fprint(w, "Processing now.\n")
err := mongoScanner.Process(w)
if err != nil {
fmt.Fprintf(w, "Processing failed: %v", err)
}
}