pkg/datasource/sql/undo/base/undo.go (430 lines of code) (raw):

/* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package base import ( "context" "database/sql" "database/sql/driver" "encoding/json" "fmt" "strconv" "strings" "github.com/arana-db/parser/mysql" "seata.apache.org/seata-go/pkg/compressor" "seata.apache.org/seata-go/pkg/datasource/sql/datasource" "seata.apache.org/seata-go/pkg/datasource/sql/types" "seata.apache.org/seata-go/pkg/datasource/sql/undo" "seata.apache.org/seata-go/pkg/datasource/sql/undo/factor" "seata.apache.org/seata-go/pkg/datasource/sql/undo/parser" "seata.apache.org/seata-go/pkg/util/collection" "seata.apache.org/seata-go/pkg/util/log" ) const ( compressorTypeKey = "compressorTypeKey" serializerKey = "serializerKey" defaultUndoLogTableName = " undo_log " ) func getUndoLogTableName() string { if undo.UndoConfig.LogTable != "" { return undo.UndoConfig.LogTable } return defaultUndoLogTableName } func getCheckUndoLogTableExistSql() string { return "SELECT 1 FROM " + getUndoLogTableName() + " LIMIT 1" } func getInsertUndoLogSql() string { return "INSERT INTO " + getUndoLogTableName() + "(branch_id,xid,context,rollback_info,log_status,log_created,log_modified) VALUES (?, ?, ?, ?, ?, now(6), now(6))" } func getSelectUndoLogSql() string { return "SELECT `branch_id`,`xid`,`context`,`rollback_info`,`log_status` FROM " + getUndoLogTableName() + " WHERE branch_id = ? AND xid = ? FOR UPDATE" } func getDeleteUndoLogSql() string { return "DELETE FROM " + getUndoLogTableName() + " WHERE branch_id = ? AND xid = ?" } // undo log status const ( // UndoLogStatusNormal This state can be properly rolled back by services UndoLogStatusNormal = iota // UndoLogStatusGlobalFinished This state prevents the branch transaction from inserting undo_log after the global transaction is rolled back. UndoLogStatusGlobalFinished ) // BaseUndoLogManager type BaseUndoLogManager struct{} func NewBaseUndoLogManager() *BaseUndoLogManager { return &BaseUndoLogManager{} } // Init func (m *BaseUndoLogManager) Init() { } // InsertUndoLog func (m *BaseUndoLogManager) InsertUndoLog(record undo.UndologRecord, conn driver.Conn) error { log.Infof("begin to insert undo log, xid %v, branch id %v", record.XID, record.BranchID) stmt, err := conn.Prepare(getInsertUndoLogSql()) if err != nil { return err } _, err = stmt.Exec([]driver.Value{record.BranchID, record.XID, record.Context, record.RollbackInfo, int64(record.LogStatus)}) if err != nil { return err } return nil } func (m *BaseUndoLogManager) InsertUndoLogWithSqlConn(ctx context.Context, record undo.UndologRecord, conn *sql.Conn) error { stmt, err := conn.PrepareContext(ctx, getInsertUndoLogSql()) if err != nil { return err } defer stmt.Close() _, err = stmt.Exec(record.BranchID, record.XID, record.Context, record.RollbackInfo, int64(record.LogStatus)) if err != nil { return err } return nil } // DeleteUndoLog exec delete single undo log operate func (m *BaseUndoLogManager) DeleteUndoLog(ctx context.Context, xid string, branchID int64, conn *sql.Conn) error { stmt, err := conn.PrepareContext(ctx, getDeleteUndoLogSql()) if err != nil { log.Errorf("[DeleteUndoLog] prepare sql fail, err: %v", err) return err } defer stmt.Close() if _, err = stmt.Exec(branchID, xid); err != nil { log.Errorf("[DeleteUndoLog] exec delete undo log fail, err: %v", err) return err } return nil } // BatchDeleteUndoLog exec delete undo log operate func (m *BaseUndoLogManager) BatchDeleteUndoLog(xid []string, branchID []int64, conn *sql.Conn) error { // build delete undo log sql batchDeleteSql, err := m.getBatchDeleteUndoLogSql(xid, branchID) if err != nil { log.Errorf("get undo sql log fail, err: %v", err) return err } ctx := context.Background() // prepare deal sql stmt, err := conn.PrepareContext(ctx, batchDeleteSql) if err != nil { log.Errorf("prepare sql fail, err: %v", err) return err } defer stmt.Close() branchIDStr, err := Int64Slice2Str(branchID, ",") if err != nil { log.Errorf("slice to string transfer fail, err: %v", err) return err } // exec sql stmt if _, err = stmt.ExecContext(ctx, branchIDStr, strings.Join(xid, ",")); err != nil { log.Errorf("exec delete undo log fail, err: %v", err) return err } return nil } // FlushUndoLog flush undo log func (m *BaseUndoLogManager) FlushUndoLog(tranCtx *types.TransactionContext, conn driver.Conn) error { if tranCtx.RoundImages.IsEmpty() { return nil } sqlUndoLogs := make([]undo.SQLUndoLog, 0) beforeImages := tranCtx.RoundImages.BeofreImages() afterImages := tranCtx.RoundImages.AfterImages() if beforeImages.IsEmptyImage() && afterImages.IsEmptyImage() { return nil } size := len(beforeImages) if size < len(afterImages) { size = len(afterImages) } for i := 0; i < size; i++ { var ( tableName string sqlType types.SQLType beforeImage *types.RecordImage afterImage *types.RecordImage ) if i < len(beforeImages) && beforeImages[i] != nil { tableName = beforeImages[i].TableName sqlType = beforeImages[i].SQLType } else if i < len(afterImages) && afterImages[i] != nil { tableName = afterImages[i].TableName sqlType = afterImages[i].SQLType } else { continue } if i < len(beforeImages) { beforeImage = beforeImages[i] } if i < len(afterImages) { afterImage = afterImages[i] } undoLog := undo.SQLUndoLog{ SQLType: sqlType, TableName: tableName, BeforeImage: beforeImage, AfterImage: afterImage, } sqlUndoLogs = append(sqlUndoLogs, undoLog) } branchUndoLog := undo.BranchUndoLog{ Xid: tranCtx.XID, BranchID: tranCtx.BranchID, Logs: sqlUndoLogs, } parseContext := make(map[string]string, 0) parseContext[serializerKey] = undo.UndoConfig.LogSerialization parseContext[compressorTypeKey] = undo.UndoConfig.CompressConfig.Type undoLogContent := m.encodeUndoLogCtx(parseContext) rollbackInfo, err := m.serializeBranchUndoLog(&branchUndoLog, parseContext[serializerKey]) if err != nil { return err } return m.InsertUndoLog(undo.UndologRecord{ BranchID: tranCtx.BranchID, XID: tranCtx.XID, Context: undoLogContent, RollbackInfo: rollbackInfo, LogStatus: undo.UndoLogStatueNormnal, }, conn) } // RunUndo undo sql func (m *BaseUndoLogManager) RunUndo(ctx context.Context, xid string, branchID int64, conn *sql.DB, dbName string) error { return nil } // Undo undo sql func (m *BaseUndoLogManager) Undo(ctx context.Context, dbType types.DBType, xid string, branchID int64, db *sql.DB, dbName string) (err error) { conn, err := db.Conn(ctx) if err != nil { return err } tx, err := conn.BeginTx(ctx, &sql.TxOptions{}) if err != nil { return err } defer func() { if err != nil { if err = tx.Rollback(); err != nil { log.Errorf("rollback fail, xid: %s, branchID:%s err:%v", xid, branchID, err) return } } }() stmt, err := conn.PrepareContext(ctx, getSelectUndoLogSql()) if err != nil { log.Errorf("prepare sql fail, err: %v", err) return err } defer func() { if err = stmt.Close(); err != nil { log.Errorf("stmt close fail, xid: %s, branchID:%s err:%v", xid, branchID, err) return } }() rows, err := stmt.Query(branchID, xid) if err != nil { log.Errorf("query sql fail, err: %v", err) return err } defer func() { if err = rows.Close(); err != nil { log.Errorf("rows close fail, xid: %s, branchID:%s err:%v", xid, branchID, err) return } }() var undoLogRecords []undo.UndologRecord for rows.Next() { var record undo.UndologRecord err = rows.Scan(&record.BranchID, &record.XID, &record.Context, &record.RollbackInfo, &record.LogStatus) if err != nil { return err } undoLogRecords = append(undoLogRecords, record) } if err := rows.Err(); err != nil { log.Errorf("read rows next fail, xid: %s, branchID:%s err:%v", xid, branchID, err) return err } var exists bool for _, record := range undoLogRecords { exists = true if !record.CanUndo() { log.Infof("xid %v branch %v, ignore %v undo_log", record.XID, record.BranchID, record.LogStatus) return nil } var logCtx map[string]string if record.Context != nil && string(record.Context) != "" { logCtx = m.decodeUndoLogCtx(record.Context) } if logCtx == nil { return fmt.Errorf("undo log context not exist in record %+v", record) } rollbackInfo, err := m.getRollbackInfo(record.RollbackInfo, logCtx) if err != nil { return err } var branchUndoLog *undo.BranchUndoLog if branchUndoLog, err = m.deserializeBranchUndoLog(rollbackInfo, logCtx); err != nil { return err } sqlUndoLogs := branchUndoLog.Logs if len(sqlUndoLogs) == 0 { return nil } branchUndoLog.Reverse() for _, undoLog := range sqlUndoLogs { tableMeta, err := datasource.GetTableCache(types.DBTypeMySQL).GetTableMeta(ctx, dbName, undoLog.TableName) if err != nil { log.Errorf("get table meta fail, err: %v", err) return err } undoLog.SetTableMeta(tableMeta) undoExecutor, err := factor.GetUndoExecutor(dbType, undoLog) if err != nil { log.Errorf("get undo executor, err: %v", err) return err } if err = undoExecutor.ExecuteOn(ctx, dbType, conn); err != nil { log.Errorf("execute on fail, err: %v", err) return err } } } if exists { if err = m.DeleteUndoLog(ctx, xid, branchID, conn); err != nil { log.Errorf("[Undo] delete undo fail, err: %v", err) return err } log.Infof("xid %v branch %v, undo_log deleted with %v", xid, branchID, undo.UndoLogStatueGlobalFinished) } else { if err = m.insertUndoLogWithGlobalFinished(ctx, xid, uint64(branchID), conn); err != nil { log.Errorf("[Undo] insert undo with global finished fail, err: %v", err) return err } log.Infof("xid %v branch %v, undo_log added with %v", xid, branchID, undo.UndoLogStatueGlobalFinished) } if err = tx.Commit(); err != nil { log.Errorf("[Undo] execute on fail, err: %v", err) return nil } return nil } func (m *BaseUndoLogManager) insertUndoLogWithGlobalFinished(ctx context.Context, xid string, branchID uint64, conn *sql.Conn) error { // todo use config to replace parseContext := make(map[string]string, 0) parseContext[serializerKey] = undo.UndoConfig.LogSerialization parseContext[compressorTypeKey] = undo.UndoConfig.CompressConfig.Type undoLogContent := m.encodeUndoLogCtx(parseContext) logParse, err := parser.GetCache().Load(parseContext[serializerKey]) if err != nil { return err } rbInfo := logParse.GetDefaultContent() record := undo.UndologRecord{ BranchID: branchID, XID: xid, RollbackInfo: rbInfo, LogStatus: UndoLogStatusGlobalFinished, Context: undoLogContent, } err = m.InsertUndoLogWithSqlConn(ctx, record, conn) if err != nil { log.Errorf("insert undo log fail, err: %v", err) return err } return nil } // DBType func (m *BaseUndoLogManager) DBType() types.DBType { panic("implement me") } // HasUndoLogTable check undo log table if exist func (m *BaseUndoLogManager) HasUndoLogTable(ctx context.Context, conn *sql.Conn) (res bool, err error) { if _, err = conn.QueryContext(ctx, getCheckUndoLogTableExistSql()); err != nil { //nolint:rowserrcheck,sqlclosecheck // 1146 mysql table not exist fault code if e, ok := err.(*mysql.SQLError); ok && e.Code == mysql.ErrNoSuchTable { return false, nil } log.Errorf("[HasUndoLogTable] query sql fail, err: %v", err) return } return true, nil } // getBatchDeleteUndoLogSql build batch delete undo log func (m *BaseUndoLogManager) getBatchDeleteUndoLogSql(xid []string, branchID []int64) (string, error) { if len(xid) == 0 || len(branchID) == 0 { return "", fmt.Errorf("xid or branch_id can't nil") } var undoLogDeleteSql strings.Builder undoLogDeleteSql.WriteString(" DELETE FROM ") undoLogDeleteSql.WriteString(getUndoLogTableName()) undoLogDeleteSql.WriteString(" WHERE branch_id IN ") m.appendInParam(len(branchID), &undoLogDeleteSql) undoLogDeleteSql.WriteString(" AND xid IN ") m.appendInParam(len(xid), &undoLogDeleteSql) return undoLogDeleteSql.String(), nil } // appendInParam build in param func (m *BaseUndoLogManager) appendInParam(size int, str *strings.Builder) { if size <= 0 { return } str.WriteString(" (") for i := 0; i < size; i++ { str.WriteString("?") if i < size-1 { str.WriteString(",") } } str.WriteString(") ") } // Int64Slice2Str func Int64Slice2Str(values interface{}, sep string) (string, error) { v, ok := values.([]int64) if !ok { return "", fmt.Errorf("param type is fault") } var valuesText []string for i := range v { text := strconv.FormatInt(v[i], 10) valuesText = append(valuesText, text) } return strings.Join(valuesText, sep), nil } // canUndo check if it can undo func (m *BaseUndoLogManager) canUndo(state int32) bool { return state == UndoLogStatusNormal } func (m *BaseUndoLogManager) UnmarshalContext(undoContext []byte) (map[string]string, error) { res := make(map[string]string) if err := json.Unmarshal(undoContext, &res); err != nil { return nil, err } return res, nil } // getRollbackInfo parser rollback info func (m *BaseUndoLogManager) getRollbackInfo(rollbackInfo []byte, undoContext map[string]string) ([]byte, error) { var err error res := rollbackInfo // get compress type if v, ok := undoContext[compressorTypeKey]; ok { res, err = compressor.CompressorType(v).GetCompressor().Decompress(rollbackInfo) if err != nil { log.Errorf("[getRollbackInfo] decompress fail, err: %+v", err) return nil, err } } return res, nil } // getSerializer get serializer from undo context func (m *BaseUndoLogManager) getSerializer(undoLogContext map[string]string) (serializer string) { if undoLogContext == nil { return } serializer, _ = undoLogContext[serializerKey] return } func (m *BaseUndoLogManager) deserializeBranchUndoLog(rbInfo []byte, logCtx map[string]string) (*undo.BranchUndoLog, error) { var ( err error logParser parser.UndoLogParser ) if serialzerType := m.getSerializer(logCtx); serialzerType != "" { if logParser, err = parser.GetCache().Load(serialzerType); err != nil { return nil, err } } var branchUndoLog *undo.BranchUndoLog if branchUndoLog, err = logParser.Decode(rbInfo); err != nil { return nil, err } return branchUndoLog, nil } func (m *BaseUndoLogManager) serializeBranchUndoLog(log *undo.BranchUndoLog, serializerType string) ([]byte, error) { logParser, err := parser.GetCache().Load(serializerType) if err != nil { return nil, err } return logParser.Encode(log) } func (m *BaseUndoLogManager) encodeUndoLogCtx(undoLogCtx map[string]string) []byte { return collection.EncodeMap(undoLogCtx) } func (m *BaseUndoLogManager) decodeUndoLogCtx(undoLogCtx []byte) map[string]string { return collection.DecodeMap(undoLogCtx) }