common/persistence/sql/sql_queue_store.go (232 lines of code) (raw):

// Copyright (c) 2019 Uber Technologies, Inc. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package sql import ( "context" "database/sql" "fmt" "github.com/uber/cadence/common/log" "github.com/uber/cadence/common/persistence" "github.com/uber/cadence/common/persistence/sql/sqlplugin" "github.com/uber/cadence/common/types" ) type ( sqlQueueStore struct { queueType persistence.QueueType logger log.Logger sqlStore } ) func newQueueStore( db sqlplugin.DB, logger log.Logger, queueType persistence.QueueType, ) (persistence.Queue, error) { return &sqlQueueStore{ sqlStore: sqlStore{ db: db, logger: logger, }, queueType: queueType, logger: logger, }, nil } func (q *sqlQueueStore) EnqueueMessage( ctx context.Context, messagePayload []byte, ) error { return q.txExecute(ctx, sqlplugin.DbDefaultShard, "EnqueueMessage", func(tx sqlplugin.Tx) error { lastMessageID, err := tx.GetLastEnqueuedMessageIDForUpdate(ctx, q.queueType) if err != nil { if err == sql.ErrNoRows { lastMessageID = -1 } else { return err } } ackLevels, err := tx.GetAckLevels(ctx, q.queueType, true) if err != nil { return err } _, err = tx.InsertIntoQueue(ctx, newQueueRow(q.queueType, getNextID(ackLevels, lastMessageID), messagePayload)) return err }) } func (q *sqlQueueStore) ReadMessages( ctx context.Context, lastMessageID int64, maxCount int, ) ([]*persistence.InternalQueueMessage, error) { rows, err := q.db.GetMessagesFromQueue(ctx, q.queueType, lastMessageID, maxCount) if err != nil { return nil, convertCommonErrors(q.db, "ReadMessages", "", err) } var messages []*persistence.InternalQueueMessage for _, row := range rows { messages = append(messages, &persistence.InternalQueueMessage{ID: row.MessageID, Payload: row.MessagePayload}) } return messages, nil } func newQueueRow( queueType persistence.QueueType, messageID int64, payload []byte, ) *sqlplugin.QueueRow { return &sqlplugin.QueueRow{QueueType: queueType, MessageID: messageID, MessagePayload: payload} } func (q *sqlQueueStore) DeleteMessagesBefore( ctx context.Context, messageID int64, ) error { _, err := q.db.DeleteMessagesBefore(ctx, q.queueType, messageID) if err != nil { return convertCommonErrors(q.db, "DeleteMessagesBefore", "", err) } return nil } func (q *sqlQueueStore) UpdateAckLevel( ctx context.Context, messageID int64, clusterName string, ) error { return q.txExecute(ctx, sqlplugin.DbDefaultShard, "UpdateAckLevel", func(tx sqlplugin.Tx) error { clusterAckLevels, err := tx.GetAckLevels(ctx, q.queueType, true) if err != nil { return err } if clusterAckLevels == nil { return tx.InsertAckLevel(ctx, q.queueType, messageID, clusterName) } // Ignore possibly delayed message if ackLevel, ok := clusterAckLevels[clusterName]; ok && ackLevel >= messageID { return nil } clusterAckLevels[clusterName] = messageID return tx.UpdateAckLevels(ctx, q.queueType, clusterAckLevels) }) } func (q *sqlQueueStore) GetAckLevels( ctx context.Context, ) (map[string]int64, error) { result, err := q.db.GetAckLevels(ctx, q.queueType, false) if err != nil { return nil, convertCommonErrors(q.db, "GetAckLevels", "", err) } return result, nil } func (q *sqlQueueStore) EnqueueMessageToDLQ( ctx context.Context, messagePayload []byte, ) error { return q.txExecute(ctx, sqlplugin.DbDefaultShard, "EnqueueMessageToDLQ", func(tx sqlplugin.Tx) error { var err error lastMessageID, err := tx.GetLastEnqueuedMessageIDForUpdate(ctx, q.getDLQTypeFromQueueType()) if err != nil { if err == sql.ErrNoRows { lastMessageID = -1 } else { return err } } _, err = tx.InsertIntoQueue(ctx, newQueueRow(q.getDLQTypeFromQueueType(), lastMessageID+1, messagePayload)) return err }) } func (q *sqlQueueStore) ReadMessagesFromDLQ( ctx context.Context, firstMessageID int64, lastMessageID int64, pageSize int, pageToken []byte, ) ([]*persistence.InternalQueueMessage, []byte, error) { if len(pageToken) != 0 { lastReadMessageID, err := deserializePageToken(pageToken) if err != nil { return nil, nil, &types.InternalServiceError{ Message: fmt.Sprintf("invalid next page token %v", pageToken)} } firstMessageID = lastReadMessageID } rows, err := q.db.GetMessagesBetween(ctx, q.getDLQTypeFromQueueType(), firstMessageID, lastMessageID, pageSize) if err != nil { return nil, nil, convertCommonErrors(q.db, "ReadMessagesFromDLQ", "", err) } var messages []*persistence.InternalQueueMessage for _, row := range rows { messages = append(messages, &persistence.InternalQueueMessage{ID: row.MessageID, Payload: row.MessagePayload}) } var newPagingToken []byte if messages != nil && len(messages) >= pageSize { lastReadMessageID := messages[len(messages)-1].ID newPagingToken = serializePageToken(int64(lastReadMessageID)) } return messages, newPagingToken, nil } func (q *sqlQueueStore) DeleteMessageFromDLQ( ctx context.Context, messageID int64, ) error { _, err := q.db.DeleteMessage(ctx, q.getDLQTypeFromQueueType(), messageID) if err != nil { return convertCommonErrors(q.db, "DeleteMessageFromDLQ", "", err) } return nil } func (q *sqlQueueStore) RangeDeleteMessagesFromDLQ( ctx context.Context, firstMessageID int64, lastMessageID int64, ) error { _, err := q.db.RangeDeleteMessages(ctx, q.getDLQTypeFromQueueType(), firstMessageID, lastMessageID) if err != nil { return convertCommonErrors(q.db, "RangeDeleteMessagesFromDLQ", "", err) } return nil } func (q *sqlQueueStore) UpdateDLQAckLevel( ctx context.Context, messageID int64, clusterName string, ) error { return q.txExecute(ctx, sqlplugin.DbDefaultShard, "UpdateDLQAckLevel", func(tx sqlplugin.Tx) error { clusterAckLevels, err := tx.GetAckLevels(ctx, q.getDLQTypeFromQueueType(), true) if err != nil { return err } if clusterAckLevels == nil { return tx.InsertAckLevel(ctx, q.getDLQTypeFromQueueType(), messageID, clusterName) } // Ignore possibly delayed message if ackLevel, ok := clusterAckLevels[clusterName]; ok && ackLevel >= messageID { return nil } clusterAckLevels[clusterName] = messageID return tx.UpdateAckLevels(ctx, q.getDLQTypeFromQueueType(), clusterAckLevels) }) } func (q *sqlQueueStore) GetDLQAckLevels( ctx context.Context, ) (map[string]int64, error) { result, err := q.db.GetAckLevels(ctx, q.getDLQTypeFromQueueType(), false) if err != nil { return nil, convertCommonErrors(q.db, "GetDLQAckLevels", "", err) } return result, nil } func (q *sqlQueueStore) GetDLQSize( ctx context.Context, ) (int64, error) { result, err := q.db.GetQueueSize(ctx, q.getDLQTypeFromQueueType()) if err != nil { return 0, convertCommonErrors(q.db, "GetDLQSize", "", err) } return result, nil } func (q *sqlQueueStore) getDLQTypeFromQueueType() persistence.QueueType { return -q.queueType } // if, for whatever reason, the ack-levels get ahead of the actual messages // then ensure the next ID follows func getNextID(acks map[string]int64, lastMessageID int64) int64 { o := lastMessageID for _, v := range acks { if v > o { o = v } } return o + 1 }