common/persistence/sql/sql_task_store.go (526 lines of code) (raw):

// Copyright (c) 2020 Uber Technologies, Inc. // Portions of the Software are attributed to Copyright (c) 2020 Temporal Technologies Inc. // // Permission is hereby granted, free of charge, to any person obtaining a copy // of this software and associated documentation files (the "Software"), to deal // in the Software without restriction, including without limitation the rights // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell // copies of the Software, and to permit persons to whom the Software is // furnished to do so, subject to the following conditions: // // The above copyright notice and this permission notice shall be included in // all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. package sql import ( "context" "database/sql" "fmt" "math" "time" "github.com/uber/cadence/common" "github.com/uber/cadence/common/log" "github.com/uber/cadence/common/persistence" "github.com/uber/cadence/common/persistence/serialization" "github.com/uber/cadence/common/persistence/sql/sqlplugin" "github.com/uber/cadence/common/types" ) type sqlTaskStore struct { sqlStore nShards int } var ( stickyTasksListsTTL = time.Hour * 24 ) // newTaskPersistence creates a new instance of TaskManager func newTaskPersistence( db sqlplugin.DB, nShards int, log log.Logger, parser serialization.Parser, ) (persistence.TaskStore, error) { return &sqlTaskStore{ sqlStore: sqlStore{ db: db, logger: log, parser: parser, }, nShards: nShards, }, nil } func (m *sqlTaskStore) GetTaskListSize(ctx context.Context, request *persistence.GetTaskListSizeRequest) (*persistence.GetTaskListSizeResponse, error) { dbShardID := sqlplugin.GetDBShardIDFromDomainIDAndTasklist(request.DomainID, request.TaskListName, m.db.GetTotalNumDBShards()) domainID := serialization.MustParseUUID(request.DomainID) size, err := m.db.GetTasksCount(ctx, &sqlplugin.TasksFilter{ ShardID: dbShardID, DomainID: domainID, TaskListName: request.TaskListName, TaskType: int64(request.TaskListType), MinTaskID: &request.AckLevel, }) if err != nil { return nil, convertCommonErrors(m.db, "GetTaskListSize", "", err) } return &persistence.GetTaskListSizeResponse{Size: size}, nil } func (m *sqlTaskStore) LeaseTaskList( ctx context.Context, request *persistence.LeaseTaskListRequest, ) (*persistence.LeaseTaskListResponse, error) { var rangeID int64 var ackLevel int64 dbShardID := sqlplugin.GetDBShardIDFromDomainIDAndTasklist(request.DomainID, request.TaskList, m.db.GetTotalNumDBShards()) domainID := serialization.MustParseUUID(request.DomainID) rows, err := m.db.SelectFromTaskLists(ctx, &sqlplugin.TaskListsFilter{ ShardID: dbShardID, DomainID: &domainID, Name: &request.TaskList, TaskType: common.Int64Ptr(int64(request.TaskType))}) if err != nil { if err == sql.ErrNoRows { tlInfo := &serialization.TaskListInfo{ AckLevel: ackLevel, Kind: int16(request.TaskListKind), ExpiryTimestamp: time.Unix(0, 0), LastUpdated: time.Now(), } blob, err := m.parser.TaskListInfoToBlob(tlInfo) if err != nil { return nil, err } row := sqlplugin.TaskListsRow{ ShardID: dbShardID, DomainID: domainID, Name: request.TaskList, TaskType: int64(request.TaskType), Data: blob.Data, DataEncoding: string(blob.Encoding), } rows = []sqlplugin.TaskListsRow{row} if m.db.SupportsTTL() && request.TaskListKind == persistence.TaskListKindSticky { rowWithTTL := sqlplugin.TaskListsRowWithTTL{ TaskListsRow: row, TTL: stickyTasksListsTTL, } if _, err := m.db.InsertIntoTaskListsWithTTL(ctx, &rowWithTTL); err != nil { return nil, convertCommonErrors(m.db, "LeaseTaskListWithTTL", fmt.Sprintf("Failed to make task list %v of type %v.", request.TaskList, request.TaskType), err) } } else { if _, err := m.db.InsertIntoTaskLists(ctx, &row); err != nil { return nil, convertCommonErrors(m.db, "LeaseTaskList", fmt.Sprintf("Failed to make task list %v of type %v.", request.TaskList, request.TaskType), err) } } } else { return nil, convertCommonErrors(m.db, "LeaseTaskList", "Failed to check if task list existed.", err) } } row := rows[0] if request.RangeID > 0 && request.RangeID != row.RangeID { return nil, &persistence.ConditionFailedError{ Msg: fmt.Sprintf("leaseTaskList:renew failed:taskList:%v, taskListType:%v, haveRangeID:%v, gotRangeID:%v", request.TaskList, request.TaskType, rangeID, row.RangeID), } } tlInfo, err := m.parser.TaskListInfoFromBlob(row.Data, row.DataEncoding) if err != nil { return nil, err } var resp *persistence.LeaseTaskListResponse err = m.txExecute(ctx, dbShardID, "LeaseTaskList", func(tx sqlplugin.Tx) error { rangeID = row.RangeID ackLevel = tlInfo.GetAckLevel() // We need to separately check the condition and do the // update because we want to throw different error codes. // Since we need to do things separately (in a transaction), we need to take a lock. err1 := lockTaskList(ctx, tx, dbShardID, domainID, request.TaskList, request.TaskType, rangeID) if err1 != nil { return err1 } now := time.Now() tlInfo.LastUpdated = now blob, err1 := m.parser.TaskListInfoToBlob(tlInfo) if err1 != nil { return err1 } row := &sqlplugin.TaskListsRow{ ShardID: dbShardID, DomainID: row.DomainID, RangeID: row.RangeID + 1, Name: row.Name, TaskType: row.TaskType, Data: blob.Data, DataEncoding: string(blob.Encoding), } var result sql.Result if tlInfo.GetKind() == persistence.TaskListKindSticky && m.db.SupportsTTL() { result, err1 = tx.UpdateTaskListsWithTTL(ctx, &sqlplugin.TaskListsRowWithTTL{ TaskListsRow: *row, TTL: stickyTasksListsTTL, }) } else { result, err1 = tx.UpdateTaskLists(ctx, row) } if err1 != nil { return err1 } rowsAffected, err1 := result.RowsAffected() if err1 != nil { return fmt.Errorf("rowsAffected error: %v", err1) } if rowsAffected == 0 { return fmt.Errorf("%v rows affected instead of 1", rowsAffected) } resp = &persistence.LeaseTaskListResponse{TaskListInfo: &persistence.TaskListInfo{ DomainID: request.DomainID, Name: request.TaskList, TaskType: request.TaskType, RangeID: rangeID + 1, AckLevel: ackLevel, Kind: request.TaskListKind, LastUpdated: now, }} return nil }) return resp, err } func (m *sqlTaskStore) UpdateTaskList( ctx context.Context, request *persistence.UpdateTaskListRequest, ) (*persistence.UpdateTaskListResponse, error) { dbShardID := sqlplugin.GetDBShardIDFromDomainIDAndTasklist(request.TaskListInfo.DomainID, request.TaskListInfo.Name, m.db.GetTotalNumDBShards()) domainID := serialization.MustParseUUID(request.TaskListInfo.DomainID) tlInfo := &serialization.TaskListInfo{ AckLevel: request.TaskListInfo.AckLevel, Kind: int16(request.TaskListInfo.Kind), ExpiryTimestamp: time.Unix(0, 0), LastUpdated: time.Now(), } if request.TaskListInfo.Kind == persistence.TaskListKindSticky { tlInfo.ExpiryTimestamp = stickyTaskListExpiry() } var resp *persistence.UpdateTaskListResponse blob, err := m.parser.TaskListInfoToBlob(tlInfo) if err != nil { return nil, err } err = m.txExecute(ctx, dbShardID, "UpdateTaskList", func(tx sqlplugin.Tx) error { err1 := lockTaskList( ctx, tx, dbShardID, domainID, request.TaskListInfo.Name, request.TaskListInfo.TaskType, request.TaskListInfo.RangeID) if err1 != nil { return err1 } var result sql.Result row := &sqlplugin.TaskListsRow{ ShardID: dbShardID, DomainID: domainID, RangeID: request.TaskListInfo.RangeID, Name: request.TaskListInfo.Name, TaskType: int64(request.TaskListInfo.TaskType), Data: blob.Data, DataEncoding: string(blob.Encoding), } if m.db.SupportsTTL() && request.TaskListInfo.Kind == persistence.TaskListKindSticky { result, err1 = tx.UpdateTaskListsWithTTL(ctx, &sqlplugin.TaskListsRowWithTTL{ TaskListsRow: *row, TTL: stickyTasksListsTTL, }) } else { result, err1 = tx.UpdateTaskLists(ctx, row) } if err1 != nil { return err1 } rowsAffected, err1 := result.RowsAffected() if err1 != nil { return err1 } if rowsAffected != 1 { return fmt.Errorf("%v rows were affected instead of 1", rowsAffected) } resp = &persistence.UpdateTaskListResponse{} return nil }) return resp, err } type taskListPageToken struct { ShardID int DomainID serialization.UUID Name string TaskType int64 } // ListTaskList lists tasklist from DB // DomainID translates into byte array in SQL. The minUUID is not the minimum byte array. func (m *sqlTaskStore) ListTaskList( ctx context.Context, request *persistence.ListTaskListRequest, ) (*persistence.ListTaskListResponse, error) { pageToken := taskListPageToken{DomainID: serialization.UUID{}} if len(request.PageToken) > 0 { if err := gobDeserialize(request.PageToken, &pageToken); err != nil { return nil, &types.InternalServiceError{Message: fmt.Sprintf("error deserializing page token: %v", err)} } } else { pageToken = taskListPageToken{TaskType: math.MinInt16, DomainID: serialization.UUID{}} } var err error var rows []sqlplugin.TaskListsRow for pageToken.ShardID < m.nShards { rows, err = m.db.SelectFromTaskLists(ctx, &sqlplugin.TaskListsFilter{ ShardID: pageToken.ShardID, DomainIDGreaterThan: &pageToken.DomainID, NameGreaterThan: &pageToken.Name, TaskTypeGreaterThan: &pageToken.TaskType, PageSize: &request.PageSize, }) if err != nil { return nil, convertCommonErrors(m.db, "ListTaskList", "", err) } if len(rows) > 0 { break } pageToken = taskListPageToken{ShardID: pageToken.ShardID + 1, TaskType: math.MinInt16, DomainID: serialization.UUID{}} } var nextPageToken []byte switch { case len(rows) >= request.PageSize: lastRow := &rows[request.PageSize-1] nextPageToken, err = gobSerialize(&taskListPageToken{ ShardID: pageToken.ShardID, DomainID: lastRow.DomainID, Name: lastRow.Name, TaskType: lastRow.TaskType, }) case pageToken.ShardID+1 < m.nShards: nextPageToken, err = gobSerialize(&taskListPageToken{ShardID: pageToken.ShardID + 1, TaskType: math.MinInt16, DomainID: serialization.UUID{}}) } if err != nil { return nil, &types.InternalServiceError{Message: fmt.Sprintf("error serializing nextPageToken:%v", err)} } resp := &persistence.ListTaskListResponse{ Items: make([]persistence.TaskListInfo, len(rows)), NextPageToken: nextPageToken, } for i := range rows { info, err := m.parser.TaskListInfoFromBlob(rows[i].Data, rows[i].DataEncoding) if err != nil { return nil, err } resp.Items[i].DomainID = rows[i].DomainID.String() resp.Items[i].Name = rows[i].Name resp.Items[i].TaskType = int(rows[i].TaskType) resp.Items[i].RangeID = rows[i].RangeID resp.Items[i].Kind = int(info.GetKind()) resp.Items[i].AckLevel = info.GetAckLevel() resp.Items[i].Expiry = info.GetExpiryTimestamp() resp.Items[i].LastUpdated = info.GetLastUpdated() } return resp, nil } func (m *sqlTaskStore) DeleteTaskList( ctx context.Context, request *persistence.DeleteTaskListRequest, ) error { shardID := sqlplugin.GetDBShardIDFromDomainIDAndTasklist(request.DomainID, request.TaskListName, m.db.GetTotalNumDBShards()) domainID := serialization.MustParseUUID(request.DomainID) result, err := m.db.DeleteFromTaskLists(ctx, &sqlplugin.TaskListsFilter{ ShardID: shardID, DomainID: &domainID, Name: &request.TaskListName, TaskType: common.Int64Ptr(int64(request.TaskListType)), RangeID: &request.RangeID, }) if err != nil { return convertCommonErrors(m.db, "DeleteTaskList", "", err) } nRows, err := result.RowsAffected() if err != nil { return &types.InternalServiceError{Message: fmt.Sprintf("rowsAffected returned error:%v", err)} } if nRows != 1 { return &types.InternalServiceError{Message: fmt.Sprintf("delete failed: %v rows affected instead of 1", nRows)} } return nil } func (m *sqlTaskStore) CreateTasks( ctx context.Context, request *persistence.InternalCreateTasksRequest, ) (*persistence.CreateTasksResponse, error) { var tasksRows []sqlplugin.TasksRow var tasksRowsWithTTL []sqlplugin.TasksRowWithTTL if m.db.SupportsTTL() { tasksRowsWithTTL = make([]sqlplugin.TasksRowWithTTL, len(request.Tasks)) } else { tasksRows = make([]sqlplugin.TasksRow, len(request.Tasks)) } dbShardID := sqlplugin.GetDBShardIDFromDomainIDAndTasklist(request.TaskListInfo.DomainID, request.TaskListInfo.Name, m.db.GetTotalNumDBShards()) for i, v := range request.Tasks { var expiryTime time.Time var ttl time.Duration if v.Data.ScheduleToStartTimeout.Seconds() > 0 { ttl = v.Data.ScheduleToStartTimeout if m.db.SupportsTTL() { maxAllowedTTL, err := m.db.MaxAllowedTTL() if err != nil { return nil, err } if ttl > *maxAllowedTTL { ttl = *maxAllowedTTL } } expiryTime = time.Now().Add(ttl) } blob, err := m.parser.TaskInfoToBlob(&serialization.TaskInfo{ WorkflowID: v.Data.WorkflowID, RunID: serialization.MustParseUUID(v.Data.RunID), ScheduleID: v.Data.ScheduleID, ExpiryTimestamp: expiryTime, CreatedTimestamp: time.Now(), PartitionConfig: v.Data.PartitionConfig, }) if err != nil { return nil, err } currTasksRow := sqlplugin.TasksRow{ ShardID: dbShardID, DomainID: serialization.MustParseUUID(v.Data.DomainID), TaskListName: request.TaskListInfo.Name, TaskType: int64(request.TaskListInfo.TaskType), TaskID: v.TaskID, Data: blob.Data, DataEncoding: string(blob.Encoding), } if m.db.SupportsTTL() { currTasksRowWithTTL := sqlplugin.TasksRowWithTTL{ TasksRow: currTasksRow, } if ttl > 0 { currTasksRowWithTTL.TTL = &ttl } tasksRowsWithTTL[i] = currTasksRowWithTTL } else { tasksRows[i] = currTasksRow } } var resp *persistence.CreateTasksResponse err := m.txExecute(ctx, dbShardID, "CreateTasks", func(tx sqlplugin.Tx) error { if m.db.SupportsTTL() { if _, err := tx.InsertIntoTasksWithTTL(ctx, tasksRowsWithTTL); err != nil { return err } } else { if _, err := tx.InsertIntoTasks(ctx, tasksRows); err != nil { return err } } // Lock task list before committing. err1 := lockTaskList(ctx, tx, dbShardID, serialization.MustParseUUID(request.TaskListInfo.DomainID), request.TaskListInfo.Name, request.TaskListInfo.TaskType, request.TaskListInfo.RangeID) if err1 != nil { return err1 } resp = &persistence.CreateTasksResponse{} return nil }) return resp, err } func (m *sqlTaskStore) GetTasks( ctx context.Context, request *persistence.GetTasksRequest, ) (*persistence.InternalGetTasksResponse, error) { shardID := sqlplugin.GetDBShardIDFromDomainIDAndTasklist(request.DomainID, request.TaskList, m.db.GetTotalNumDBShards()) rows, err := m.db.SelectFromTasks(ctx, &sqlplugin.TasksFilter{ ShardID: shardID, DomainID: serialization.MustParseUUID(request.DomainID), TaskListName: request.TaskList, TaskType: int64(request.TaskType), MinTaskID: &request.ReadLevel, MaxTaskID: request.MaxReadLevel, PageSize: &request.BatchSize, }) if err != nil { return nil, convertCommonErrors(m.db, "GetTasks", "", err) } var tasks = make([]*persistence.InternalTaskInfo, len(rows)) for i, v := range rows { info, err := m.parser.TaskInfoFromBlob(v.Data, v.DataEncoding) if err != nil { return nil, err } tasks[i] = &persistence.InternalTaskInfo{ DomainID: request.DomainID, WorkflowID: info.GetWorkflowID(), RunID: info.RunID.String(), TaskID: v.TaskID, ScheduleID: info.GetScheduleID(), Expiry: info.GetExpiryTimestamp(), CreatedTime: info.GetCreatedTimestamp(), PartitionConfig: info.GetPartitionConfig(), } } return &persistence.InternalGetTasksResponse{Tasks: tasks}, nil } func (m *sqlTaskStore) CompleteTask( ctx context.Context, request *persistence.CompleteTaskRequest, ) error { taskID := request.TaskID taskList := request.TaskList shardID := sqlplugin.GetDBShardIDFromDomainIDAndTasklist(taskList.DomainID, taskList.Name, m.db.GetTotalNumDBShards()) _, err := m.db.DeleteFromTasks(ctx, &sqlplugin.TasksFilter{ ShardID: shardID, DomainID: serialization.MustParseUUID(taskList.DomainID), TaskListName: taskList.Name, TaskType: int64(taskList.TaskType), TaskID: &taskID}) if err != nil { return convertCommonErrors(m.db, "CompleteTask", "", err) } return nil } func (m *sqlTaskStore) CompleteTasksLessThan( ctx context.Context, request *persistence.CompleteTasksLessThanRequest, ) (*persistence.CompleteTasksLessThanResponse, error) { shardID := sqlplugin.GetDBShardIDFromDomainIDAndTasklist(request.DomainID, request.TaskListName, m.db.GetTotalNumDBShards()) result, err := m.db.DeleteFromTasks(ctx, &sqlplugin.TasksFilter{ ShardID: shardID, DomainID: serialization.MustParseUUID(request.DomainID), TaskListName: request.TaskListName, TaskType: int64(request.TaskType), TaskIDLessThanEquals: &request.TaskID, Limit: &request.Limit, }) if err != nil { return nil, convertCommonErrors(m.db, "CompleteTasksLessThan", "", err) } nRows, err := result.RowsAffected() if err != nil { return nil, &types.InternalServiceError{ Message: fmt.Sprintf("rowsAffected returned error: %v", err), } } return &persistence.CompleteTasksLessThanResponse{TasksCompleted: int(nRows)}, nil } // GetOrphanTasks gets tasks from the tasks table that belong to a task_list no longer present // in the task_lists table. // TODO: Limit this query to a specific shard at a time. See https://github.com/uber/cadence/issues/4064 func (m *sqlTaskStore) GetOrphanTasks(ctx context.Context, request *persistence.GetOrphanTasksRequest) (*persistence.GetOrphanTasksResponse, error) { rows, err := m.db.GetOrphanTasks(ctx, &sqlplugin.OrphanTasksFilter{ Limit: &request.Limit, }) if err != nil { return nil, convertCommonErrors(m.db, "GetOrphanTasks", "", err) } var tasks = make([]*persistence.TaskKey, len(rows)) for i, v := range rows { tasks[i] = &persistence.TaskKey{ DomainID: v.DomainID.String(), TaskListName: v.TaskListName, TaskType: int(v.TaskType), TaskID: v.TaskID, } } return &persistence.GetOrphanTasksResponse{Tasks: tasks}, nil } func lockTaskList(ctx context.Context, tx sqlplugin.Tx, shardID int, domainID serialization.UUID, name string, taskListType int, oldRangeID int64) error { rangeID, err := tx.LockTaskLists(ctx, &sqlplugin.TaskListsFilter{ ShardID: shardID, DomainID: &domainID, Name: &name, TaskType: common.Int64Ptr(int64(taskListType))}) switch err { case nil: if rangeID != oldRangeID { return &persistence.ConditionFailedError{ Msg: fmt.Sprintf("Task list range ID was %v when it was should have been %v", rangeID, oldRangeID), } } return nil case sql.ErrNoRows: return &persistence.ConditionFailedError{ Msg: "Task list does not exist.", } default: return convertCommonErrors(tx, "lockTaskList", "", err) } } func stickyTaskListExpiry() time.Time { return time.Now().Add(stickyTasksListsTTL) }