common/persistence/sql/sql_shard_store.go (249 lines of code) (raw):

// Copyright (c) 2018 Uber Technologies, Inc. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package sql import ( "context" "database/sql" "fmt" "time" "github.com/uber/cadence/common" "github.com/uber/cadence/common/log" "github.com/uber/cadence/common/persistence" "github.com/uber/cadence/common/persistence/serialization" "github.com/uber/cadence/common/persistence/sql/sqlplugin" "github.com/uber/cadence/common/types" ) type sqlShardStore struct { sqlStore currentClusterName string } // NewShardPersistence creates an instance of ShardStore func NewShardPersistence( db sqlplugin.DB, currentClusterName string, log log.Logger, parser serialization.Parser, ) (persistence.ShardStore, error) { return &sqlShardStore{ sqlStore: sqlStore{ db: db, logger: log, parser: parser, }, currentClusterName: currentClusterName, }, nil } func (m *sqlShardStore) CreateShard( ctx context.Context, request *persistence.InternalCreateShardRequest, ) error { if _, err := m.GetShard(ctx, &persistence.InternalGetShardRequest{ ShardID: request.ShardInfo.ShardID, }); err == nil { return &persistence.ShardAlreadyExistError{ Msg: fmt.Sprintf("CreateShard operation failed. Shard with ID %v already exists.", request.ShardInfo.ShardID), } } row, err := shardInfoToShardsRow(*request.ShardInfo, m.parser) if err != nil { return &types.InternalServiceError{ Message: fmt.Sprintf("CreateShard operation failed. Error: %v", err), } } if _, err := m.db.InsertIntoShards(ctx, row); err != nil { return convertCommonErrors(m.db, "CreateShard", "Failed to insert into shards table.", err) } return nil } func (m *sqlShardStore) GetShard( ctx context.Context, request *persistence.InternalGetShardRequest, ) (*persistence.InternalGetShardResponse, error) { row, err := m.db.SelectFromShards(ctx, &sqlplugin.ShardsFilter{ShardID: int64(request.ShardID)}) if err != nil { return nil, convertCommonErrors(m.db, "GetShard", fmt.Sprintf("Failed to get shard, ShardId: %v.", request.ShardID), err) } shardInfo, err := m.parser.ShardInfoFromBlob(row.Data, row.DataEncoding) if err != nil { return nil, err } if len(shardInfo.ClusterTransferAckLevel) == 0 { shardInfo.ClusterTransferAckLevel = map[string]int64{ m.currentClusterName: shardInfo.GetTransferAckLevel(), } } timerAckLevel := make(map[string]time.Time, len(shardInfo.ClusterTimerAckLevel)) for k, v := range shardInfo.ClusterTimerAckLevel { timerAckLevel[k] = v } if len(timerAckLevel) == 0 { timerAckLevel = map[string]time.Time{ m.currentClusterName: shardInfo.GetTimerAckLevel(), } } if shardInfo.ClusterReplicationLevel == nil { shardInfo.ClusterReplicationLevel = make(map[string]int64) } if shardInfo.ReplicationDlqAckLevel == nil { shardInfo.ReplicationDlqAckLevel = make(map[string]int64) } var transferPQS *persistence.DataBlob if shardInfo.GetTransferProcessingQueueStates() != nil { transferPQS = &persistence.DataBlob{ Encoding: common.EncodingType(shardInfo.GetTransferProcessingQueueStatesEncoding()), Data: shardInfo.GetTransferProcessingQueueStates(), } } var crossClusterPQS *persistence.DataBlob if shardInfo.GetCrossClusterProcessingQueueStates() != nil { crossClusterPQS = &persistence.DataBlob{ Encoding: common.EncodingType(shardInfo.GetCrossClusterProcessingQueueStatesEncoding()), Data: shardInfo.GetCrossClusterProcessingQueueStates(), } } var timerPQS *persistence.DataBlob if shardInfo.GetTimerProcessingQueueStates() != nil { timerPQS = &persistence.DataBlob{ Encoding: common.EncodingType(shardInfo.GetTimerProcessingQueueStatesEncoding()), Data: shardInfo.GetTimerProcessingQueueStates(), } } resp := &persistence.InternalGetShardResponse{ShardInfo: &persistence.InternalShardInfo{ ShardID: int(row.ShardID), RangeID: row.RangeID, Owner: shardInfo.GetOwner(), StolenSinceRenew: int(shardInfo.GetStolenSinceRenew()), UpdatedAt: shardInfo.GetUpdatedAt(), ReplicationAckLevel: shardInfo.GetReplicationAckLevel(), TransferAckLevel: shardInfo.GetTransferAckLevel(), TimerAckLevel: shardInfo.GetTimerAckLevel(), ClusterTransferAckLevel: shardInfo.ClusterTransferAckLevel, ClusterTimerAckLevel: timerAckLevel, TransferProcessingQueueStates: transferPQS, CrossClusterProcessingQueueStates: crossClusterPQS, TimerProcessingQueueStates: timerPQS, DomainNotificationVersion: shardInfo.GetDomainNotificationVersion(), ClusterReplicationLevel: shardInfo.ClusterReplicationLevel, ReplicationDLQAckLevel: shardInfo.ReplicationDlqAckLevel, }} return resp, nil } func (m *sqlShardStore) UpdateShard( ctx context.Context, request *persistence.InternalUpdateShardRequest, ) error { row, err := shardInfoToShardsRow(*request.ShardInfo, m.parser) if err != nil { return &types.InternalServiceError{ Message: fmt.Sprintf("UpdateShard operation failed. Error: %v", err), } } dbShardID := sqlplugin.GetDBShardIDFromHistoryShardID(request.ShardInfo.ShardID, m.db.GetTotalNumDBShards()) return m.txExecute(ctx, dbShardID, "UpdateShard", func(tx sqlplugin.Tx) error { if err := lockShard(ctx, tx, request.ShardInfo.ShardID, request.PreviousRangeID); err != nil { return err } result, err := tx.UpdateShards(ctx, row) if err != nil { return err } rowsAffected, err := result.RowsAffected() if err != nil { return fmt.Errorf("rowsAffected returned error for shardID %v: %v", request.ShardInfo.ShardID, err) } if rowsAffected != 1 { return fmt.Errorf("rowsAffected returned %v shards instead of one", rowsAffected) } return nil }) } // initiated by the owning shard func lockShard(ctx context.Context, tx sqlplugin.Tx, shardID int, oldRangeID int64) error { rangeID, err := tx.WriteLockShards(ctx, &sqlplugin.ShardsFilter{ShardID: int64(shardID)}) if err != nil { if err == sql.ErrNoRows { return &types.InternalServiceError{ Message: fmt.Sprintf("Failed to lock shard with ID %v that does not exist.", shardID), } } return convertCommonErrors(tx, "lockShard", fmt.Sprintf("Failed to lock shard with ID: %v.", shardID), err) } if int64(rangeID) != oldRangeID { return &persistence.ShardOwnershipLostError{ ShardID: shardID, Msg: fmt.Sprintf("Failed to update shard. Previous range ID: %v; new range ID: %v", oldRangeID, rangeID), } } return nil } // initiated by the owning shard func readLockShard(ctx context.Context, tx sqlplugin.Tx, shardID int, oldRangeID int64) error { rangeID, err := tx.ReadLockShards(ctx, &sqlplugin.ShardsFilter{ShardID: int64(shardID)}) if err != nil { if err == sql.ErrNoRows { return &types.InternalServiceError{ Message: fmt.Sprintf("Failed to lock shard with ID %v that does not exist.", shardID), } } return convertCommonErrors(tx, "readLockShard", fmt.Sprintf("Failed to lock shard with ID: %v.", shardID), err) } if int64(rangeID) != oldRangeID { return &persistence.ShardOwnershipLostError{ ShardID: shardID, Msg: fmt.Sprintf("Failed to lock shard. Previous range ID: %v; new range ID: %v", oldRangeID, rangeID), } } return nil } func shardInfoToShardsRow(s persistence.InternalShardInfo, parser serialization.Parser) (*sqlplugin.ShardsRow, error) { var markerData []byte markerEncoding := string(common.EncodingTypeEmpty) if s.PendingFailoverMarkers != nil { markerData = s.PendingFailoverMarkers.Data markerEncoding = string(s.PendingFailoverMarkers.Encoding) } var transferPQSData []byte transferPQSEncoding := string(common.EncodingTypeEmpty) if s.TransferProcessingQueueStates != nil { transferPQSData = s.TransferProcessingQueueStates.Data transferPQSEncoding = string(s.TransferProcessingQueueStates.Encoding) } var crossClusterPQS []byte crossClusterPQSEncoding := string(common.EncodingTypeEmpty) if s.CrossClusterProcessingQueueStates != nil { crossClusterPQS = s.CrossClusterProcessingQueueStates.Data crossClusterPQSEncoding = string(s.CrossClusterProcessingQueueStates.Encoding) } var timerPQSData []byte timerPQSEncoding := string(common.EncodingTypeEmpty) if s.TimerProcessingQueueStates != nil { timerPQSData = s.TimerProcessingQueueStates.Data timerPQSEncoding = string(s.TimerProcessingQueueStates.Encoding) } shardInfo := &serialization.ShardInfo{ StolenSinceRenew: int32(s.StolenSinceRenew), UpdatedAt: s.UpdatedAt, ReplicationAckLevel: s.ReplicationAckLevel, TransferAckLevel: s.TransferAckLevel, TimerAckLevel: s.TimerAckLevel, ClusterTransferAckLevel: s.ClusterTransferAckLevel, ClusterTimerAckLevel: s.ClusterTimerAckLevel, TransferProcessingQueueStates: transferPQSData, TransferProcessingQueueStatesEncoding: transferPQSEncoding, CrossClusterProcessingQueueStates: crossClusterPQS, CrossClusterProcessingQueueStatesEncoding: crossClusterPQSEncoding, TimerProcessingQueueStates: timerPQSData, TimerProcessingQueueStatesEncoding: timerPQSEncoding, DomainNotificationVersion: s.DomainNotificationVersion, Owner: s.Owner, ClusterReplicationLevel: s.ClusterReplicationLevel, ReplicationDlqAckLevel: s.ReplicationDLQAckLevel, PendingFailoverMarkers: markerData, PendingFailoverMarkersEncoding: markerEncoding, } blob, err := parser.ShardInfoToBlob(shardInfo) if err != nil { return nil, err } return &sqlplugin.ShardsRow{ ShardID: int64(s.ShardID), RangeID: s.RangeID, Data: blob.Data, DataEncoding: string(blob.Encoding), }, nil }