in pkg/datasource/sql/undo/builder/mysql_multi_update_undo_log_builder.go [149:236]
func (u *MySQLMultiUpdateUndoLogBuilder) buildBeforeImageSQL(updateStmts []*ast.UpdateStmt, args []driver.Value) (string, []driver.Value, error) {
if len(updateStmts) == 0 {
log.Errorf("invalid multi update stmt")
return "", nil, fmt.Errorf("invalid muliti update stmt")
}
var newArgs []driver.Value
var fields []*ast.SelectField
fieldsExits := make(map[string]struct{})
var whereCondition strings.Builder
for _, updateStmt := range updateStmts {
if updateStmt.Limit != nil {
return "", nil, fmt.Errorf("multi update SQL with limit condition is not support yet")
}
if updateStmt.Order != nil {
return "", nil, fmt.Errorf("multi update SQL with orderBy condition is not support yet")
}
// todo use ONLY_CARE_UPDATE_COLUMNS to judge select all columns or not
for _, column := range updateStmt.List {
if _, exist := fieldsExits[column.Column.String()]; exist {
continue
}
fieldsExits[column.Column.String()] = struct{}{}
fields = append(fields, &ast.SelectField{
Expr: &ast.ColumnNameExpr{
Name: column.Column,
},
})
}
tmpSelectStmt := ast.SelectStmt{
SelectStmtOpts: &ast.SelectStmtOpts{},
From: updateStmt.TableRefs,
Where: updateStmt.Where,
Fields: &ast.FieldList{Fields: fields},
OrderBy: updateStmt.Order,
Limit: updateStmt.Limit,
TableHints: updateStmt.TableHints,
LockInfo: &ast.SelectLockInfo{
LockType: ast.SelectLockForUpdate,
},
}
newArgs = append(newArgs, u.buildSelectArgs(&tmpSelectStmt, args)...)
in := bytes.NewByteBuffer([]byte{})
updateStmt.Where.Restore(format.NewRestoreCtx(format.RestoreKeyWordUppercase, in))
whereConditionStr := string(in.Bytes())
if whereCondition.Len() > 0 {
whereCondition.Write([]byte(" OR "))
}
whereCondition.Write([]byte(whereConditionStr))
}
// only just get the where condition
fakeSql := "select * from t where " + whereCondition.String()
fakeStmt, err := parser.New().ParseOneStmt(fakeSql, "", "")
if err != nil {
return "", nil, errors.Wrap(err, "multi update parse fake sql error")
}
fakeNode, ok := fakeStmt.Accept(&updateVisitor{})
if !ok {
return "", nil, errors.Wrap(err, "multi update accept update visitor error")
}
fakeSelectStmt, ok := fakeNode.(*ast.SelectStmt)
if !ok {
return "", nil, fmt.Errorf("multi update fake node is not select stmt")
}
selStmt := ast.SelectStmt{
SelectStmtOpts: &ast.SelectStmtOpts{},
From: updateStmts[0].TableRefs,
Where: fakeSelectStmt.Where,
Fields: &ast.FieldList{Fields: fields},
TableHints: updateStmts[0].TableHints,
LockInfo: &ast.SelectLockInfo{
LockType: ast.SelectLockForUpdate,
},
}
b := bytes.NewByteBuffer([]byte{})
selStmt.Restore(format.NewRestoreCtx(format.RestoreKeyWordUppercase, b))
sql := string(b.Bytes())
log.Infof("build select sql by update sourceQuery, sql {}", sql)
return sql, newArgs, nil
}