sources/dynamodb/streaming.go (380 lines of code) (raw):

// Copyright 2022 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package dynamodb import ( "context" "encoding/base64" "fmt" "os" "os/signal" "strconv" "strings" "sync" "syscall" "time" sp "cloud.google.com/go/spanner" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/service/dynamodb" "github.com/aws/aws-sdk-go/service/dynamodb/dynamodbiface" "github.com/aws/aws-sdk-go/service/dynamodbstreams" "github.com/aws/aws-sdk-go/service/dynamodbstreams/dynamodbstreamsiface" "google.golang.org/grpc/metadata" "google.golang.org/protobuf/proto" "github.com/GoogleCloudPlatform/spanner-migration-tool/common/constants" "github.com/GoogleCloudPlatform/spanner-migration-tool/common/metrics" "github.com/GoogleCloudPlatform/spanner-migration-tool/internal" "github.com/GoogleCloudPlatform/spanner-migration-tool/schema" "github.com/GoogleCloudPlatform/spanner-migration-tool/sources/common" ) const ( ESC = 27 retryLimit = 100 ) // NewDynamoDBStream initializes a new DynamoDB Stream for a table with NEW_AND_OLD_IMAGES // StreamViewType. If there exists a stream for a given table then it must be of type // NEW_IMAGE or NEW_AND_OLD_IMAGES otherwise streaming changes for this table won't be captured. // It returns latest Stream Arn for the table along with any error if encountered. func NewDynamoDBStream(client dynamodbiface.DynamoDBAPI, srcTable string) (string, error) { describeTableInput := &dynamodb.DescribeTableInput{ TableName: aws.String(srcTable), } result, err := client.DescribeTable(describeTableInput) if err != nil { return "", fmt.Errorf("unexpected call to DescribeTable: %v", err) } if result.Table.StreamSpecification != nil { switch *result.Table.StreamSpecification.StreamViewType { case dynamodb.StreamViewTypeKeysOnly: return "", fmt.Errorf("error! there exists a stream with KEYS_ONLY StreamViewType") case dynamodb.StreamViewTypeOldImage: return "", fmt.Errorf("error! there exists a stream with OLD_IMAGE StreamViewType") default: return *result.Table.LatestStreamArn, nil } } else { streamSpecification := &dynamodb.StreamSpecification{ StreamEnabled: aws.Bool(true), StreamViewType: aws.String(dynamodb.StreamViewTypeNewAndOldImages), } updateTableInput := &dynamodb.UpdateTableInput{ StreamSpecification: streamSpecification, TableName: aws.String(srcTable), } res, err := client.UpdateTable(updateTableInput) if err != nil { return "", fmt.Errorf("unexpected call to UpdateTable: %v", err) } return *res.TableDescription.LatestStreamArn, nil } } // catchCtrlC catches the Ctrl+C signal if customer wants to exit. func catchCtrlC(wg *sync.WaitGroup, streamInfo *StreamingInfo) { defer wg.Done() c := make(chan os.Signal) signal.Notify(c, os.Interrupt, syscall.SIGTERM) go func() { <-c streamInfo.UserExit = true }() } // clear erases the last printed line on the output file. var clear = fmt.Sprintf("%c[%dA%c[2K", ESC, 1, ESC) // updateProgress updates the customer every minute with number of records processed // and if the current moment is an optimum condition for cutover or not. func updateProgress(optimumCondition, firstCall bool, totalRecordsProcessed int64) { if !firstCall { fmt.Print(strings.Repeat(clear, 2)) } fmt.Printf("Optimum time for switching to Cloud Spanner: %s\n", strconv.FormatBool(optimumCondition)) fmt.Printf("Count of records processed: %s\n", strconv.FormatInt(totalRecordsProcessed, 10)) } // cutoverHelper analyzes the records processed and makes a decision if current moment is // optimum for switching to Cloud Spanner or not. func cutoverHelper(wg *sync.WaitGroup, streamInfo *StreamingInfo) { defer wg.Done() updateProgress(false, true, streamInfo.recordsProcessed) timer := int64(0) firstFiveMin := int64(0) lastFiveMin := int64(0) tillLastMin := int64(0) arr := [5]int64{0, 0, 0, 0, 0} for { time.Sleep(60 * time.Second) if streamInfo.UserExit { break } counter := timer % 5 lastFiveMin -= arr[counter] arr[counter] = streamInfo.recordsProcessed - tillLastMin tillLastMin += arr[counter] lastFiveMin += arr[counter] if timer < 5 { firstFiveMin += arr[counter] } lastMin := arr[counter] optimumCondition := ((lastFiveMin*100 <= 5*firstFiveMin) || (lastMin == 0)) updateProgress(optimumCondition, false, tillLastMin) timer++ } } // ProcessStream processes the latest enabled DynamoDB Stream for a table. It searches // for shards within stream and for each shard it creates a seperate working thread to // process records within it. func ProcessStream(wgStream *sync.WaitGroup, streamClient dynamodbstreamsiface.DynamoDBStreamsAPI, streamInfo *StreamingInfo, conv *internal.Conv, streamArn, srcTable string) { defer wgStream.Done() wgShard := &sync.WaitGroup{} processingStarted := make(map[string]bool) passAfterUserExit := false for { shards, err := scanShards(streamClient, streamArn) if err != nil { streamInfo.Unexpected(fmt.Sprintf("Couldn't scan shards for table %s: %s", srcTable, err)) break } for _, shard := range shards { shardId := *shard.ShardId if _, ok := processingStarted[shardId]; !ok { processingStarted[shardId] = false streamInfo.SetShardStatus(shardId, false) } } for _, shard := range shards { shardId := *shard.ShardId if !processingStarted[shardId] { processingStarted[shardId] = true wgShard.Add(1) go ProcessShard(wgShard, streamInfo, conv, streamClient, shard, streamArn, srcTable) } } if passAfterUserExit { break } else if streamInfo.UserExit { passAfterUserExit = true } else { time.Sleep(20 * time.Second) } } wgShard.Wait() } // scanShards fetches all the shards from a given DynamoDB Stream. func scanShards(streamClient dynamodbstreamsiface.DynamoDBStreamsAPI, streamArn string) ([]*dynamodbstreams.Shard, error) { describeStreamInput := &dynamodbstreams.DescribeStreamInput{ ExclusiveStartShardId: nil, StreamArn: &streamArn, } var scanResult []*dynamodbstreams.Shard for { result, err := streamClient.DescribeStream(describeStreamInput) if err != nil { return nil, fmt.Errorf("unexpected call to DescribeStream: %v", err) } scanResult = append(scanResult, result.StreamDescription.Shards...) if result.StreamDescription.LastEvaluatedShardId == nil { break } else { describeStreamInput.ExclusiveStartShardId = result.StreamDescription.LastEvaluatedShardId } } return scanResult, nil } // checkTrimmedDataError checks if the error is an TrimmedDataAccessException. func checkTrimmedDataError(err error) bool { return strings.Contains(err.Error(), "TrimmedDataAccessException") } // ProcessShard processes records within a shard starting from the first unexpired record. It // doesn't start processing unless parent shard is processed. For closed shards this process is // completed after processing all records but for open shards it keeps searching for new records // until shards gets closed or customer calls for a exit. func ProcessShard(wgShard *sync.WaitGroup, streamInfo *StreamingInfo, conv *internal.Conv, streamClient dynamodbstreamsiface.DynamoDBStreamsAPI, shard *dynamodbstreams.Shard, streamArn, srcTable string) { defer wgShard.Done() waitForParentShard(streamInfo, shard.ParentShardId) shardId := *shard.ShardId var lastEvaluatedSequenceNumber *string = nil passAfterUserExit := false retryCount := 0 for { shardIterator, err := getShardIterator(streamClient, lastEvaluatedSequenceNumber, shardId, streamArn) if err != nil { if checkTrimmedDataError(err) { lastEvaluatedSequenceNumber = nil continue } else { streamInfo.Unexpected(fmt.Sprintf("Couldn't get shardIterator for table %s: %s", srcTable, err)) break } } getRecordsOutput, err := getRecords(streamClient, shardIterator) if err != nil { // In case of closed shards, after all data records get expired it still returns a non-nil // shardIterator for GetShardIterator query. Using this shardIterator for GetRecords // API call results in TrimmedDataAccessException. This will result in same steps being // followed again and again. To handle this a retry limit of 5 is set. if checkTrimmedDataError(err) && retryCount < 5 { lastEvaluatedSequenceNumber = nil retryCount++ continue } else { streamInfo.Unexpected(fmt.Sprintf("Couldn't fetch records for table %s: %s", srcTable, err)) break } } else { retryCount = 0 } records := getRecordsOutput.Records for _, record := range records { ProcessRecord(conv, streamInfo, record, srcTable) lastEvaluatedSequenceNumber = record.Dynamodb.SequenceNumber } if getRecordsOutput.NextShardIterator == nil || passAfterUserExit { break } if streamInfo.UserExit { passAfterUserExit = true } else if len(records) == 0 { time.Sleep(5 * time.Second) } } streamInfo.SetShardStatus(shardId, true) } // waitForParentShard checks every 6 seconds if parentShard is processed or // not and waits as long as parent shard is not processed. func waitForParentShard(streamInfo *StreamingInfo, parentShard *string) { if parentShard != nil { for { streamInfo.lock.Lock() done, ok := streamInfo.ShardProcessed[*parentShard] streamInfo.lock.Unlock() if !ok || done { return } else { time.Sleep(6 * time.Second) } } } } // getShardIterator returns an iterator to find records based on the lastEvaluatedSequence number. // If lastEvaluatedSequenceNumber is nil then it uses TrimHorizon as shardIterator type to point to first // non-expired record otherwise it finds the first unprocessed record after lastEvaluatedSequence number using // AfterSequenceNumber shardIterator type. func getShardIterator(streamClient dynamodbstreamsiface.DynamoDBStreamsAPI, lastEvaluatedSequenceNumber *string, shardId, streamArn string) (*string, error) { var getShardIteratorInput *dynamodbstreams.GetShardIteratorInput if lastEvaluatedSequenceNumber == nil { getShardIteratorInput = &dynamodbstreams.GetShardIteratorInput{ ShardId: &shardId, ShardIteratorType: aws.String(dynamodbstreams.ShardIteratorTypeTrimHorizon), StreamArn: &streamArn, } } else { getShardIteratorInput = &dynamodbstreams.GetShardIteratorInput{ SequenceNumber: lastEvaluatedSequenceNumber, ShardId: &shardId, ShardIteratorType: aws.String(dynamodbstreams.ShardIteratorTypeAfterSequenceNumber), StreamArn: &streamArn, } } result, err := streamClient.GetShardIterator(getShardIteratorInput) if err != nil { err = fmt.Errorf("unexpected call to GetShardIterator: %v", err) return nil, err } return result.ShardIterator, nil } // getRecords fetches the records from DynamoDB Streams by using the shardIterator. func getRecords(streamClient dynamodbstreamsiface.DynamoDBStreamsAPI, shardIterator *string) (*dynamodbstreams.GetRecordsOutput, error) { getRecordsInput := &dynamodbstreams.GetRecordsInput{ ShardIterator: shardIterator, } result, err := streamClient.GetRecords(getRecordsInput) if err != nil { err = fmt.Errorf("unexpected call to GetRecords: %v", err) return result, err } return result, nil } // ProcessRecord processes records retrieved from shards. It first converts the data // to Spanner data (based on the source and Spanner schemas), and then writes that data // to Cloud Spanner. func ProcessRecord(conv *internal.Conv, streamInfo *StreamingInfo, record *dynamodbstreams.Record, srcTable string) { eventName := *record.EventName streamInfo.StatsAddRecord(srcTable, eventName) // todo - write a function that will compute schemas and colums and return tableId, err := internal.GetTableIdFromSrcName(conv.SrcSchema, srcTable) srcSchema, ok1 := conv.SrcSchema[tableId] spSchema, ok2 := conv.SpSchema[tableId] if err != nil || !ok1 || !ok2 { streamInfo.Unexpected(fmt.Sprintf("Can't get tableId and schemas for table %s: %v", srcTable, err)) return } spTable := spSchema.Name spCols := []string{} srcCols := []string{} srcColIds := srcSchema.ColIds spColIds := spSchema.ColIds commonIds := common.IntersectionOfTwoStringSlices(spColIds, srcColIds) for _, colId := range commonIds { spCols = append(spCols, spSchema.ColDefs[colId].Name) srcCols = append(srcCols, srcSchema.ColDefs[colId].Name) } var srcImage map[string]*dynamodb.AttributeValue if eventName == "REMOVE" { srcImage = record.Dynamodb.Keys } else { srcImage = record.Dynamodb.NewImage } spVals, badCols, srcStrVals := cvtRow(srcImage, srcSchema, spSchema, commonIds) if len(badCols) == 0 { writeRecord(streamInfo, srcTable, spTable, eventName, spCols, spVals, srcSchema) } else { streamInfo.StatsAddBadRecord(srcTable, eventName) streamInfo.CollectBadRecord(eventName, srcTable, srcCols, srcStrVals) } streamInfo.StatsAddRecordProcessed() } // writeRecord handles creation and processing of mutation from the converted data to Cloud Spanner. // If the writer which writes mutations to Cloud Spanner is not configured then it treats the record // as a bad record. func writeRecord(streamInfo *StreamingInfo, srcTable, spTable, eventName string, spCols []string, spVals []interface{}, srcSchema schema.Table) { if streamInfo.write == nil { msg := "Internal error: writeRecord called but writer not configured" streamInfo.StatsAddBadRecord(srcTable, eventName) streamInfo.Unexpected(msg) } else { m := getMutation(eventName, srcTable, spTable, spCols, spVals, srcSchema) err := writeMutation(m, streamInfo) if err != nil { streamInfo.StatsAddDroppedRecord(srcTable, eventName) streamInfo.CollectDroppedRecord(eventName, spTable, spCols, spVals, err) } } } // getMutation creates a mutation for writing to Cloud Spanner from the converted data. func getMutation(eventName, srcTable, spTable string, spCols []string, spVals []interface{}, srcSchema schema.Table) (m *sp.Mutation) { if eventName == "INSERT" { m = sp.Insert(spTable, spCols, spVals) } else if eventName == "MODIFY" { m = sp.InsertOrUpdate(spTable, spCols, spVals) } else { m = removeMutation(srcSchema, spTable, srcTable, spVals) } return m } // removeMutation create a mutation from converted data for records of type 'REMOVE'. // It ensures that when keyset is created the order for primary keys passed is same // as the original database i.e. HASH Key, Partition Key. func removeMutation(srcSchema schema.Table, spTable, srcTable string, spVals []interface{}) (m *sp.Mutation) { var srcKeys []string var reqSpVals []interface{} for i := 0; i < len(spVals); i++ { if spVals[i] == nil { continue } srcKeys = append(srcKeys, srcSchema.ColIds[i]) reqSpVals = append(reqSpVals, spVals[i]) } primaryKeys := srcSchema.PrimaryKeys if primaryKeys[0].ColId != srcKeys[0] { reqSpVals[0], reqSpVals[1] = reqSpVals[1], reqSpVals[0] } if len(reqSpVals) == 1 { m = sp.Delete(spTable, sp.Key{reqSpVals[0]}) } else { m = sp.Delete(spTable, sp.Key{reqSpVals[0], reqSpVals[1]}) } return m } // parentDataMissingError is used to track errors where insertions fail because of missing parent data. // // Note: If error code and description for parent row missing error is changed in future, then this // function is subject to change. func parentDataMissingError(err error) bool { return strings.Contains(err.Error(), "NotFound") && strings.Contains(err.Error(), "Parent row") && strings.Contains(err.Error(), "is missing") } // writeMutation handles writing of a mutation to Cloud Spanner. To handle insertions failing // because of missing parent data, a retryLimit is set. func writeMutation(m *sp.Mutation, streamInfo *StreamingInfo) error { var err error tryNum := 0 for tryNum < retryLimit { err = streamInfo.write(m) if err == nil || !parentDataMissingError(err) { break } time.Sleep(4 * time.Second) tryNum++ } return err } // setWriter initializes the write function used to write mutations to Cloud Spanner. func setWriter(streamInfo *StreamingInfo, client *sp.Client, conv *internal.Conv) { streamInfo.write = func(m *sp.Mutation) error { migrationData := metrics.GetMigrationData(conv, "", constants.DataConv) serializedMigrationData, _ := proto.Marshal(migrationData) migrationMetadataValue := base64.StdEncoding.EncodeToString(serializedMigrationData) _, err := client.Apply(metadata.AppendToOutgoingContext(context.Background(), constants.MigrationMetadataKey, migrationMetadataValue), []*sp.Mutation{m}) return err } } // fillConvWithStreamingStats passes the information related to processing of DynamoDB Streams // to conv object for report and bad data file. func fillConvWithStreamingStats(streamInfo *StreamingInfo, conv *internal.Conv) { // Pass Unexpected Conditions for unexpectedCondition, count := range streamInfo.Unexpecteds { conv.Unexpected(unexpectedCondition) if _, ok := conv.Stats.Unexpected[unexpectedCondition]; ok { conv.Stats.Unexpected[unexpectedCondition] += (count - 1) } } conv.Audit.StreamingStats.Streaming = true // Pass count stats to conv conv.Audit.StreamingStats.TotalRecords = streamInfo.Records conv.Audit.StreamingStats.BadRecords = streamInfo.BadRecords conv.Audit.StreamingStats.DroppedRecords = streamInfo.DroppedRecords // Pass badRecords and droppedRecords conv.Audit.StreamingStats.SampleBadRecords = streamInfo.SampleBadRecords conv.Audit.StreamingStats.SampleBadWrites = streamInfo.SampleBadWrites }