pkg/datasource/sql/exec/at/multi_executor.go (124 lines of code) (raw):

/* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package at import ( "context" "fmt" "seata.apache.org/seata-go/pkg/datasource/sql/exec" "seata.apache.org/seata-go/pkg/datasource/sql/types" "seata.apache.org/seata-go/pkg/util/log" ) type multiExecutor struct { baseExecutor parserCtx *types.ParseContext execContext *types.ExecContext } // NewMultiExecutor get new multi executor func NewMultiExecutor(parserCtx *types.ParseContext, execContext *types.ExecContext, hooks []exec.SQLHook) executor { return &multiExecutor{parserCtx: parserCtx, execContext: execContext, baseExecutor: baseExecutor{hooks: hooks}} } // ExecContext exec SQL, and generate before image and after image func (m *multiExecutor) ExecContext(ctx context.Context, f exec.CallbackWithNamedValue) (types.ExecResult, error) { m.beforeHooks(ctx, m.execContext) defer func() { m.afterHooks(ctx, m.execContext) }() beforeImages, err := m.beforeImage(ctx, m.parserCtx) if err != nil { return nil, err } res, err := f(ctx, m.execContext.Query, m.execContext.NamedValues) if err != nil { return nil, err } afterImages, err := m.afterImage(ctx, m.parserCtx, beforeImages) if err != nil { return nil, err } for _, beforeImage := range beforeImages { m.execContext.TxCtx.RoundImages.AppendBeofreImage(beforeImage) } for _, afterImage := range afterImages { m.execContext.TxCtx.RoundImages.AppendAfterImage(afterImage) } return res, nil } func (m *multiExecutor) beforeImage(ctx context.Context, parseContext *types.ParseContext) ([]*types.RecordImage, error) { if len(parseContext.MultiStmt) == 0 { return nil, nil } tableParsers, err := m.groupParsersByTableName(parseContext) if err != nil { log.Infof("group parsers by table name failed, %s", err) return nil, err } var beforeImages = make([]*types.RecordImage, 0) for _, multiParser := range tableParsers { var images []*types.RecordImage switch multiParser.ExecutorType { case types.UpdateExecutor: multiUpdateExec := NewMultiUpdateExecutor(multiParser, m.execContext, m.hooks) images, err = multiUpdateExec.beforeImage(ctx) case types.DeleteExecutor: multiDeleteExec := NewMultiDeleteExecutor(multiParser, m.execContext, m.hooks) images, err = multiDeleteExec.beforeImage(ctx) default: return nil, fmt.Errorf("not support multi sql %s", m.execContext.Query) } if err != nil { return nil, err } beforeImages = append(beforeImages, images...) } return beforeImages, err } func (m *multiExecutor) afterImage(ctx context.Context, parseContext *types.ParseContext, beforeImages []*types.RecordImage) ([]*types.RecordImage, error) { if len(parseContext.MultiStmt) == 0 { return nil, nil } tableParsers, err := m.groupParsersByTableName(parseContext) if err != nil { log.Infof("group parsers by table name failed, %s", err) return nil, err } var afterImages = make([]*types.RecordImage, 0) for _, multiParser := range tableParsers { var images []*types.RecordImage switch multiParser.ExecutorType { case types.UpdateExecutor: multiUpdateExec := NewMultiUpdateExecutor(multiParser, m.execContext, m.hooks) images, err = multiUpdateExec.afterImage(ctx, beforeImages) case types.DeleteExecutor: multiDeleteExec := NewMultiDeleteExecutor(multiParser, m.execContext, m.hooks) images, err = multiDeleteExec.afterImage(ctx) default: return nil, fmt.Errorf("not support multi sql %s", m.execContext.Query) } if err != nil { return nil, err } afterImages = append(afterImages, images...) } return afterImages, err } func (m *multiExecutor) groupParsersByTableName(parseContext *types.ParseContext) (map[string]*types.ParseContext, error) { var ( err error tableName string tableParsers = make(map[string]*types.ParseContext, len(parseContext.MultiStmt)) ) for _, parser := range parseContext.MultiStmt { tempParser := *parser tableName, err = parser.GetTableName() if err != nil { return nil, err } if stmtList, ok := tableParsers[tableName]; ok { sts := append(stmtList.MultiStmt, &tempParser) tableParsers[tableName].MultiStmt = sts } else { tableParsers[tableName] = &types.ParseContext{ SQLType: parser.SQLType, ExecutorType: parser.ExecutorType, MultiStmt: []*types.ParseContext{&tempParser}, } } } return tableParsers, err }