pkg/datasource/sql/undo/executor/sql.go (145 lines of code) (raw):
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package executor
import (
"database/sql"
"strings"
"seata.apache.org/seata-go/pkg/datasource/sql/types"
"seata.apache.org/seata-go/pkg/datasource/sql/undo"
)
const (
Dot = "."
EscapeStandard = "\""
EscapeMysql = "`"
)
// DelEscape del escape by db type
func DelEscape(colName string, dbType types.DBType) string {
newColName := delEscape(colName, EscapeStandard)
if dbType == types.DBTypeMySQL {
newColName = delEscape(newColName, EscapeMysql)
}
return newColName
}
// delEscape
func delEscape(colName string, escape string) string {
if colName == "" {
return ""
}
if string(colName[0]) == escape && string(colName[len(colName)-1]) == escape {
// like "scheme"."id" `scheme`.`id`
str := escape + Dot + escape
index := strings.Index(colName, str)
if index > -1 {
return colName[1:index] + Dot + colName[index+len(str):len(colName)-1]
}
return colName[1 : len(colName)-1]
} else {
// like "scheme".id `scheme`.id
str := escape + Dot
index := strings.Index(colName, str)
if index > -1 && string(colName[0]) == escape {
return colName[1:index] + Dot + colName[index+len(str):]
}
// like scheme."id" scheme.`id`
str = Dot + escape
index = strings.Index(colName, str)
if index > -1 && string(colName[len(colName)-1]) == escape {
return colName[0:index] + Dot + colName[index+len(str):len(colName)-1]
}
}
return colName
}
// AddEscape if necessary, add escape by db type
func AddEscape(colName string, dbType types.DBType) string {
if dbType == types.DBTypeMySQL {
return addEscape(colName, dbType, EscapeMysql)
}
return addEscape(colName, dbType, EscapeStandard)
}
func addEscape(colName string, dbType types.DBType, escape string) string {
if colName == "" {
return colName
}
if string(colName[0]) == escape && string(colName[len(colName)-1]) == escape {
return colName
}
if !checkEscape(colName, dbType) {
return colName
}
if strings.Contains(colName, Dot) {
// like "scheme".id `scheme`.id
str := escape + Dot
dotIndex := strings.Index(colName, str)
if dotIndex > -1 {
tempStr := strings.Builder{}
tempStr.WriteString(colName[0 : dotIndex+len(str)])
tempStr.WriteString(escape)
tempStr.WriteString(colName[dotIndex+len(str):])
tempStr.WriteString(escape)
return tempStr.String()
}
// like scheme."id" scheme.`id`
str = Dot + escape
dotIndex = strings.Index(colName, str)
if dotIndex > -1 {
tempStr := strings.Builder{}
tempStr.WriteString(escape)
tempStr.WriteString(colName[0:dotIndex])
tempStr.WriteString(escape)
tempStr.WriteString(colName[dotIndex:])
return tempStr.String()
}
str = Dot
dotIndex = strings.Index(colName, str)
if dotIndex > -1 {
tempStr := strings.Builder{}
tempStr.WriteString(escape)
tempStr.WriteString(colName[0:dotIndex])
tempStr.WriteString(escape)
tempStr.WriteString(Dot)
tempStr.WriteString(escape)
tempStr.WriteString(colName[dotIndex+len(str):])
tempStr.WriteString(escape)
return tempStr.String()
}
}
buf := make([]byte, len(colName)+2)
buf[0], buf[len(buf)-1] = escape[0], escape[0]
for key := range colName {
buf[key+1] = colName[key]
}
return string(buf)
}
// checkEscape check whether given field or table name use keywords. the method has database special logic.
func checkEscape(colName string, dbType types.DBType) bool {
switch dbType {
case types.DBTypeMySQL:
if _, ok := types.GetMysqlKeyWord()[strings.ToUpper(colName)]; ok {
return true
}
return false
// TODO impl Oracle PG SQLServer ...
default:
return true
}
}
// BuildWhereConditionByPKs each pk is a condition.the result will like :" id =? and userCode =?"
func BuildWhereConditionByPKs(pkNameList []string, dbType types.DBType) string {
whereStr := strings.Builder{}
for i := 0; i < len(pkNameList); i++ {
if i > 0 {
whereStr.WriteString(" and ")
}
pkName := pkNameList[i]
whereStr.WriteString(AddEscape(pkName, dbType))
whereStr.WriteString(" = ? ")
}
return whereStr.String()
}
// DataValidationAndGoOn check data valid
// Todo implement dataValidationAndGoOn
func DataValidationAndGoOn(sqlUndoLog undo.SQLUndoLog, conn *sql.Conn) bool {
return true
}
func GetOrderedPkList(image *types.RecordImage, row types.RowImage, dbType types.DBType) ([]types.ColumnImage, error) {
pkColumnNameListByOrder := image.TableMeta.GetPrimaryKeyOnlyName()
pkColumnNameListNoOrder := make([]types.ColumnImage, 0)
pkFields := make([]types.ColumnImage, 0)
for _, column := range row.PrimaryKeys(row.Columns) {
column.ColumnName = DelEscape(column.ColumnName, dbType)
pkColumnNameListNoOrder = append(pkColumnNameListNoOrder, column)
}
for _, pkName := range pkColumnNameListByOrder {
for _, col := range pkColumnNameListNoOrder {
if strings.Index(col.ColumnName, pkName) > -1 {
pkFields = append(pkFields, col)
}
}
}
return pkFields, nil
}