common/persistence/sql/sql_execution_store_util.go (1,267 lines of code) (raw):
// Copyright (c) 2017-2020 Uber Technologies, Inc.
// Portions of the Software are attributed to Copyright (c) 2020 Temporal Technologies Inc.
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.
package sql
import (
"bytes"
"context"
"database/sql"
"fmt"
"time"
"github.com/uber/cadence/common"
p "github.com/uber/cadence/common/persistence"
"github.com/uber/cadence/common/persistence/serialization"
"github.com/uber/cadence/common/persistence/sql/sqlplugin"
"github.com/uber/cadence/common/types"
)
func applyWorkflowMutationTx(
ctx context.Context,
tx sqlplugin.Tx,
shardID int,
workflowMutation *p.InternalWorkflowMutation,
parser serialization.Parser,
) error {
executionInfo := workflowMutation.ExecutionInfo
versionHistories := workflowMutation.VersionHistories
workflowChecksum := workflowMutation.ChecksumData
startVersion := workflowMutation.StartVersion
lastWriteVersion := workflowMutation.LastWriteVersion
domainID := serialization.MustParseUUID(executionInfo.DomainID)
workflowID := executionInfo.WorkflowID
runID := serialization.MustParseUUID(executionInfo.RunID)
// TODO Remove me if UPDATE holds the lock to the end of a transaction
if err := lockAndCheckNextEventID(
ctx,
tx,
shardID,
domainID,
workflowID,
runID,
workflowMutation.Condition); err != nil {
return err
}
if err := updateExecution(
ctx,
tx,
executionInfo,
versionHistories,
workflowChecksum,
startVersion,
lastWriteVersion,
shardID,
parser); err != nil {
return err
}
if err := applyTasks(
ctx,
tx,
shardID,
domainID,
workflowID,
runID,
workflowMutation.TransferTasks,
workflowMutation.CrossClusterTasks,
workflowMutation.ReplicationTasks,
workflowMutation.TimerTasks,
parser,
); err != nil {
return err
}
if err := updateActivityInfos(
ctx,
tx,
workflowMutation.UpsertActivityInfos,
workflowMutation.DeleteActivityInfos,
shardID,
domainID,
workflowID,
runID,
parser,
); err != nil {
return err
}
if err := updateTimerInfos(
ctx,
tx,
workflowMutation.UpsertTimerInfos,
workflowMutation.DeleteTimerInfos,
shardID,
domainID,
workflowID,
runID,
parser,
); err != nil {
return err
}
if err := updateChildExecutionInfos(
ctx,
tx,
workflowMutation.UpsertChildExecutionInfos,
workflowMutation.DeleteChildExecutionInfos,
shardID,
domainID,
workflowID,
runID,
parser,
); err != nil {
return err
}
if err := updateRequestCancelInfos(
ctx,
tx,
workflowMutation.UpsertRequestCancelInfos,
workflowMutation.DeleteRequestCancelInfos,
shardID,
domainID,
workflowID,
runID,
parser,
); err != nil {
return err
}
if err := updateSignalInfos(
ctx,
tx,
workflowMutation.UpsertSignalInfos,
workflowMutation.DeleteSignalInfos,
shardID,
domainID,
workflowID,
runID,
parser,
); err != nil {
return err
}
if err := updateSignalsRequested(
ctx,
tx,
workflowMutation.UpsertSignalRequestedIDs,
workflowMutation.DeleteSignalRequestedIDs,
shardID,
domainID,
workflowID,
runID,
); err != nil {
return err
}
if workflowMutation.ClearBufferedEvents {
if err := deleteBufferedEvents(
ctx,
tx,
shardID,
domainID,
workflowID,
runID,
); err != nil {
return err
}
}
return updateBufferedEvents(
ctx,
tx,
workflowMutation.NewBufferedEvents,
shardID,
domainID,
workflowID,
runID,
)
}
func applyWorkflowSnapshotTxAsReset(
ctx context.Context,
tx sqlplugin.Tx,
shardID int,
workflowSnapshot *p.InternalWorkflowSnapshot,
parser serialization.Parser,
) error {
executionInfo := workflowSnapshot.ExecutionInfo
versionHistories := workflowSnapshot.VersionHistories
workflowChecksum := workflowSnapshot.ChecksumData
startVersion := workflowSnapshot.StartVersion
lastWriteVersion := workflowSnapshot.LastWriteVersion
domainID := serialization.MustParseUUID(executionInfo.DomainID)
workflowID := executionInfo.WorkflowID
runID := serialization.MustParseUUID(executionInfo.RunID)
// TODO Is there a way to modify the various map tables without fear of other people adding rows after we delete, without locking the executions row?
if err := lockAndCheckNextEventID(
ctx,
tx,
shardID,
domainID,
workflowID,
runID,
workflowSnapshot.Condition); err != nil {
return err
}
if err := updateExecution(
ctx,
tx,
executionInfo,
versionHistories,
workflowChecksum,
startVersion,
lastWriteVersion,
shardID,
parser); err != nil {
return err
}
if err := applyTasks(
ctx,
tx,
shardID,
domainID,
workflowID,
runID,
workflowSnapshot.TransferTasks,
workflowSnapshot.CrossClusterTasks,
workflowSnapshot.ReplicationTasks,
workflowSnapshot.TimerTasks,
parser,
); err != nil {
return err
}
if err := deleteActivityInfoMap(
ctx,
tx,
shardID,
domainID,
workflowID,
runID); err != nil {
return err
}
if err := updateActivityInfos(
ctx,
tx,
workflowSnapshot.ActivityInfos,
nil,
shardID,
domainID,
workflowID,
runID,
parser); err != nil {
return err
}
if err := deleteTimerInfoMap(
ctx,
tx,
shardID,
domainID,
workflowID,
runID); err != nil {
return err
}
if err := updateTimerInfos(
ctx,
tx,
workflowSnapshot.TimerInfos,
nil,
shardID,
domainID,
workflowID,
runID,
parser); err != nil {
return err
}
if err := deleteChildExecutionInfoMap(
ctx,
tx,
shardID,
domainID,
workflowID,
runID); err != nil {
return err
}
if err := updateChildExecutionInfos(
ctx,
tx,
workflowSnapshot.ChildExecutionInfos,
nil,
shardID,
domainID,
workflowID,
runID,
parser); err != nil {
return err
}
if err := deleteRequestCancelInfoMap(
ctx,
tx,
shardID,
domainID,
workflowID,
runID); err != nil {
return err
}
if err := updateRequestCancelInfos(
ctx,
tx,
workflowSnapshot.RequestCancelInfos,
nil,
shardID,
domainID,
workflowID,
runID,
parser); err != nil {
return err
}
if err := deleteSignalInfoMap(
ctx,
tx,
shardID,
domainID,
workflowID,
runID); err != nil {
return err
}
if err := updateSignalInfos(
ctx,
tx,
workflowSnapshot.SignalInfos,
nil,
shardID,
domainID,
workflowID,
runID,
parser); err != nil {
return err
}
if err := deleteSignalsRequestedSet(
ctx,
tx,
shardID,
domainID,
workflowID,
runID); err != nil {
return err
}
if err := updateSignalsRequested(
ctx,
tx,
workflowSnapshot.SignalRequestedIDs,
nil,
shardID,
domainID,
workflowID,
runID); err != nil {
return err
}
return deleteBufferedEvents(
ctx,
tx,
shardID,
domainID,
workflowID,
runID)
}
func applyWorkflowSnapshotTxAsNew(
ctx context.Context,
tx sqlplugin.Tx,
shardID int,
workflowSnapshot *p.InternalWorkflowSnapshot,
parser serialization.Parser,
) error {
executionInfo := workflowSnapshot.ExecutionInfo
versionHistories := workflowSnapshot.VersionHistories
workflowChecksum := workflowSnapshot.ChecksumData
startVersion := workflowSnapshot.StartVersion
lastWriteVersion := workflowSnapshot.LastWriteVersion
domainID := serialization.MustParseUUID(executionInfo.DomainID)
workflowID := executionInfo.WorkflowID
runID := serialization.MustParseUUID(executionInfo.RunID)
if err := createExecution(
ctx,
tx,
executionInfo,
versionHistories,
workflowChecksum,
startVersion,
lastWriteVersion,
shardID,
parser); err != nil {
return err
}
if err := applyTasks(
ctx,
tx,
shardID,
domainID,
workflowID,
runID,
workflowSnapshot.TransferTasks,
workflowSnapshot.CrossClusterTasks,
workflowSnapshot.ReplicationTasks,
workflowSnapshot.TimerTasks,
parser,
); err != nil {
return err
}
if err := updateActivityInfos(
ctx,
tx,
workflowSnapshot.ActivityInfos,
nil,
shardID,
domainID,
workflowID,
runID,
parser); err != nil {
return err
}
if err := updateTimerInfos(
ctx,
tx,
workflowSnapshot.TimerInfos,
nil,
shardID,
domainID,
workflowID,
runID,
parser); err != nil {
return err
}
if err := updateChildExecutionInfos(
ctx,
tx,
workflowSnapshot.ChildExecutionInfos,
nil,
shardID,
domainID,
workflowID,
runID,
parser); err != nil {
return err
}
if err := updateRequestCancelInfos(
ctx,
tx,
workflowSnapshot.RequestCancelInfos,
nil,
shardID,
domainID,
workflowID,
runID,
parser); err != nil {
return err
}
if err := updateSignalInfos(
ctx,
tx,
workflowSnapshot.SignalInfos,
nil,
shardID,
domainID,
workflowID,
runID,
parser); err != nil {
return err
}
return updateSignalsRequested(
ctx,
tx,
workflowSnapshot.SignalRequestedIDs,
nil,
shardID,
domainID,
workflowID,
runID)
}
func applyTasks(
ctx context.Context,
tx sqlplugin.Tx,
shardID int,
domainID serialization.UUID,
workflowID string,
runID serialization.UUID,
transferTasks []p.Task,
crossClusterTasks []p.Task,
replicationTasks []p.Task,
timerTasks []p.Task,
parser serialization.Parser,
) error {
if err := createTransferTasks(
ctx,
tx,
transferTasks,
shardID,
domainID,
workflowID,
runID,
parser); err != nil {
return err
}
if err := createCrossClusterTasks(
ctx,
tx,
crossClusterTasks,
shardID,
domainID,
workflowID,
runID,
parser); err != nil {
return err
}
if err := createReplicationTasks(
ctx,
tx,
replicationTasks,
shardID,
domainID,
workflowID,
runID,
parser,
); err != nil {
return err
}
return createTimerTasks(
ctx,
tx,
timerTasks,
shardID,
domainID,
workflowID,
runID,
parser,
)
}
// lockCurrentExecutionIfExists returns current execution or nil if none is found for the workflowID
// locking it in the DB
func lockCurrentExecutionIfExists(
ctx context.Context,
tx sqlplugin.Tx,
shardID int,
domainID serialization.UUID,
workflowID string,
) (*sqlplugin.CurrentExecutionsRow, error) {
rows, err := tx.LockCurrentExecutionsJoinExecutions(ctx, &sqlplugin.CurrentExecutionsFilter{
ShardID: int64(shardID),
DomainID: domainID,
WorkflowID: workflowID,
})
if err != nil {
if err != sql.ErrNoRows {
return nil, convertCommonErrors(tx, "lockCurrentExecutionIfExists", fmt.Sprintf("Failed to get current_executions row for (shard,domain,workflow) = (%v, %v, %v).", shardID, domainID, workflowID), err)
}
}
size := len(rows)
if size > 1 {
return nil, &types.InternalServiceError{
Message: fmt.Sprintf("lockCurrentExecutionIfExists failed. Multiple current_executions rows for (shard,domain,workflow) = (%v, %v, %v).", shardID, domainID, workflowID),
}
}
if size == 0 {
return nil, nil
}
return &rows[0], nil
}
func createOrUpdateCurrentExecution(
ctx context.Context,
tx sqlplugin.Tx,
createMode p.CreateWorkflowMode,
shardID int,
domainID serialization.UUID,
workflowID string,
runID serialization.UUID,
state int,
closeStatus int,
createRequestID string,
startVersion int64,
lastWriteVersion int64,
) error {
row := sqlplugin.CurrentExecutionsRow{
ShardID: int64(shardID),
DomainID: domainID,
WorkflowID: workflowID,
RunID: runID,
CreateRequestID: createRequestID,
State: state,
CloseStatus: closeStatus,
StartVersion: startVersion,
LastWriteVersion: lastWriteVersion,
}
switch createMode {
case p.CreateWorkflowModeContinueAsNew,
p.CreateWorkflowModeWorkflowIDReuse:
if err := updateCurrentExecution(
ctx,
tx,
shardID,
domainID,
workflowID,
runID,
createRequestID,
state,
closeStatus,
row.StartVersion,
row.LastWriteVersion); err != nil {
return err
}
case p.CreateWorkflowModeBrandNew:
if _, err := tx.InsertIntoCurrentExecutions(ctx, &row); err != nil {
return convertCommonErrors(tx, "createOrUpdateCurrentExecution", "Failed to insert into current_executions table.", err)
}
case p.CreateWorkflowModeZombie:
// noop
default:
return fmt.Errorf("createOrUpdateCurrentExecution failed. Unknown workflow creation mode: %v", createMode)
}
return nil
}
func lockAndCheckNextEventID(
ctx context.Context,
tx sqlplugin.Tx,
shardID int,
domainID serialization.UUID,
workflowID string,
runID serialization.UUID,
condition int64,
) error {
nextEventID, err := lockNextEventID(
ctx,
tx,
shardID,
domainID,
workflowID,
runID,
)
if err != nil {
return err
}
if *nextEventID != condition {
return &p.ConditionFailedError{
Msg: fmt.Sprintf("lockAndCheckNextEventID failed. Next_event_id was %v when it should have been %v.", nextEventID, condition),
}
}
return nil
}
func lockNextEventID(
ctx context.Context,
tx sqlplugin.Tx,
shardID int,
domainID serialization.UUID,
workflowID string,
runID serialization.UUID,
) (*int64, error) {
nextEventID, err := tx.WriteLockExecutions(ctx, &sqlplugin.ExecutionsFilter{
ShardID: shardID,
DomainID: domainID,
WorkflowID: workflowID,
RunID: runID,
})
if err != nil {
if err == sql.ErrNoRows {
return nil, &types.EntityNotExistsError{
Message: fmt.Sprintf(
"lockNextEventID failed. Unable to lock executions row with (shard, domain, workflow, run) = (%v,%v,%v,%v) which does not exist.",
shardID,
domainID,
workflowID,
runID,
),
}
}
return nil, convertCommonErrors(tx, "lockNextEventID", "", err)
}
result := int64(nextEventID)
return &result, nil
}
func createCrossClusterTasks(
ctx context.Context,
tx sqlplugin.Tx,
crossClusterTasks []p.Task,
shardID int,
domainID serialization.UUID,
workflowID string,
runID serialization.UUID,
parser serialization.Parser,
) error {
if len(crossClusterTasks) == 0 {
return nil
}
crossClusterTasksRows := make([]sqlplugin.CrossClusterTasksRow, len(crossClusterTasks))
for i, task := range crossClusterTasks {
info := &serialization.CrossClusterTaskInfo{
DomainID: domainID,
WorkflowID: workflowID,
RunID: runID,
TaskType: int16(task.GetType()),
TargetDomainID: domainID,
TargetWorkflowID: p.TransferTaskTransferTargetWorkflowID,
TargetRunID: serialization.UUID(p.CrossClusterTaskDefaultTargetRunID),
ScheduleID: 0,
Version: task.GetVersion(),
VisibilityTimestamp: task.GetVisibilityTimestamp(),
}
crossClusterTasksRows[i].ShardID = shardID
crossClusterTasksRows[i].TaskID = task.GetTaskID()
switch task.GetType() {
case p.CrossClusterTaskTypeStartChildExecution:
crossClusterTasksRows[i].TargetCluster = task.(*p.CrossClusterStartChildExecutionTask).TargetCluster
info.TargetDomainID = serialization.MustParseUUID(task.(*p.CrossClusterStartChildExecutionTask).TargetDomainID)
info.TargetWorkflowID = task.(*p.CrossClusterStartChildExecutionTask).TargetWorkflowID
info.ScheduleID = task.(*p.CrossClusterStartChildExecutionTask).InitiatedID
case p.CrossClusterTaskTypeCancelExecution:
crossClusterTasksRows[i].TargetCluster = task.(*p.CrossClusterCancelExecutionTask).TargetCluster
info.TargetDomainID = serialization.MustParseUUID(task.(*p.CrossClusterCancelExecutionTask).TargetDomainID)
info.TargetWorkflowID = task.(*p.CrossClusterCancelExecutionTask).TargetWorkflowID
if targetRunID := task.(*p.CrossClusterCancelExecutionTask).TargetRunID; targetRunID != "" {
info.TargetRunID = serialization.MustParseUUID(targetRunID)
}
info.TargetChildWorkflowOnly = task.(*p.CrossClusterCancelExecutionTask).TargetChildWorkflowOnly
info.ScheduleID = task.(*p.CrossClusterCancelExecutionTask).InitiatedID
case p.CrossClusterTaskTypeSignalExecution:
crossClusterTasksRows[i].TargetCluster = task.(*p.CrossClusterSignalExecutionTask).TargetCluster
info.TargetDomainID = serialization.MustParseUUID(task.(*p.CrossClusterSignalExecutionTask).TargetDomainID)
info.TargetWorkflowID = task.(*p.CrossClusterSignalExecutionTask).TargetWorkflowID
if targetRunID := task.(*p.CrossClusterSignalExecutionTask).TargetRunID; targetRunID != "" {
info.TargetRunID = serialization.MustParseUUID(targetRunID)
}
info.TargetChildWorkflowOnly = task.(*p.CrossClusterSignalExecutionTask).TargetChildWorkflowOnly
info.ScheduleID = task.(*p.CrossClusterSignalExecutionTask).InitiatedID
case p.CrossClusterTaskTypeRecordChildExeuctionCompleted:
crossClusterTasksRows[i].TargetCluster = task.(*p.CrossClusterRecordChildExecutionCompletedTask).TargetCluster
info.TargetDomainID = serialization.MustParseUUID(task.(*p.CrossClusterRecordChildExecutionCompletedTask).TargetDomainID)
info.TargetWorkflowID = task.(*p.CrossClusterRecordChildExecutionCompletedTask).TargetWorkflowID
if targetRunID := task.(*p.CrossClusterRecordChildExecutionCompletedTask).TargetRunID; targetRunID != "" {
info.TargetRunID = serialization.MustParseUUID(targetRunID)
}
case p.CrossClusterTaskTypeApplyParentClosePolicy:
crossClusterTasksRows[i].TargetCluster = task.(*p.CrossClusterApplyParentClosePolicyTask).TargetCluster
for domainID := range task.(*p.CrossClusterApplyParentClosePolicyTask).TargetDomainIDs {
info.TargetDomainIDs = append(info.TargetDomainIDs, serialization.MustParseUUID(domainID))
}
default:
return &types.InternalServiceError{
Message: fmt.Sprintf("Unknown cross-cluster task type: %v", task.GetType()),
}
}
blob, err := parser.CrossClusterTaskInfoToBlob(info)
if err != nil {
return err
}
crossClusterTasksRows[i].Data = blob.Data
crossClusterTasksRows[i].DataEncoding = string(blob.Encoding)
}
result, err := tx.InsertIntoCrossClusterTasks(ctx, crossClusterTasksRows)
if err != nil {
return convertCommonErrors(tx, "createCrossClusterTasks", "", err)
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return &types.InternalServiceError{
Message: fmt.Sprintf("createTransferTasks failed. Could not verify number of rows inserted. Error: %v", err),
}
}
if int(rowsAffected) != len(crossClusterTasks) {
return &types.InternalServiceError{
Message: fmt.Sprintf("createCrossClusterTasks failed. Inserted %v instead of %v rows into transfer_tasks. Error: %v", rowsAffected, len(crossClusterTasks), err),
}
}
return nil
}
func createTransferTasks(
ctx context.Context,
tx sqlplugin.Tx,
transferTasks []p.Task,
shardID int,
domainID serialization.UUID,
workflowID string,
runID serialization.UUID,
parser serialization.Parser,
) error {
if len(transferTasks) == 0 {
return nil
}
transferTasksRows := make([]sqlplugin.TransferTasksRow, len(transferTasks))
for i, task := range transferTasks {
info := &serialization.TransferTaskInfo{
DomainID: domainID,
WorkflowID: workflowID,
RunID: runID,
TaskType: int16(task.GetType()),
TargetDomainID: domainID,
TargetWorkflowID: p.TransferTaskTransferTargetWorkflowID,
ScheduleID: 0,
Version: task.GetVersion(),
VisibilityTimestamp: task.GetVisibilityTimestamp(),
}
transferTasksRows[i].ShardID = shardID
transferTasksRows[i].TaskID = task.GetTaskID()
switch task.GetType() {
case p.TransferTaskTypeActivityTask:
info.TargetDomainID = serialization.MustParseUUID(task.(*p.ActivityTask).DomainID)
info.TaskList = task.(*p.ActivityTask).TaskList
info.ScheduleID = task.(*p.ActivityTask).ScheduleID
case p.TransferTaskTypeDecisionTask:
info.TargetDomainID = serialization.MustParseUUID(task.(*p.DecisionTask).DomainID)
info.TaskList = task.(*p.DecisionTask).TaskList
info.ScheduleID = task.(*p.DecisionTask).ScheduleID
case p.TransferTaskTypeCancelExecution:
info.TargetDomainID = serialization.MustParseUUID(task.(*p.CancelExecutionTask).TargetDomainID)
info.TargetWorkflowID = task.(*p.CancelExecutionTask).TargetWorkflowID
if targetRunID := task.(*p.CancelExecutionTask).TargetRunID; targetRunID != "" {
info.TargetRunID = serialization.MustParseUUID(targetRunID)
}
info.TargetChildWorkflowOnly = task.(*p.CancelExecutionTask).TargetChildWorkflowOnly
info.ScheduleID = task.(*p.CancelExecutionTask).InitiatedID
case p.TransferTaskTypeSignalExecution:
info.TargetDomainID = serialization.MustParseUUID(task.(*p.SignalExecutionTask).TargetDomainID)
info.TargetWorkflowID = task.(*p.SignalExecutionTask).TargetWorkflowID
if targetRunID := task.(*p.SignalExecutionTask).TargetRunID; targetRunID != "" {
info.TargetRunID = serialization.MustParseUUID(targetRunID)
}
info.TargetChildWorkflowOnly = task.(*p.SignalExecutionTask).TargetChildWorkflowOnly
info.ScheduleID = task.(*p.SignalExecutionTask).InitiatedID
case p.TransferTaskTypeStartChildExecution:
info.TargetDomainID = serialization.MustParseUUID(task.(*p.StartChildExecutionTask).TargetDomainID)
info.TargetWorkflowID = task.(*p.StartChildExecutionTask).TargetWorkflowID
info.ScheduleID = task.(*p.StartChildExecutionTask).InitiatedID
case p.TransferTaskTypeRecordChildExecutionCompleted:
info.TargetDomainID = serialization.MustParseUUID(task.(*p.RecordChildExecutionCompletedTask).TargetDomainID)
info.TargetWorkflowID = task.(*p.RecordChildExecutionCompletedTask).TargetWorkflowID
if targetRunID := task.(*p.RecordChildExecutionCompletedTask).TargetRunID; targetRunID != "" {
info.TargetRunID = serialization.MustParseUUID(targetRunID)
}
case p.TransferTaskTypeApplyParentClosePolicy:
for targetDomainID := range task.(*p.ApplyParentClosePolicyTask).TargetDomainIDs {
info.TargetDomainIDs = append(info.TargetDomainIDs, serialization.MustParseUUID(targetDomainID))
}
case p.TransferTaskTypeCloseExecution,
p.TransferTaskTypeRecordWorkflowStarted,
p.TransferTaskTypeResetWorkflow,
p.TransferTaskTypeUpsertWorkflowSearchAttributes,
p.TransferTaskTypeRecordWorkflowClosed:
// No explicit property needs to be set
default:
return &types.InternalServiceError{
Message: fmt.Sprintf("createTransferTasks failed. Unknown transfer type: %v", task.GetType()),
}
}
blob, err := parser.TransferTaskInfoToBlob(info)
if err != nil {
return err
}
transferTasksRows[i].Data = blob.Data
transferTasksRows[i].DataEncoding = string(blob.Encoding)
}
result, err := tx.InsertIntoTransferTasks(ctx, transferTasksRows)
if err != nil {
return convertCommonErrors(tx, "createTransferTasks", "", err)
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return &types.InternalServiceError{
Message: fmt.Sprintf("createTransferTasks failed. Could not verify number of rows inserted. Error: %v", err),
}
}
if int(rowsAffected) != len(transferTasks) {
return &types.InternalServiceError{
Message: fmt.Sprintf("createTransferTasks failed. Inserted %v instead of %v rows into transfer_tasks. Error: %v", rowsAffected, len(transferTasks), err),
}
}
return nil
}
func createReplicationTasks(
ctx context.Context,
tx sqlplugin.Tx,
replicationTasks []p.Task,
shardID int,
domainID serialization.UUID,
workflowID string,
runID serialization.UUID,
parser serialization.Parser,
) error {
if len(replicationTasks) == 0 {
return nil
}
replicationTasksRows := make([]sqlplugin.ReplicationTasksRow, len(replicationTasks))
for i, task := range replicationTasks {
firstEventID := common.EmptyEventID
nextEventID := common.EmptyEventID
version := common.EmptyVersion
activityScheduleID := common.EmptyEventID
var branchToken, newRunBranchToken []byte
switch task.GetType() {
case p.ReplicationTaskTypeHistory:
historyReplicationTask, ok := task.(*p.HistoryReplicationTask)
if !ok {
return &types.InternalServiceError{
Message: fmt.Sprintf("createReplicationTasks failed. Failed to cast %v to HistoryReplicationTask", task),
}
}
firstEventID = historyReplicationTask.FirstEventID
nextEventID = historyReplicationTask.NextEventID
version = task.GetVersion()
branchToken = historyReplicationTask.BranchToken
newRunBranchToken = historyReplicationTask.NewRunBranchToken
case p.ReplicationTaskTypeSyncActivity:
version = task.GetVersion()
activityScheduleID = task.(*p.SyncActivityTask).ScheduledID
case p.ReplicationTaskTypeFailoverMarker:
version = task.GetVersion()
default:
return &types.InternalServiceError{
Message: fmt.Sprintf("Unknown replication task: %v", task.GetType()),
}
}
blob, err := parser.ReplicationTaskInfoToBlob(&serialization.ReplicationTaskInfo{
DomainID: domainID,
WorkflowID: workflowID,
RunID: runID,
TaskType: int16(task.GetType()),
FirstEventID: firstEventID,
NextEventID: nextEventID,
Version: version,
ScheduledID: activityScheduleID,
EventStoreVersion: p.EventStoreVersion,
NewRunEventStoreVersion: p.EventStoreVersion,
BranchToken: branchToken,
NewRunBranchToken: newRunBranchToken,
CreationTimestamp: task.GetVisibilityTimestamp(),
})
if err != nil {
return err
}
replicationTasksRows[i].ShardID = shardID
replicationTasksRows[i].TaskID = task.GetTaskID()
replicationTasksRows[i].Data = blob.Data
replicationTasksRows[i].DataEncoding = string(blob.Encoding)
}
result, err := tx.InsertIntoReplicationTasks(ctx, replicationTasksRows)
if err != nil {
return convertCommonErrors(tx, "createReplicationTasks", "", err)
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return &types.InternalServiceError{
Message: fmt.Sprintf("createReplicationTasks failed. Could not verify number of rows inserted. Error: %v", err),
}
}
if int(rowsAffected) != len(replicationTasks) {
return &types.InternalServiceError{
Message: fmt.Sprintf("createReplicationTasks failed. Inserted %v instead of %v rows into transfer_tasks. Error: %v", rowsAffected, len(replicationTasks), err),
}
}
return nil
}
func createTimerTasks(
ctx context.Context,
tx sqlplugin.Tx,
timerTasks []p.Task,
shardID int,
domainID serialization.UUID,
workflowID string,
runID serialization.UUID,
parser serialization.Parser,
) error {
if len(timerTasks) == 0 {
return nil
}
timerTasksRows := make([]sqlplugin.TimerTasksRow, len(timerTasks))
for i, task := range timerTasks {
info := &serialization.TimerTaskInfo{
DomainID: domainID,
WorkflowID: workflowID,
RunID: runID,
TaskType: int16(task.GetType()),
Version: task.GetVersion(),
EventID: common.EmptyEventID,
ScheduleAttempt: 0,
}
switch t := task.(type) {
case *p.DecisionTimeoutTask:
info.EventID = t.EventID
info.TimeoutType = common.Int16Ptr(int16(t.TimeoutType))
info.ScheduleAttempt = t.ScheduleAttempt
case *p.ActivityTimeoutTask:
info.EventID = t.EventID
info.TimeoutType = common.Int16Ptr(int16(t.TimeoutType))
info.ScheduleAttempt = t.Attempt
case *p.UserTimerTask:
info.EventID = t.EventID
case *p.ActivityRetryTimerTask:
info.EventID = t.EventID
info.ScheduleAttempt = int64(t.Attempt)
case *p.WorkflowBackoffTimerTask:
info.EventID = t.EventID
info.TimeoutType = common.Int16Ptr(int16(t.TimeoutType))
case *p.WorkflowTimeoutTask:
// noop
case *p.DeleteHistoryEventTask:
// noop
default:
return &types.InternalServiceError{
Message: fmt.Sprintf("createTimerTasks failed. Unknown timer task: %v", task.GetType()),
}
}
blob, err := parser.TimerTaskInfoToBlob(info)
if err != nil {
return err
}
timerTasksRows[i].ShardID = shardID
timerTasksRows[i].VisibilityTimestamp = task.GetVisibilityTimestamp()
timerTasksRows[i].TaskID = task.GetTaskID()
timerTasksRows[i].Data = blob.Data
timerTasksRows[i].DataEncoding = string(blob.Encoding)
}
result, err := tx.InsertIntoTimerTasks(ctx, timerTasksRows)
if err != nil {
return convertCommonErrors(tx, "createTimerTasks", "", err)
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return &types.InternalServiceError{
Message: fmt.Sprintf("createTimerTasks failed. Could not verify number of rows inserted. Error: %v", err),
}
}
if int(rowsAffected) != len(timerTasks) {
return &types.InternalServiceError{
Message: fmt.Sprintf("createTimerTasks failed. Inserted %v instead of %v rows into timer_tasks.", rowsAffected, len(timerTasks)),
}
}
return nil
}
func assertNotCurrentExecution(
ctx context.Context,
tx sqlplugin.Tx,
shardID int,
domainID serialization.UUID,
workflowID string,
runID serialization.UUID,
) error {
currentRow, err := tx.LockCurrentExecutions(ctx, &sqlplugin.CurrentExecutionsFilter{
ShardID: int64(shardID),
DomainID: domainID,
WorkflowID: workflowID,
})
if err != nil {
if err == sql.ErrNoRows {
// allow bypassing no current record
return nil
}
return convertCommonErrors(tx, "assertCurrentExecution", "Unable to load current record.", err)
}
return assertRunIDMismatch(runID, currentRow.RunID)
}
func assertRunIDAndUpdateCurrentExecution(
ctx context.Context,
tx sqlplugin.Tx,
shardID int,
domainID serialization.UUID,
workflowID string,
newRunID serialization.UUID,
previousRunID serialization.UUID,
createRequestID string,
state int,
closeStatus int,
startVersion int64,
lastWriteVersion int64,
) error {
assertFn := func(currentRow *sqlplugin.CurrentExecutionsRow) error {
if !bytes.Equal(currentRow.RunID, previousRunID) {
return &p.ConditionFailedError{Msg: fmt.Sprintf(
"assertRunIDAndUpdateCurrentExecution failed. Current run ID was %v, expected %v",
currentRow.RunID,
previousRunID,
)}
}
return nil
}
if err := assertCurrentExecution(ctx, tx, shardID, domainID, workflowID, assertFn); err != nil {
return err
}
return updateCurrentExecution(ctx, tx, shardID, domainID, workflowID, newRunID, createRequestID, state, closeStatus, startVersion, lastWriteVersion)
}
func assertCurrentExecution(
ctx context.Context,
tx sqlplugin.Tx,
shardID int,
domainID serialization.UUID,
workflowID string,
assertFn func(currentRow *sqlplugin.CurrentExecutionsRow) error,
) error {
currentRow, err := tx.LockCurrentExecutions(ctx, &sqlplugin.CurrentExecutionsFilter{
ShardID: int64(shardID),
DomainID: domainID,
WorkflowID: workflowID,
})
if err != nil {
return convertCommonErrors(tx, "assertCurrentExecution", "Unable to load current record.", err)
}
return assertFn(currentRow)
}
func assertRunIDMismatch(runID serialization.UUID, currentRunID serialization.UUID) error {
// zombie workflow creation with existence of current record, this is a noop
if bytes.Equal(currentRunID, runID) {
return &p.ConditionFailedError{Msg: fmt.Sprintf(
"assertRunIDMismatch failed. Current run ID was %v, input %v",
currentRunID,
runID,
)}
}
return nil
}
func updateCurrentExecution(
ctx context.Context,
tx sqlplugin.Tx,
shardID int,
domainID serialization.UUID,
workflowID string,
runID serialization.UUID,
createRequestID string,
state int,
closeStatus int,
startVersion int64,
lastWriteVersion int64,
) error {
result, err := tx.UpdateCurrentExecutions(ctx, &sqlplugin.CurrentExecutionsRow{
ShardID: int64(shardID),
DomainID: domainID,
WorkflowID: workflowID,
RunID: runID,
CreateRequestID: createRequestID,
State: state,
CloseStatus: closeStatus,
StartVersion: startVersion,
LastWriteVersion: lastWriteVersion,
})
if err != nil {
return convertCommonErrors(tx, "updateCurrentExecution", "", err)
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return &types.InternalServiceError{
Message: fmt.Sprintf("updateCurrentExecution failed. Failed to check number of rows updated in current_executions table. Error: %v", err),
}
}
if rowsAffected != 1 {
return &types.InternalServiceError{
Message: fmt.Sprintf("updateCurrentExecution failed. %v rows of current_executions updated instead of 1.", rowsAffected),
}
}
return nil
}
func buildExecutionRow(
executionInfo *p.InternalWorkflowExecutionInfo,
versionHistories *p.DataBlob,
workflowChecksum *p.DataBlob,
startVersion int64,
lastWriteVersion int64,
shardID int,
parser serialization.Parser,
) (row *sqlplugin.ExecutionsRow, err error) {
info := serialization.FromInternalWorkflowExecutionInfo(executionInfo)
info.StartVersion = startVersion
if versionHistories == nil {
// this is allowed
} else {
info.VersionHistories = versionHistories.Data
info.VersionHistoriesEncoding = string(versionHistories.GetEncoding())
}
if workflowChecksum != nil {
info.Checksum = workflowChecksum.Data
info.ChecksumEncoding = string(workflowChecksum.GetEncoding())
}
blob, err := parser.WorkflowExecutionInfoToBlob(info)
if err != nil {
return nil, err
}
return &sqlplugin.ExecutionsRow{
ShardID: shardID,
DomainID: serialization.MustParseUUID(executionInfo.DomainID),
WorkflowID: executionInfo.WorkflowID,
RunID: serialization.MustParseUUID(executionInfo.RunID),
NextEventID: int64(executionInfo.NextEventID),
LastWriteVersion: lastWriteVersion,
Data: blob.Data,
DataEncoding: string(blob.Encoding),
}, nil
}
func createExecution(
ctx context.Context,
tx sqlplugin.Tx,
executionInfo *p.InternalWorkflowExecutionInfo,
versionHistories *p.DataBlob,
workflowChecksum *p.DataBlob,
startVersion int64,
lastWriteVersion int64,
shardID int,
parser serialization.Parser,
) error {
// validate workflow state & close status
if err := p.ValidateCreateWorkflowStateCloseStatus(
executionInfo.State,
executionInfo.CloseStatus); err != nil {
return err
}
now := time.Now()
// TODO: this case seems to be always false
if executionInfo.StartTimestamp.IsZero() {
executionInfo.StartTimestamp = now
}
row, err := buildExecutionRow(
executionInfo,
versionHistories,
workflowChecksum,
startVersion,
lastWriteVersion,
shardID,
parser,
)
if err != nil {
return err
}
result, err := tx.InsertIntoExecutions(ctx, row)
if err != nil {
if tx.IsDupEntryError(err) {
return &p.WorkflowExecutionAlreadyStartedError{
Msg: fmt.Sprintf("Workflow execution already running. WorkflowId: %v", executionInfo.WorkflowID),
StartRequestID: executionInfo.CreateRequestID,
RunID: executionInfo.RunID,
State: executionInfo.State,
CloseStatus: executionInfo.CloseStatus,
LastWriteVersion: row.LastWriteVersion,
}
}
return convertCommonErrors(tx, "createExecution", "", err)
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return &types.InternalServiceError{
Message: fmt.Sprintf("createExecution failed. Failed to verify number of rows affected. Erorr: %v", err),
}
}
if rowsAffected != 1 {
return &types.EntityNotExistsError{
Message: fmt.Sprintf("createExecution failed. Affected %v rows updated instead of 1.", rowsAffected),
}
}
return nil
}
func updateExecution(
ctx context.Context,
tx sqlplugin.Tx,
executionInfo *p.InternalWorkflowExecutionInfo,
versionHistories *p.DataBlob,
workflowChecksum *p.DataBlob,
startVersion int64,
lastWriteVersion int64,
shardID int,
parser serialization.Parser,
) error {
// validate workflow state & close status
if err := p.ValidateUpdateWorkflowStateCloseStatus(
executionInfo.State,
executionInfo.CloseStatus); err != nil {
return err
}
row, err := buildExecutionRow(
executionInfo,
versionHistories,
workflowChecksum,
startVersion,
lastWriteVersion,
shardID,
parser,
)
if err != nil {
return err
}
result, err := tx.UpdateExecutions(ctx, row)
if err != nil {
return convertCommonErrors(tx, "updateExecution", "", err)
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return &types.InternalServiceError{
Message: fmt.Sprintf("updateExecution failed. Failed to verify number of rows affected. Erorr: %v", err),
}
}
if rowsAffected != 1 {
return &types.EntityNotExistsError{
Message: fmt.Sprintf("updateExecution failed. Affected %v rows updated instead of 1.", rowsAffected),
}
}
return nil
}