func()

in common/persistence/sql/sql_task_store.go [81:204]


func (m *sqlTaskStore) LeaseTaskList(
	ctx context.Context,
	request *persistence.LeaseTaskListRequest,
) (*persistence.LeaseTaskListResponse, error) {
	var rangeID int64
	var ackLevel int64
	dbShardID := sqlplugin.GetDBShardIDFromDomainIDAndTasklist(request.DomainID, request.TaskList, m.db.GetTotalNumDBShards())

	domainID := serialization.MustParseUUID(request.DomainID)
	rows, err := m.db.SelectFromTaskLists(ctx, &sqlplugin.TaskListsFilter{
		ShardID:  dbShardID,
		DomainID: &domainID,
		Name:     &request.TaskList,
		TaskType: common.Int64Ptr(int64(request.TaskType))})
	if err != nil {
		if err == sql.ErrNoRows {
			tlInfo := &serialization.TaskListInfo{
				AckLevel:        ackLevel,
				Kind:            int16(request.TaskListKind),
				ExpiryTimestamp: time.Unix(0, 0),
				LastUpdated:     time.Now(),
			}
			blob, err := m.parser.TaskListInfoToBlob(tlInfo)
			if err != nil {
				return nil, err
			}
			row := sqlplugin.TaskListsRow{
				ShardID:      dbShardID,
				DomainID:     domainID,
				Name:         request.TaskList,
				TaskType:     int64(request.TaskType),
				Data:         blob.Data,
				DataEncoding: string(blob.Encoding),
			}
			rows = []sqlplugin.TaskListsRow{row}
			if m.db.SupportsTTL() && request.TaskListKind == persistence.TaskListKindSticky {
				rowWithTTL := sqlplugin.TaskListsRowWithTTL{
					TaskListsRow: row,
					TTL:          stickyTasksListsTTL,
				}
				if _, err := m.db.InsertIntoTaskListsWithTTL(ctx, &rowWithTTL); err != nil {
					return nil, convertCommonErrors(m.db, "LeaseTaskListWithTTL", fmt.Sprintf("Failed to make task list %v of type %v.", request.TaskList, request.TaskType), err)
				}
			} else {
				if _, err := m.db.InsertIntoTaskLists(ctx, &row); err != nil {
					return nil, convertCommonErrors(m.db, "LeaseTaskList", fmt.Sprintf("Failed to make task list %v of type %v.", request.TaskList, request.TaskType), err)
				}
			}
		} else {
			return nil, convertCommonErrors(m.db, "LeaseTaskList", "Failed to check if task list existed.", err)
		}
	}

	row := rows[0]
	if request.RangeID > 0 && request.RangeID != row.RangeID {
		return nil, &persistence.ConditionFailedError{
			Msg: fmt.Sprintf("leaseTaskList:renew failed:taskList:%v, taskListType:%v, haveRangeID:%v, gotRangeID:%v",
				request.TaskList, request.TaskType, rangeID, row.RangeID),
		}
	}

	tlInfo, err := m.parser.TaskListInfoFromBlob(row.Data, row.DataEncoding)
	if err != nil {
		return nil, err
	}

	var resp *persistence.LeaseTaskListResponse
	err = m.txExecute(ctx, dbShardID, "LeaseTaskList", func(tx sqlplugin.Tx) error {
		rangeID = row.RangeID
		ackLevel = tlInfo.GetAckLevel()
		// We need to separately check the condition and do the
		// update because we want to throw different error codes.
		// Since we need to do things separately (in a transaction), we need to take a lock.
		err1 := lockTaskList(ctx, tx, dbShardID, domainID, request.TaskList, request.TaskType, rangeID)
		if err1 != nil {
			return err1
		}
		now := time.Now()
		tlInfo.LastUpdated = now
		blob, err1 := m.parser.TaskListInfoToBlob(tlInfo)
		if err1 != nil {
			return err1
		}
		row := &sqlplugin.TaskListsRow{
			ShardID:      dbShardID,
			DomainID:     row.DomainID,
			RangeID:      row.RangeID + 1,
			Name:         row.Name,
			TaskType:     row.TaskType,
			Data:         blob.Data,
			DataEncoding: string(blob.Encoding),
		}
		var result sql.Result
		if tlInfo.GetKind() == persistence.TaskListKindSticky && m.db.SupportsTTL() {
			result, err1 = tx.UpdateTaskListsWithTTL(ctx, &sqlplugin.TaskListsRowWithTTL{
				TaskListsRow: *row,
				TTL:          stickyTasksListsTTL,
			})
		} else {
			result, err1 = tx.UpdateTaskLists(ctx, row)
		}
		if err1 != nil {
			return err1
		}
		rowsAffected, err1 := result.RowsAffected()
		if err1 != nil {
			return fmt.Errorf("rowsAffected error: %v", err1)
		}
		if rowsAffected == 0 {
			return fmt.Errorf("%v rows affected instead of 1", rowsAffected)
		}
		resp = &persistence.LeaseTaskListResponse{TaskListInfo: &persistence.TaskListInfo{
			DomainID:    request.DomainID,
			Name:        request.TaskList,
			TaskType:    request.TaskType,
			RangeID:     rangeID + 1,
			AckLevel:    ackLevel,
			Kind:        request.TaskListKind,
			LastUpdated: now,
		}}
		return nil
	})
	return resp, err
}