internal/master/persistence/utils.go (59 lines of code) (raw):
/*
* Copyright (c) 2023 Alibaba Group Holding Ltd.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package persistence
import (
"context"
"database/sql"
"fmt"
"github.com/alibaba/schedulerx-worker-go/internal/proto/schedulerx"
"github.com/alibaba/schedulerx-worker-go/logger"
)
// SQLTxFunc is a function that will be called with an initialized 'DbTx' object
// that can be used for executing statements and queries against a database.
type SQLTxFunc func(tx *sql.Tx) error
// WithTransaction creates a new transaction and handles rollback/commit based on the
// error object returned by the 'SQLTxFunc'
func WithTransaction(ctx context.Context, db *sql.DB, fn SQLTxFunc) (err error) {
tx, err := db.BeginTx(ctx, nil)
if err != nil {
logger.Errorf("db begin error.%v", err)
return
}
defer func() {
if p := recover(); p != nil {
// a panic occurred, rollback and repanic
if rollbackErr := tx.Rollback(); rollbackErr != nil {
logger.Errorf("rollback error.%v", rollbackErr)
}
panic(p)
} else if err != nil {
// something went wrong, rollback
logger.Errorf("fn error.%v", err)
if rollbackErr := tx.Rollback(); rollbackErr != nil {
logger.Errorf("rollback error.%v", rollbackErr)
}
} else {
// all good, commit
err = tx.Commit()
if err != nil {
logger.Errorf("commit error.%v", err)
}
}
}()
err = fn(tx)
return err
}
// GetTaskStatusMap get task Status classfied by Status, workerIdAddr
// {Status -> {workerIdAddr -> list of taskIds}}
func getTaskStatusMap(taskStatusInfos []*schedulerx.ContainerReportTaskStatusRequest) map[int32]map[string][]int64 {
status2WorkIdAddr2TaskIds := make(map[int32]map[string][]int64)
for _, e := range taskStatusInfos {
status := e.GetStatus()
taskId := e.GetTaskId()
workerIdAddr := fmt.Sprintf("%s@%s", e.GetWorkerId(), e.GetWorkerAddr())
addTaskStatusInfo(status2WorkIdAddr2TaskIds, status, workerIdAddr, taskId)
}
return status2WorkIdAddr2TaskIds
}
func addTaskStatusInfo(status2WorkIdAddr2TaskIds map[int32]map[string][]int64, status int32, workerIdAddr string, taskId int64) {
if workerIdAddr2TaskIds, ok := status2WorkIdAddr2TaskIds[status]; !ok {
// Status not exists , all below must be first time add in too
workerAddr2TaskIds := make(map[string][]int64)
workerAddr2TaskIds[workerIdAddr] = []int64{taskId}
status2WorkIdAddr2TaskIds[status] = workerAddr2TaskIds
} else {
// Status already exists
if _, ok := workerIdAddr2TaskIds[workerIdAddr]; !ok {
workerIdAddr2TaskIds[workerIdAddr] = []int64{taskId}
} else {
workerIdAddr2TaskIds[workerIdAddr] = append(workerIdAddr2TaskIds[workerIdAddr], taskId)
}
}
}