in pkg/datasource/sql/exec/at/multi_update_excutor.go [233:343]
func (u *multiUpdateExecutor) buildBeforeImageSQL(args []driver.NamedValue, meta *types.TableMeta) (string, []driver.NamedValue, error) {
if !u.isAstStmtValid() {
log.Errorf("invalid multi update stmt")
return "", nil, errors.New("invalid muliti update stmt")
}
var (
whereCondition strings.Builder
multiStmts = u.parserCtx.MultiStmt
newArgs = make([]driver.NamedValue, 0, len(u.parserCtx.MultiStmt))
fields = make([]*ast.SelectField, 0, len(meta.ColumnNames))
fieldsExits = make(map[string]struct{}, len(meta.ColumnNames))
)
for _, multiStmt := range u.parserCtx.MultiStmt {
updateStmt := multiStmt.UpdateStmt
if updateStmt.Limit != nil {
return "", nil, fmt.Errorf("multi update SQL with limit condition is not support yet")
}
if updateStmt.Order != nil {
return "", nil, fmt.Errorf("multi update SQL with orderBy condition is not support yet")
}
if undo.UndoConfig.OnlyCareUpdateColumns {
//select update columns
for _, column := range updateStmt.List {
if _, exist := fieldsExits[column.Column.String()]; exist {
continue
}
fieldsExits[column.Column.String()] = struct{}{}
fields = append(fields, &ast.SelectField{Expr: &ast.ColumnNameExpr{Name: column.Column}})
}
for _, columnName := range meta.GetPrimaryKeyOnlyName() {
if _, exist := fieldsExits[columnName]; exist {
continue
}
//select index columns
fieldsExits[columnName] = struct{}{}
fields = append(fields, &ast.SelectField{
Expr: &ast.ColumnNameExpr{Name: &ast.ColumnName{Name: model.CIStr{O: columnName, L: columnName}}},
})
}
} else {
fields = make([]*ast.SelectField, 0, len(meta.ColumnNames))
for _, column := range meta.ColumnNames {
fields = append(fields, &ast.SelectField{
Expr: &ast.ColumnNameExpr{Name: &ast.ColumnName{Name: model.CIStr{O: column}}}})
}
}
tmpSelectStmt := ast.SelectStmt{
SelectStmtOpts: &ast.SelectStmtOpts{},
From: updateStmt.TableRefs,
Where: updateStmt.Where,
Fields: &ast.FieldList{Fields: fields},
OrderBy: updateStmt.Order,
Limit: updateStmt.Limit,
TableHints: updateStmt.TableHints,
LockInfo: &ast.SelectLockInfo{
LockType: ast.SelectLockForUpdate,
},
}
newArgs = append(newArgs, u.buildSelectArgs(&tmpSelectStmt, args)...)
in := bytes.NewByteBuffer([]byte{})
_ = updateStmt.Where.Restore(format.NewRestoreCtx(format.RestoreKeyWordUppercase, in))
if whereCondition.Len() > 0 {
whereCondition.Write([]byte(" OR "))
}
whereCondition.Write(in.Bytes())
}
// only just get the where condition
fakeSql := "select * from t where " + whereCondition.String()
fakeStmt, err := parser.New().ParseOneStmt(fakeSql, "", "")
if err != nil {
log.Errorf("multi update parse fake sql error")
return "", nil, err
}
fakeNode, ok := fakeStmt.Accept(&updateVisitor{})
if !ok {
log.Errorf("multi update accept update visitor error")
return "", nil, err
}
fakeSelectStmt, ok := fakeNode.(*ast.SelectStmt)
if !ok {
log.Errorf("multi update fake node is not select stmt")
return "", nil, err
}
selStmt := ast.SelectStmt{
SelectStmtOpts: &ast.SelectStmtOpts{},
From: multiStmts[0].UpdateStmt.TableRefs,
Where: fakeSelectStmt.Where,
Fields: &ast.FieldList{Fields: fields},
TableHints: multiStmts[0].UpdateStmt.TableHints,
LockInfo: &ast.SelectLockInfo{
LockType: ast.SelectLockForUpdate,
},
}
b := bytes.NewByteBuffer([]byte{})
selStmt.Restore(format.NewRestoreCtx(format.RestoreKeyWordUppercase, b))
log.Infof("build select sql by update sourceQuery, sql {}", string(b.Bytes()))
return string(b.Bytes()), newArgs, nil
}