in common/persistence/sql/sql_task_store.go [81:204]
func (m *sqlTaskStore) LeaseTaskList(
ctx context.Context,
request *persistence.LeaseTaskListRequest,
) (*persistence.LeaseTaskListResponse, error) {
var rangeID int64
var ackLevel int64
dbShardID := sqlplugin.GetDBShardIDFromDomainIDAndTasklist(request.DomainID, request.TaskList, m.db.GetTotalNumDBShards())
domainID := serialization.MustParseUUID(request.DomainID)
rows, err := m.db.SelectFromTaskLists(ctx, &sqlplugin.TaskListsFilter{
ShardID: dbShardID,
DomainID: &domainID,
Name: &request.TaskList,
TaskType: common.Int64Ptr(int64(request.TaskType))})
if err != nil {
if err == sql.ErrNoRows {
tlInfo := &serialization.TaskListInfo{
AckLevel: ackLevel,
Kind: int16(request.TaskListKind),
ExpiryTimestamp: time.Unix(0, 0),
LastUpdated: time.Now(),
}
blob, err := m.parser.TaskListInfoToBlob(tlInfo)
if err != nil {
return nil, err
}
row := sqlplugin.TaskListsRow{
ShardID: dbShardID,
DomainID: domainID,
Name: request.TaskList,
TaskType: int64(request.TaskType),
Data: blob.Data,
DataEncoding: string(blob.Encoding),
}
rows = []sqlplugin.TaskListsRow{row}
if m.db.SupportsTTL() && request.TaskListKind == persistence.TaskListKindSticky {
rowWithTTL := sqlplugin.TaskListsRowWithTTL{
TaskListsRow: row,
TTL: stickyTasksListsTTL,
}
if _, err := m.db.InsertIntoTaskListsWithTTL(ctx, &rowWithTTL); err != nil {
return nil, convertCommonErrors(m.db, "LeaseTaskListWithTTL", fmt.Sprintf("Failed to make task list %v of type %v.", request.TaskList, request.TaskType), err)
}
} else {
if _, err := m.db.InsertIntoTaskLists(ctx, &row); err != nil {
return nil, convertCommonErrors(m.db, "LeaseTaskList", fmt.Sprintf("Failed to make task list %v of type %v.", request.TaskList, request.TaskType), err)
}
}
} else {
return nil, convertCommonErrors(m.db, "LeaseTaskList", "Failed to check if task list existed.", err)
}
}
row := rows[0]
if request.RangeID > 0 && request.RangeID != row.RangeID {
return nil, &persistence.ConditionFailedError{
Msg: fmt.Sprintf("leaseTaskList:renew failed:taskList:%v, taskListType:%v, haveRangeID:%v, gotRangeID:%v",
request.TaskList, request.TaskType, rangeID, row.RangeID),
}
}
tlInfo, err := m.parser.TaskListInfoFromBlob(row.Data, row.DataEncoding)
if err != nil {
return nil, err
}
var resp *persistence.LeaseTaskListResponse
err = m.txExecute(ctx, dbShardID, "LeaseTaskList", func(tx sqlplugin.Tx) error {
rangeID = row.RangeID
ackLevel = tlInfo.GetAckLevel()
// We need to separately check the condition and do the
// update because we want to throw different error codes.
// Since we need to do things separately (in a transaction), we need to take a lock.
err1 := lockTaskList(ctx, tx, dbShardID, domainID, request.TaskList, request.TaskType, rangeID)
if err1 != nil {
return err1
}
now := time.Now()
tlInfo.LastUpdated = now
blob, err1 := m.parser.TaskListInfoToBlob(tlInfo)
if err1 != nil {
return err1
}
row := &sqlplugin.TaskListsRow{
ShardID: dbShardID,
DomainID: row.DomainID,
RangeID: row.RangeID + 1,
Name: row.Name,
TaskType: row.TaskType,
Data: blob.Data,
DataEncoding: string(blob.Encoding),
}
var result sql.Result
if tlInfo.GetKind() == persistence.TaskListKindSticky && m.db.SupportsTTL() {
result, err1 = tx.UpdateTaskListsWithTTL(ctx, &sqlplugin.TaskListsRowWithTTL{
TaskListsRow: *row,
TTL: stickyTasksListsTTL,
})
} else {
result, err1 = tx.UpdateTaskLists(ctx, row)
}
if err1 != nil {
return err1
}
rowsAffected, err1 := result.RowsAffected()
if err1 != nil {
return fmt.Errorf("rowsAffected error: %v", err1)
}
if rowsAffected == 0 {
return fmt.Errorf("%v rows affected instead of 1", rowsAffected)
}
resp = &persistence.LeaseTaskListResponse{TaskListInfo: &persistence.TaskListInfo{
DomainID: request.DomainID,
Name: request.TaskList,
TaskType: request.TaskType,
RangeID: rangeID + 1,
AckLevel: ackLevel,
Kind: request.TaskListKind,
LastUpdated: now,
}}
return nil
})
return resp, err
}