utils/ast/ast.go (312 lines of code) (raw):

package ast import ( "errors" "fmt" "math" "runtime/debug" "strconv" "strings" "sync" "github.com/alibaba/pairec/v2/log" valuate "github.com/bruceding/go-antlr-valuate" ) // 基础表达式节点接口 type ExprAST interface { toStr() string } // 数字表达式节点 type NumberExprAST struct { // 具体的值 Val float64 } type ParameterExprAST struct { Val string } // 操作表达式节点 type BinaryExprAST struct { // 操作符 Op string // 左右节点,可能是 数字表达式节点/操作表达式节点/nil Lhs, Rhs ExprAST } // 实现接口 func (n NumberExprAST) toStr() string { return fmt.Sprintf( "NumberExprAST:%s", strconv.FormatFloat(n.Val, 'f', 0, 64), ) } // 实现接口 func (n ParameterExprAST) toStr() string { return fmt.Sprintf( "ParameterExprAST:%s", n.Val, ) } // 实现接口 func (b BinaryExprAST) toStr() string { return fmt.Sprintf( "BinaryExprAST: (%s %s %s)", b.Op, b.Lhs.toStr(), b.Rhs.toStr(), ) } // AST 生成器结构体 type AST struct { // 词法分析的结果 Tokens []*Token // 源字符串 source string // 当前分析器分析的 Token currTok *Token // 当前分析器的位置 currIndex int // 错误收集 Err error } // 定义操作符优先级,value 越高,优先级越高 var precedence = map[string]int{"+": 20, "-": 20, "*": 40, "/": 40, "%": 40, "^": 60, "#": 80} // 语法分析器入口 func (a *AST) ParseExpression() ExprAST { lhs := a.parsePrimary() return a.parseBinOpRHS(0, lhs) } // 获取下一个 Token func (a *AST) getNextToken() *Token { a.currIndex++ if a.currIndex < len(a.Tokens) { a.currTok = a.Tokens[a.currIndex] return a.currTok } return nil } // 获取操作优先级 func (a *AST) getTokPrecedence() int { if p, ok := precedence[a.currTok.Tok]; ok { return p } return -1 } // 解析数字,并生成一个 NumberExprAST 节点 func (a *AST) parseNumber() NumberExprAST { f64, err := strconv.ParseFloat(a.currTok.Tok, 64) if err != nil { a.Err = errors.New( fmt.Sprintf("%v\nwant '(' or '0-9' but get '%s'\n%s", err.Error(), a.currTok.Tok, ErrPos(a.source, a.currTok.Offset))) return NumberExprAST{} } n := NumberExprAST{ Val: f64, } a.getNextToken() return n } // 解析参数 func (a *AST) parseParameter() ParameterExprAST { n := ParameterExprAST{ Val: a.currTok.Tok, } a.getNextToken() return n } // 获取一个节点,返回 ExprAST // 这里会处理所有可能出现的类型,并对相应类型做解析 func (a *AST) parsePrimary() ExprAST { switch a.currTok.Type { case Literal: return a.parseNumber() case Parameter: return a.parseParameter() case Operator: // 对 () 语法处理 if a.currTok.Tok == "(" { a.getNextToken() e := a.ParseExpression() if e == nil { return nil } if a.currTok.Tok != ")" { a.Err = errors.New( fmt.Sprintf("want ')' but get %s\n%s", a.currTok.Tok, ErrPos(a.source, a.currTok.Offset))) return nil } a.getNextToken() return e } else { return a.parseNumber() } default: return nil } } // 循环获取操作符的优先级,将高优先级的递归成较深的节点 // 这是生成正确的 AST 结构最重要的一个算法,一定要仔细阅读、理解 func (a *AST) parseBinOpRHS(execPrec int, lhs ExprAST) ExprAST { for { tokPrec := a.getTokPrecedence() if tokPrec < execPrec { return lhs } binOp := a.currTok.Tok if a.getNextToken() == nil { return lhs } rhs := a.parsePrimary() if rhs == nil { return nil } nextPrec := a.getTokPrecedence() if tokPrec < nextPrec { // 递归,将当前优先级+1 rhs = a.parseBinOpRHS(tokPrec+1, rhs) if rhs == nil { return nil } } lhs = BinaryExprAST{ Op: binOp, Lhs: lhs, Rhs: rhs, } } } // 生成一个 AST 结构指针 func NewAST(toks []*Token, s string) *AST { a := &AST{ Tokens: toks, source: s, } if a.Tokens == nil || len(a.Tokens) == 0 { a.Err = errors.New("empty token") } else { a.currIndex = 0 a.currTok = a.Tokens[0] } return a } // 一个典型的后序遍历求解算法 func ExprASTResult(expr ExprAST, exprDatas ...ParameterExprData) float64 { // 左右值 var l, r float64 switch expr.(type) { // 传入的根节点是 BinaryExprAST case BinaryExprAST: ast := expr.(BinaryExprAST) // 递归左节点 l = ExprASTResult(ast.Lhs, exprDatas...) // 递归右节点 r = ExprASTResult(ast.Rhs, exprDatas...) // 现在 l,r 都有具体的值了,可以根据运算符运算 switch ast.Op { case "#": if l != 0.0 { return l } else { return r } case "^": return math.Pow(l, r) case "+": return l + r case "-": return l - r case "*": return l * r case "/": if r == 0 { panic(errors.New( fmt.Sprintf("violation of arithmetic specification: a division by zero in ExprASTResult: [%g/%g]", l, r))) } return l / r case "%": return float64(int(l) % int(r)) default: } // 传入的根节点是 NumberExprAST,无需做任何事情,直接返回 Val 值 case NumberExprAST: return expr.(NumberExprAST).Val case ParameterExprAST: val := expr.(ParameterExprAST).Val for _, data := range exprDatas { if f, err := data.FloatExprData(val); err == nil { return f } } } return 0.0 } // should use sync map var caches = make(map[string]ExprAST) var cachesByAntlr = make(map[string]ExprAST) var mutex sync.RWMutex type exprAST struct { expression *valuate.EvaluableExpression } func (e *exprAST) Evaluate(data map[string]any) (result float64, err error) { defer func() { if r := recover(); r != nil { stack := string(debug.Stack()) log.Error(fmt.Sprintf("error=%v, stack=%s", err, strings.ReplaceAll(stack, "\n", "\t"))) result = float64(0) err = nil } }() ret, err1 := e.expression.Evaluate(data) if err1 != nil { err = err1 return } if r, ok := ret.(float64); ok { result = r return } else { result = float64(0) err = fmt.Errorf("expression invoke result:%v", ret) return } } func (e *exprAST) toStr() string { return "" } func GetExpASTByAntlr(source string) (ExprAST, error) { if source == "" { return nil, nil } var exprAst ExprAST mutex.RLock() exprAst, ok := cachesByAntlr[source] mutex.RUnlock() if !ok { expression, err := valuate.NewEvaluableExpression(source) if err != nil { return nil, err } exprAst = &exprAST{ expression: expression, } mutex.Lock() cachesByAntlr[source] = exprAst mutex.Unlock() } return exprAst, nil } func GetExpASTWithType(source, astType string) (ExprAST, error) { if astType == "antlr" { return GetExpASTByAntlr(source) } return GetExpAST(source) } func GetExpAST(source string) (ExprAST, error) { if source == "" { return nil, nil } var exprAst ExprAST mutex.RLock() exprAst, ok := caches[source] mutex.RUnlock() if !ok { tokens, err := Parse(source) if err != nil { return nil, err } ast := NewAST(tokens, source) exprAst = ast.ParseExpression() mutex.Lock() caches[source] = exprAst mutex.Unlock() } return exprAst, nil } func ExprASTResultWithType(expr ExprAST, exprDatas ParameterExprData, astType string) float64 { if astType == "antlr" { return ExprASTResultByAntlr(expr, exprDatas) } else { return ExprASTResult(expr, exprDatas) } } func ExprASTResultByAntlr(expr ExprAST, exprDatas ParameterExprData) float64 { switch ast := expr.(type) { // 传入的根节点是 BinaryExprAST case *exprAST: data := exprDatas.ExprData() result, err := ast.Evaluate(data) if err != nil { log.Error(fmt.Sprintf("expression invoke error:%v", err)) return float64(0) } return result } return float64(0) }