internal/master/persistence/task_dao.go (406 lines of code) (raw):
/*
* Copyright (c) 2023 Alibaba Group Holding Ltd.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package persistence
import (
"context"
"database/sql"
"fmt"
"strings"
"time"
"github.com/alibaba/schedulerx-worker-go/internal/common"
"github.com/alibaba/schedulerx-worker-go/internal/proto/schedulerx"
"github.com/alibaba/schedulerx-worker-go/internal/utils"
"github.com/alibaba/schedulerx-worker-go/processor/taskstatus"
)
type TaskDao struct {
h2 *H2ConnectionPool
}
func NewTaskDao(h2CP *H2ConnectionPool) *TaskDao {
return &TaskDao{
h2: h2CP,
}
}
func (d *TaskDao) CreateTable() error {
sql := "CREATE TABLE IF NOT EXISTS task (" +
"job_id unsigned bigint(20) NOT NULL," +
"job_instance_id unsigned bigint(20) NOT NULL," +
"task_id unsigned bigint(20) NOT NULL," +
"task_name varchar(100) NOT NULL DEFAULT ''," +
"status int(11) NOT NULL," +
"progress float NOT NULL DEFAULT '0'," +
"gmt_create datetime NOT NULL," +
"gmt_modified datetime NOT NULL," +
"worker_addr varchar(30) NOT NULL DEFAULT ''," +
"worker_id varchar(30) NOT NULL DEFAULT ''," +
"task_body blob DEFAULT NULL," +
"CONSTRAINT uk_instance_and_task UNIQUE (job_instance_id,task_id));" +
"CREATE INDEX idx_job_instance_id ON task (job_instance_id);" +
"CREATE INDEX idx_status ON task (status);"
_, err := d.h2.Exec(sql)
return err
}
func (d *TaskDao) DropTable() error {
sql := "DROP TABLE IF EXISTS task"
_, err := d.h2.Exec(sql)
return err
}
func (d *TaskDao) BatchDeleteTasks(jobInstanceId int64, taskIds []int64) (int64, error) {
var (
totalAffectCnt int64
ctx = context.Background()
)
err := WithTransaction(ctx, d.h2.DB, func(tx *sql.Tx) error {
sql := "delete from task where job_instance_id=? and task_id=?"
stmt, err := tx.Prepare(sql)
if err != nil {
return err
}
defer stmt.Close()
for _, taskId := range taskIds {
ret, err := stmt.ExecContext(ctx, jobInstanceId, taskId)
if err != nil {
continue
}
affectCnt, _ := ret.RowsAffected()
totalAffectCnt += affectCnt
}
return nil
})
return totalAffectCnt, err
}
func (d *TaskDao) BatchDeleteTasks2(jobInstanceId int64, workerId string, workerAddr string) (int64, error) {
var (
totalAffectCnt int64
ctx = context.Background()
)
err := WithTransaction(ctx, d.h2.DB, func(tx *sql.Tx) error {
sql := "delete from task where job_instance_id=? and worker_id=? and worker_addr=?"
stmt, err := tx.Prepare(sql)
if err != nil {
return err
}
defer stmt.Close()
ret, err := stmt.ExecContext(ctx, jobInstanceId, workerId, workerAddr)
if err != nil {
return err
}
affectCnt, _ := ret.RowsAffected()
totalAffectCnt += affectCnt
return nil
})
return totalAffectCnt, err
}
func (d *TaskDao) BatchInsert(containers []*schedulerx.MasterStartContainerRequest, workerId string, workerAddr string) (int64, error) {
var (
totalAffectCnt int64
ctx = context.Background()
)
err := WithTransaction(ctx, d.h2.DB, func(tx *sql.Tx) error {
sql := "insert into task(job_id,job_instance_id,task_id,task_name,status,gmt_create,gmt_modified,task_body,worker_id,worker_addr) " +
"VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)"
stmt, err := tx.Prepare(sql)
if err != nil {
return err
}
defer stmt.Close()
for _, snapshot := range containers {
timeNowMill := time.Now().UnixMilli()
ret, err := stmt.ExecContext(ctx,
snapshot.GetJobId(),
snapshot.GetJobInstanceId(),
snapshot.GetTaskId(),
snapshot.GetTaskName(),
int(taskstatus.TaskStatusPulled),
timeNowMill,
timeNowMill,
snapshot.GetTask(),
workerId,
workerAddr)
if err != nil {
continue
}
affectCnt, _ := ret.RowsAffected()
totalAffectCnt += affectCnt
}
return nil
})
if err != nil {
return 0, err
}
return totalAffectCnt, nil
}
func (d *TaskDao) BatchUpdateStatus(jobInstanceId int64, taskIdList []int64, status int) (int64, error) {
var (
totalAffectCnt int64
ctx = context.Background()
)
err := WithTransaction(ctx, d.h2.DB, func(tx *sql.Tx) error {
sql := fmt.Sprintf("update task set status=? where job_instance_id=? and task_id in (%s)",
strings.Join(utils.Int64SliceToStringSlice(taskIdList), ","))
stmt, err := tx.Prepare(sql)
if err != nil {
return err
}
defer stmt.Close()
ret, err := stmt.ExecContext(ctx, status, jobInstanceId)
if err != nil {
return err
}
affectCnt, _ := ret.RowsAffected()
totalAffectCnt += affectCnt
return nil
})
if err != nil {
return 0, err
}
return totalAffectCnt, nil
}
func (d *TaskDao) BatchUpdateStatus2(jobInstanceId int64, status int, workerId string, workerAddr string) (int64, error) {
var (
totalAffectCnt int64
ctx = context.Background()
)
err := WithTransaction(ctx, d.h2.DB, func(tx *sql.Tx) error {
sqlStr := "update task set status=?,gmt_modified=? where job_instance_id=?"
if workerId != "" {
sqlStr = "update task set status=?,gmt_modified=? where job_instance_id=? and worker_id=? and worker_addr=?"
}
if status == int(taskstatus.TaskStatusPulled) {
sqlStr = fmt.Sprintf("%v%v", sqlStr, " and status = 3")
}
stmt, err := tx.Prepare(sqlStr)
if err != nil {
return err
}
defer stmt.Close()
var ret sql.Result
if workerId != "" {
ret, err = stmt.ExecContext(ctx, status, time.Now().UnixMilli(), jobInstanceId, workerId, workerAddr)
} else {
ret, err = stmt.ExecContext(ctx, status, time.Now().UnixMilli(), jobInstanceId)
}
if err != nil {
return err
}
affectCnt, _ := ret.RowsAffected()
totalAffectCnt += affectCnt
return nil
})
return totalAffectCnt, err
}
func (d *TaskDao) DeleteByJobInstanceId(jobInstanceId int64) (int64, error) {
sql := "delete from task where job_instance_id=?"
stmt, err := d.h2.Prepare(sql)
if err != nil {
return 0, err
}
defer stmt.Close()
ret, err := stmt.Exec(jobInstanceId)
if err != nil {
return 0, err
}
return ret.RowsAffected()
}
func (d *TaskDao) Exist(jobInstanceId int64) (bool, error) {
sql := "select EXISTS (select * from task where job_instance_id=?)"
stmt, err := d.h2.Prepare(sql)
if err != nil {
return false, err
}
defer stmt.Close()
var isExisted bool
err = stmt.QueryRow(jobInstanceId).Scan(&isExisted)
return isExisted, err
}
func (d *TaskDao) GetDistinctInstanceIds() ([]int64, error) {
sql := "select distinct job_instance_id from task"
rows, err := d.h2.Query(sql)
if err != nil {
return nil, err
}
defer rows.Close()
var result []int64
for rows.Next() {
var instanceId int64
if err = rows.Scan(&instanceId); err != nil {
return nil, err
}
result = append(result, instanceId)
}
err = rows.Err()
return result, err
}
func (d *TaskDao) GetTaskStatistics() (*common.TaskStatistics, error) {
var result = new(common.TaskStatistics)
sql := "select count(distinct job_instance_id) from task"
var instanceId int64
err := d.h2.QueryRow(sql).Scan(&instanceId)
if err != nil {
return nil, err
}
result.SetDistinctInstanceCount(instanceId)
sql = "select count(*) from task"
var taskCnt int64
err = d.h2.QueryRow(sql).Scan(&taskCnt)
if err != nil {
return nil, err
}
result.SetTaskCount(taskCnt)
return result, nil
}
func (d *TaskDao) Insert(jobId int64, jobInstanceId int64, taskId int64, taskName string, taskBody []byte) error {
sql := "insert into task(job_id,job_instance_id,task_id,task_name,status,gmt_create,gmt_modified,task_body) VALUES (?, ?, ?, ?, ?, ?, ?, ?)"
stmt, err := d.h2.Prepare(sql)
if err != nil {
return err
}
defer stmt.Close()
timeNowMill := time.Now().UnixMilli()
_, err = stmt.Exec(jobId, jobInstanceId, taskId, taskName, int(taskstatus.TaskStatusPulled), timeNowMill, timeNowMill, taskBody)
return err
}
func (d *TaskDao) QueryStatus(jobInstanceId int64) ([]int32, error) {
sql := "select distinct(status) from task where job_instance_id=?"
stmt, err := d.h2.Prepare(sql)
if err != nil {
return nil, err
}
defer stmt.Close()
rows, err := stmt.Query(jobInstanceId)
if err != nil {
return nil, err
}
defer rows.Close()
var statusList []int32
for rows.Next() {
var status int32
if err = rows.Scan(&status); err != nil {
return nil, err
}
statusList = append(statusList, status)
}
err = rows.Err()
return statusList, err
}
func (d *TaskDao) QueryTaskCount(jobInstanceId int64) (int64, error) {
sql := "select count(*) from task where job_instance_id=?"
stmt, err := d.h2.Prepare(sql)
if err != nil {
return 0, err
}
defer stmt.Close()
var taskCnt int64
err = stmt.QueryRow(jobInstanceId).Scan(&taskCnt)
return taskCnt, err
}
func (d *TaskDao) QueryTaskList(jobInstanceId int64, status int, pageSize int32) ([]*TaskSnapshot, error) {
sql := "select * from task where job_instance_id=? and status=? limit ?"
stmt, err := d.h2.Prepare(sql)
if err != nil {
return nil, err
}
defer stmt.Close()
rows, err := stmt.Query(jobInstanceId, status, pageSize)
if err != nil {
return nil, err
}
defer rows.Close()
var taskList []*TaskSnapshot
for rows.Next() {
snapshot := new(TaskSnapshot)
if err = rows.Scan(
&snapshot.JobId,
&snapshot.JobInstanceId,
&snapshot.TaskId,
&snapshot.TaskName,
&snapshot.Status,
&snapshot.Progress,
&snapshot.GmtCreate,
&snapshot.GmtModified,
&snapshot.WorkerAddr,
&snapshot.WorkerId,
&snapshot.TaskBody); err != nil {
return nil, err
}
taskList = append(taskList, snapshot)
}
err = rows.Err()
return taskList, err
}
func (d *TaskDao) QueryTasks(jobInstanceId int64, pageSize int32) ([]*TaskSnapshot, error) {
sql := "select * from task where job_instance_id=? limit ?"
stmt, err := d.h2.Prepare(sql)
if err != nil {
return nil, err
}
defer stmt.Close()
rows, err := stmt.Query(jobInstanceId, pageSize)
if err != nil {
return nil, err
}
defer rows.Close()
var taskList []*TaskSnapshot
for rows.Next() {
snapshot := new(TaskSnapshot)
if err = rows.Scan(
&snapshot.JobId,
&snapshot.JobInstanceId,
&snapshot.TaskId,
&snapshot.TaskName,
&snapshot.Status,
&snapshot.Progress,
&snapshot.GmtCreate,
&snapshot.GmtModified,
&snapshot.WorkerAddr,
&snapshot.WorkerId,
&snapshot.TaskBody); err != nil {
return nil, err
}
taskList = append(taskList, snapshot)
}
err = rows.Err()
return taskList, err
}
func (d *TaskDao) UpdateStatus(jobInstanceId int64, taskId int64, status int, workerAddr string) (int64, error) {
sql := "update task set status=?,worker_addr=?,gmt_modified=? where job_instance_id=? and task_id=?"
stmt, err := d.h2.Prepare(sql)
if err != nil {
return 0, err
}
defer stmt.Close()
ret, err := stmt.Exec(status, workerAddr, time.Now().UnixMilli(), jobInstanceId, taskId)
affectCnt, _ := ret.RowsAffected()
return affectCnt, err
}
func (d *TaskDao) UpdateStatus2(jobInstanceId int64, taskIds []int64, status int, workerId string, workerAddr string) (int64, error) {
sql := "update task set status=?, worker_id=?, worker_addr=? WHERE job_instance_id=? and task_id =?"
if status == int(taskstatus.TaskStatusPulled) {
sql = fmt.Sprintf("%v%v", sql, " and status = 3")
}
stmt, err := d.h2.Prepare(sql)
if err != nil {
return 0, err
}
defer stmt.Close()
var totalAffectCnt int64
for _, taskId := range taskIds {
ret, err := stmt.Exec(status, workerId, workerAddr, jobInstanceId, taskId)
if err != nil {
continue
}
affectCnt, _ := ret.RowsAffected()
totalAffectCnt += affectCnt
}
return totalAffectCnt, err
}
func (d *TaskDao) UpdateWorker(jobInstanceId int64, taskId int64, workerId string, workerAddr string) (int64, error) {
sql := "update task set worker_id=?,worker_addr=?,gmt_modified=? where job_instance_id=? and task_id=?"
stmt, err := d.h2.Prepare(sql)
if err != nil {
return 0, err
}
defer stmt.Close()
ret, err := stmt.Exec(workerId, workerAddr, time.Now().UnixMilli(), jobInstanceId, taskId)
if err != nil {
return 0, err
}
return ret.RowsAffected()
}