common/persistence/nosql/nosqlplugin/cassandra/shard.go (250 lines of code) (raw):

// Copyright (c) 2020 Uber Technologies, Inc. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package cassandra import ( "context" "fmt" "strings" "time" "github.com/uber/cadence/common" "github.com/uber/cadence/common/persistence" "github.com/uber/cadence/common/persistence/nosql/nosqlplugin" ) // InsertShard creates a new shard, return error is there is any. // Return ShardOperationConditionFailure if the condition doesn't meet func (db *cdb) InsertShard(ctx context.Context, row *nosqlplugin.ShardRow) error { cqlNowTimestamp := persistence.UnixNanoToDBTimestamp(db.timeSrc.Now().UnixNano()) markerData, markerEncoding := persistence.FromDataBlob(row.PendingFailoverMarkers) transferPQS, transferPQSEncoding := persistence.FromDataBlob(row.TransferProcessingQueueStates) crossClusterPQS, crossClusterPQSEncoding := persistence.FromDataBlob(row.CrossClusterProcessingQueueStates) timerPQS, timerPQSEncoding := persistence.FromDataBlob(row.TimerProcessingQueueStates) query := db.session.Query(templateCreateShardQuery, row.ShardID, rowTypeShard, rowTypeShardDomainID, rowTypeShardWorkflowID, rowTypeShardRunID, defaultVisibilityTimestamp, rowTypeShardTaskID, row.ShardID, row.Owner, row.RangeID, row.StolenSinceRenew, cqlNowTimestamp, row.ReplicationAckLevel, row.TransferAckLevel, row.TimerAckLevel, row.ClusterTransferAckLevel, row.ClusterTimerAckLevel, transferPQS, transferPQSEncoding, crossClusterPQS, crossClusterPQSEncoding, timerPQS, timerPQSEncoding, row.DomainNotificationVersion, row.ClusterReplicationLevel, row.ReplicationDLQAckLevel, markerData, markerEncoding, row.RangeID, ).WithContext(ctx) previous := make(map[string]interface{}) applied, err := query.MapScanCAS(previous) if err != nil { return err } if !applied { return convertToConflictedShardRow(previous) } return nil } func convertToConflictedShardRow(previous map[string]interface{}) error { rangeID := previous["range_id"].(int64) var columns []string for k, v := range previous { columns = append(columns, fmt.Sprintf("%s=%v", k, v)) } return &nosqlplugin.ShardOperationConditionFailure{ RangeID: rangeID, Details: strings.Join(columns, ","), } } // SelectShard gets a shard func (db *cdb) SelectShard(ctx context.Context, shardID int, currentClusterName string) (int64, *nosqlplugin.ShardRow, error) { query := db.session.Query(templateGetShardQuery, shardID, rowTypeShard, rowTypeShardDomainID, rowTypeShardWorkflowID, rowTypeShardRunID, defaultVisibilityTimestamp, rowTypeShardTaskID, ).WithContext(ctx) result := make(map[string]interface{}) if err := query.MapScan(result); err != nil { return 0, nil, err } rangeID := result["range_id"].(int64) shard := result["shard"].(map[string]interface{}) shardInfoRangeID := shard["range_id"].(int64) return rangeID, convertToShardInfo(currentClusterName, shardInfoRangeID, shard), nil } func convertToShardInfo( currentCluster string, rangeID int64, shard map[string]interface{}, ) *nosqlplugin.ShardRow { var pendingFailoverMarkersRawData []byte var pendingFailoverMarkersEncoding string var transferProcessingQueueStatesRawData []byte var transferProcessingQueueStatesEncoding string var crossClusterProcessingQueueStatesRawData []byte var crossClusterProcessingQueueStatesEncoding string var timerProcessingQueueStatesRawData []byte var timerProcessingQueueStatesEncoding string info := &persistence.InternalShardInfo{} info.RangeID = rangeID for k, v := range shard { switch k { case "shard_id": info.ShardID = v.(int) case "owner": info.Owner = v.(string) case "stolen_since_renew": info.StolenSinceRenew = v.(int) case "updated_at": info.UpdatedAt = v.(time.Time) case "replication_ack_level": info.ReplicationAckLevel = v.(int64) case "transfer_ack_level": info.TransferAckLevel = v.(int64) case "timer_ack_level": info.TimerAckLevel = v.(time.Time) case "cluster_transfer_ack_level": info.ClusterTransferAckLevel = v.(map[string]int64) case "cluster_timer_ack_level": info.ClusterTimerAckLevel = v.(map[string]time.Time) case "transfer_processing_queue_states": transferProcessingQueueStatesRawData = v.([]byte) case "transfer_processing_queue_states_encoding": transferProcessingQueueStatesEncoding = v.(string) case "cross_cluster_processing_queue_states": crossClusterProcessingQueueStatesRawData = v.([]byte) case "cross_cluster_processing_queue_states_encoding": crossClusterProcessingQueueStatesEncoding = v.(string) case "timer_processing_queue_states": timerProcessingQueueStatesRawData = v.([]byte) case "timer_processing_queue_states_encoding": timerProcessingQueueStatesEncoding = v.(string) case "domain_notification_version": info.DomainNotificationVersion = v.(int64) case "cluster_replication_level": info.ClusterReplicationLevel = v.(map[string]int64) case "replication_dlq_ack_level": info.ReplicationDLQAckLevel = v.(map[string]int64) case "pending_failover_markers": pendingFailoverMarkersRawData = v.([]byte) case "pending_failover_markers_encoding": pendingFailoverMarkersEncoding = v.(string) } } if info.ClusterTransferAckLevel == nil { info.ClusterTransferAckLevel = map[string]int64{ currentCluster: info.TransferAckLevel, } } if info.ClusterTimerAckLevel == nil { info.ClusterTimerAckLevel = map[string]time.Time{ currentCluster: info.TimerAckLevel, } } if info.ClusterReplicationLevel == nil { info.ClusterReplicationLevel = make(map[string]int64) } if info.ReplicationDLQAckLevel == nil { info.ReplicationDLQAckLevel = make(map[string]int64) } info.PendingFailoverMarkers = persistence.NewDataBlob( pendingFailoverMarkersRawData, common.EncodingType(pendingFailoverMarkersEncoding), ) info.TransferProcessingQueueStates = persistence.NewDataBlob( transferProcessingQueueStatesRawData, common.EncodingType(transferProcessingQueueStatesEncoding), ) info.CrossClusterProcessingQueueStates = persistence.NewDataBlob( crossClusterProcessingQueueStatesRawData, common.EncodingType(crossClusterProcessingQueueStatesEncoding), ) info.TimerProcessingQueueStates = persistence.NewDataBlob( timerProcessingQueueStatesRawData, common.EncodingType(timerProcessingQueueStatesEncoding), ) return info } // UpdateRangeID updates the rangeID, return error is there is any // Return ShardOperationConditionFailure if the condition doesn't meet func (db *cdb) UpdateRangeID(ctx context.Context, shardID int, rangeID int64, previousRangeID int64) error { query := db.session.Query(templateUpdateRangeIDQuery, rangeID, shardID, rowTypeShard, rowTypeShardDomainID, rowTypeShardWorkflowID, rowTypeShardRunID, defaultVisibilityTimestamp, rowTypeShardTaskID, previousRangeID, ).WithContext(ctx) previous := make(map[string]interface{}) applied, err := query.MapScanCAS(previous) if err != nil { return err } if !applied { return convertToConflictedShardRow(previous) } return nil } // UpdateShard updates a shard, return error is there is any. // Return ShardOperationConditionFailure if the condition doesn't meet func (db *cdb) UpdateShard(ctx context.Context, row *nosqlplugin.ShardRow, previousRangeID int64) error { cqlNowTimestamp := persistence.UnixNanoToDBTimestamp(db.timeSrc.Now().UnixNano()) markerData, markerEncoding := persistence.FromDataBlob(row.PendingFailoverMarkers) transferPQS, transferPQSEncoding := persistence.FromDataBlob(row.TransferProcessingQueueStates) crossClusterPQS, crossClusterPQSEncoding := persistence.FromDataBlob(row.CrossClusterProcessingQueueStates) timerPQS, timerPQSEncoding := persistence.FromDataBlob(row.TimerProcessingQueueStates) query := db.session.Query(templateUpdateShardQuery, row.ShardID, row.Owner, row.RangeID, row.StolenSinceRenew, cqlNowTimestamp, row.ReplicationAckLevel, row.TransferAckLevel, row.TimerAckLevel, row.ClusterTransferAckLevel, row.ClusterTimerAckLevel, transferPQS, transferPQSEncoding, crossClusterPQS, crossClusterPQSEncoding, timerPQS, timerPQSEncoding, row.DomainNotificationVersion, row.ClusterReplicationLevel, row.ReplicationDLQAckLevel, markerData, markerEncoding, row.RangeID, row.ShardID, rowTypeShard, rowTypeShardDomainID, rowTypeShardWorkflowID, rowTypeShardRunID, defaultVisibilityTimestamp, rowTypeShardTaskID, previousRangeID, ).WithContext(ctx) previous := make(map[string]interface{}) applied, err := query.MapScanCAS(previous) if err != nil { return err } if !applied { return convertToConflictedShardRow(previous) } return nil }