dao/feature_view_hologres_dao.go (567 lines of code) (raw):

package dao import ( "database/sql" "fmt" "hash/crc32" "log" "strings" "sync" "time" "github.com/aliyun/aliyun-pai-featurestore-go-sdk/v2/api" "github.com/aliyun/aliyun-pai-featurestore-go-sdk/v2/datasource/hologres" "github.com/aliyun/aliyun-pai-featurestore-go-sdk/v2/utils" "github.com/expr-lang/expr" "github.com/expr-lang/expr/ast" "github.com/huandu/go-sqlbuilder" ) type FeatureViewHologresDao struct { UnimplementedFeatureViewDao db *sql.DB table string primaryKeyField string eventTimeField string ttl int mu sync.RWMutex stmtMap map[uint32]*sql.Stmt offlineTable string onlineTable string } func NewFeatureViewHologresDao(config DaoConfig) *FeatureViewHologresDao { dao := FeatureViewHologresDao{ table: config.HologresTableName, primaryKeyField: config.PrimaryKeyField, eventTimeField: config.EventTimeField, ttl: config.TTL, stmtMap: make(map[uint32]*sql.Stmt, 4), offlineTable: config.HologresOfflineTableName, onlineTable: config.HologresOnlineTableName, } hologres, err := hologres.GetHologres(config.HologresName) if err != nil { return nil } dao.db = hologres.DB return &dao } func (d *FeatureViewHologresDao) getStmt(key uint32) *sql.Stmt { d.mu.RLock() defer d.mu.RUnlock() return d.stmtMap[key] } func (d *FeatureViewHologresDao) GetFeatures(keys []interface{}, selectFields []string) ([]map[string]interface{}, error) { selector := make([]string, 0, len(selectFields)) for _, field := range selectFields { selector = append(selector, fmt.Sprintf("\"%s\"", field)) } builder := sqlbuilder.PostgreSQL.NewSelectBuilder() builder.Select(selector...) builder.From(d.table) builder.Where(builder.In(fmt.Sprintf("\"%s\"", d.primaryKeyField), keys...)) if d.ttl > 0 { t := time.Now().Add(time.Duration(-1 * d.ttl * int(time.Second))) builder.Where(builder.GreaterEqualThan(fmt.Sprintf("\"%s\"", d.eventTimeField), t)) } sql, args := builder.Build() stmtKey := crc32.ChecksumIEEE([]byte(sql)) //stmtKey := Md5(sql) stmt := d.getStmt(stmtKey) if stmt == nil { d.mu.Lock() stmt = d.stmtMap[stmtKey] if stmt == nil { stmt2, err := d.db.Prepare(sql) if err != nil { d.mu.Unlock() return nil, err } d.stmtMap[stmtKey] = stmt2 stmt = stmt2 d.mu.Unlock() } else { d.mu.Unlock() } } rows, err := stmt.Query(args...) if err != nil { return nil, err } defer rows.Close() result := make([]map[string]interface{}, 0, len(keys)) columns, _ := rows.ColumnTypes() values := ColumnValues(columns) for rows.Next() { if err := rows.Scan(values...); err == nil { properties := make(map[string]interface{}, len(values)) for i, column := range columns { name := column.Name() if value := ParseColumnValues(values[i]); value != nil { properties[name] = value } } result = append(result, properties) } } return result, nil } type sequenceInfo struct { itemId string event string playTime float64 timestamp int64 } func (d *FeatureViewHologresDao) GetUserSequenceFeature(keys []interface{}, userIdField string, sequenceConfig api.FeatureViewSeqConfig, onlineConfig []*api.SeqConfig) ([]map[string]interface{}, error) { var selectFields []string if sequenceConfig.PlayTimeField == "" { selectFields = []string{fmt.Sprintf("\"%s\"", sequenceConfig.ItemIdField), fmt.Sprintf("\"%s\"", sequenceConfig.EventField), fmt.Sprintf("\"%s\"", sequenceConfig.TimestampField)} } else { selectFields = []string{fmt.Sprintf("\"%s\"", sequenceConfig.ItemIdField), fmt.Sprintf("\"%s\"", sequenceConfig.EventField), fmt.Sprintf("\"%s\"", sequenceConfig.PlayTimeField), fmt.Sprintf("\"%s\"", sequenceConfig.TimestampField)} } currTime := time.Now().Unix() sequencePlayTimeMap := makePlayTimeMap(sequenceConfig.PlayTimeFilter) onlineFunc := func(seqEvent string, sequence_events []interface{}, seqLen int, key interface{}) []*sequenceInfo { onlineSequences := []*sequenceInfo{} builder := sqlbuilder.PostgreSQL.NewSelectBuilder() builder.Select(selectFields...) builder.From(d.onlineTable) where := []string{builder.Equal(fmt.Sprintf("\"%s\"", userIdField), key), builder.GreaterThan(fmt.Sprintf("\"%s\"", sequenceConfig.TimestampField), currTime-86400*5)} if len(sequence_events) > 1 { where = append(where, builder.In(fmt.Sprintf("\"%s\"", sequenceConfig.EventField), sequence_events...)) } else { where = append(where, builder.Equal(fmt.Sprintf("\"%s\"", sequenceConfig.EventField), seqEvent)) } builder.Where(where...) builder.Limit(seqLen) builder.OrderBy(fmt.Sprintf("\"%s\"", sequenceConfig.TimestampField)).Desc() sql, args := builder.Build() stmtKey := crc32.ChecksumIEEE([]byte(sql)) stmt := d.getStmt(stmtKey) if stmt == nil { d.mu.Lock() stmt = d.stmtMap[stmtKey] if stmt == nil { stmt2, err := d.db.Prepare(sql) if err != nil { d.mu.Unlock() log.Println(err) return nil } d.stmtMap[stmtKey] = stmt2 stmt = stmt2 d.mu.Unlock() } else { d.mu.Unlock() } } rows, err := stmt.Query(args...) if err != nil { log.Println(err) return nil } defer rows.Close() for rows.Next() { seq := new(sequenceInfo) var dst []interface{} if sequenceConfig.PlayTimeField == "" { dst = []interface{}{&seq.itemId, &seq.event, &seq.timestamp} } else { dst = []interface{}{&seq.itemId, &seq.event, &seq.playTime, &seq.timestamp} } if err := rows.Scan(dst...); err == nil { if seq.event == "" || seq.itemId == "" { continue } if t, exist := sequencePlayTimeMap[seq.event]; exist { if seq.playTime <= t { continue } } onlineSequences = append(onlineSequences, seq) } else { log.Println(err) return nil } } return onlineSequences } offlineFunc := func(seqEvent string, sequence_events []interface{}, seqLen int, key interface{}) []*sequenceInfo { offlineSequences := []*sequenceInfo{} builder := sqlbuilder.PostgreSQL.NewSelectBuilder() builder.Select(selectFields...) builder.From(d.offlineTable) where := []string{builder.Equal(fmt.Sprintf("\"%s\"", userIdField), key)} if len(sequence_events) > 1 { where = append(where, builder.In(fmt.Sprintf("\"%s\"", sequenceConfig.EventField), sequence_events...)) } else { where = append(where, builder.Equal(fmt.Sprintf("\"%s\"", sequenceConfig.EventField), seqEvent)) } builder.Where(where...) builder.Limit(seqLen) builder.OrderBy(fmt.Sprintf("\"%s\"", sequenceConfig.TimestampField)).Desc() sql, args := builder.Build() stmtKey := crc32.ChecksumIEEE([]byte(sql)) stmt := d.getStmt(stmtKey) if stmt == nil { d.mu.Lock() stmt = d.stmtMap[stmtKey] if stmt == nil { stmt2, err := d.db.Prepare(sql) if err != nil { d.mu.Unlock() log.Println(err) return nil } d.stmtMap[stmtKey] = stmt2 stmt = stmt2 d.mu.Unlock() } else { d.mu.Unlock() } } rows, err := stmt.Query(args...) if err != nil { log.Println(err) return nil } defer rows.Close() for rows.Next() { seq := new(sequenceInfo) var dst []interface{} if sequenceConfig.PlayTimeField == "" { dst = []interface{}{&seq.itemId, &seq.event, &seq.timestamp} } else { dst = []interface{}{&seq.itemId, &seq.event, &seq.playTime, &seq.timestamp} } if err := rows.Scan(dst...); err == nil { if seq.event == "" || seq.itemId == "" { continue } if t, exist := sequencePlayTimeMap[seq.event]; exist { if seq.playTime <= t { continue } } offlineSequences = append(offlineSequences, seq) } else { log.Println(err) return nil } } return offlineSequences } results := make([]map[string]interface{}, 0, len(keys)) var outmu sync.Mutex var wg sync.WaitGroup for _, key := range keys { wg.Add(1) go func(key interface{}) { defer wg.Done() properties := make(map[string]interface{}) var mu sync.Mutex var eventWg sync.WaitGroup for _, seqConfig := range onlineConfig { eventWg.Add(1) go func(seqConfig *api.SeqConfig) { defer eventWg.Done() var onlineSequences []*sequenceInfo var offlineSequences []*sequenceInfo origin_sequence_events := strings.Split(seqConfig.SeqEvent, "|") sequence_events := make([]interface{}, len(origin_sequence_events)) for i, v := range origin_sequence_events { sequence_events[i] = v } var innerWg sync.WaitGroup //get data from online table innerWg.Add(1) go func(seqEvent string, sequence_events []interface{}, seqLen int, key interface{}) { defer innerWg.Done() if onlineresult := onlineFunc(seqEvent, sequence_events, seqLen, key); onlineresult != nil { onlineSequences = onlineresult } }(seqConfig.SeqEvent, sequence_events, seqConfig.SeqLen, key) //get data from offline table innerWg.Add(1) go func(seqEvent string, sequence_events []interface{}, seqLen int, key interface{}) { defer innerWg.Done() if offlineresult := offlineFunc(seqEvent, sequence_events, seqLen, key); offlineresult != nil { offlineSequences = offlineresult } }(seqConfig.SeqEvent, sequence_events, seqConfig.SeqLen, key) innerWg.Wait() subproperties := makeSequenceFeatures(offlineSequences, onlineSequences, seqConfig, sequenceConfig, currTime) mu.Lock() defer mu.Unlock() for k, value := range subproperties { properties[k] = value } }(seqConfig) } eventWg.Wait() properties[userIdField] = key outmu.Lock() results = append(results, properties) outmu.Unlock() }(key) } wg.Wait() return results, nil } func (d *FeatureViewHologresDao) GetUserBehaviorFeature(userIds []interface{}, events []interface{}, selectFields []string, sequenceConfig api.FeatureViewSeqConfig) ([]map[string]interface{}, error) { selector := make([]string, 0, len(selectFields)) for _, field := range selectFields { selector = append(selector, fmt.Sprintf("\"%s\"", field)) } currTime := time.Now().Unix() sequencePlayTimeMap := makePlayTimeMap(sequenceConfig.PlayTimeFilter) onlineFunc := func(userId interface{}) []map[string]interface{} { builder := sqlbuilder.PostgreSQL.NewSelectBuilder() builder.Select(selector...) builder.From(d.onlineTable) where := []string{builder.Equal(fmt.Sprintf("\"%s\"", d.primaryKeyField), userId), builder.GreaterThan(fmt.Sprintf("\"%s\"", sequenceConfig.TimestampField), currTime-86400*5)} if len(events) > 0 { where = append(where, builder.In(fmt.Sprintf("\"%s\"", sequenceConfig.EventField), events...)) } builder.Where(where...) builder.OrderBy(fmt.Sprintf("\"%s\"", sequenceConfig.TimestampField)).Desc() sql, args := builder.Build() stmtKey := crc32.ChecksumIEEE([]byte(sql)) stmt := d.getStmt(stmtKey) if stmt == nil { d.mu.Lock() stmt = d.stmtMap[stmtKey] if stmt == nil { stmt2, err := d.db.Prepare(sql) if err != nil { d.mu.Unlock() log.Println(err) return nil } d.stmtMap[stmtKey] = stmt2 stmt = stmt2 d.mu.Unlock() } else { d.mu.Unlock() } } rows, err := stmt.Query(args...) if err != nil { log.Println(err) return nil } defer rows.Close() columns, _ := rows.ColumnTypes() values := ColumnValues(columns) result := make([]map[string]interface{}, 0, len(userIds)*len(events)*50) for rows.Next() { if err := rows.Scan(values...); err == nil { properties := make(map[string]interface{}, len(values)) for i, column := range columns { name := column.Name() if value := ParseColumnValues(values[i]); value != nil { properties[name] = value } } if t, exist := sequencePlayTimeMap[utils.ToString(properties[sequenceConfig.EventField], "")]; exist { if utils.ToFloat(properties[sequenceConfig.PlayTimeField], 0.0) <= t { continue } } result = append(result, properties) } } return result } offlineFunc := func(userId interface{}) []map[string]interface{} { builder := sqlbuilder.PostgreSQL.NewSelectBuilder() builder.Select(selector...) builder.From(d.offlineTable) where := []string{builder.Equal(fmt.Sprintf("\"%s\"", d.primaryKeyField), userId)} if len(events) > 0 { where = append(where, builder.In(fmt.Sprintf("\"%s\"", sequenceConfig.EventField), events...)) } builder.Where(where...) builder.OrderBy(fmt.Sprintf("\"%s\"", sequenceConfig.TimestampField)).Desc() sql, args := builder.Build() stmtKey := crc32.ChecksumIEEE([]byte(sql)) stmt := d.getStmt(stmtKey) if stmt == nil { d.mu.Lock() stmt = d.stmtMap[stmtKey] if stmt == nil { stmt2, err := d.db.Prepare(sql) if err != nil { d.mu.Unlock() log.Println(err) return nil } d.stmtMap[stmtKey] = stmt2 stmt = stmt2 d.mu.Unlock() } else { d.mu.Unlock() } } rows, err := stmt.Query(args...) if err != nil { log.Println(err) return nil } defer rows.Close() columns, _ := rows.ColumnTypes() values := ColumnValues(columns) result := make([]map[string]interface{}, 0, len(userIds)*len(events)*50) for rows.Next() { if err := rows.Scan(values...); err == nil { properties := make(map[string]interface{}, len(values)) for i, column := range columns { name := column.Name() if value := ParseColumnValues(values[i]); value != nil { properties[name] = value } } if t, exist := sequencePlayTimeMap[utils.ToString(properties[sequenceConfig.EventField], "")]; exist { if utils.ToFloat(properties[sequenceConfig.PlayTimeField], 0.0) <= t { continue } } result = append(result, properties) } } return result } results := make([]map[string]interface{}, 0, len(userIds)*(len(events)+1)*50) var outmu sync.Mutex var wg sync.WaitGroup for _, userId := range userIds { wg.Add(1) go func(userId interface{}) { defer wg.Done() var innerWg sync.WaitGroup var offlineResult []map[string]interface{} var onlineResult []map[string]interface{} // offline table innerWg.Add(1) go func(userId interface{}) { defer innerWg.Done() offlineResult = offlineFunc(userId) }(userId) // online table innerWg.Add(1) go func(userId interface{}) { defer innerWg.Done() onlineResult = onlineFunc(userId) }(userId) innerWg.Wait() if offlineResult == nil || onlineResult == nil { fmt.Println("get user behavior feature failed") return } combinedResult := combineBehaviorFeatures(offlineResult, onlineResult, sequenceConfig.TimestampField) outmu.Lock() results = append(results, combinedResult...) outmu.Unlock() }(userId) } wg.Wait() return results, nil } type Visitor struct { LastNode *ast.BinaryNode } func (v *Visitor) Visit(node *ast.Node) { switch n := (*node).(type) { case *ast.BinaryNode: v.LastNode = n } } func (v *Visitor) ConvertToSql(node *ast.BinaryNode) string { if node == nil { return "" } if node.Operator != "&&" && node.Operator != "||" { op := node.Operator if op == "==" { op = "=" } if leftNode, ok := node.Left.(*ast.IdentifierNode); ok { return fmt.Sprintf("%s %s '%s'", leftNode, op, strings.ReplaceAll(node.Right.String(), "\"", "")) } else { return fmt.Sprintf("'%s' %s %s", strings.ReplaceAll(node.Left.String(), "\"", ""), op, node.Right.String()) } } else if node.Operator == "&&" { left := v.ConvertToSql(node.Left.(*ast.BinaryNode)) right := v.ConvertToSql(node.Right.(*ast.BinaryNode)) return fmt.Sprintf("(%s) and (%s)", left, right) } else if node.Operator == "||" { left := v.ConvertToSql(node.Left.(*ast.BinaryNode)) right := v.ConvertToSql(node.Right.(*ast.BinaryNode)) return fmt.Sprintf("(%s) or (%s)", left, right) } return "" } func (d *FeatureViewHologresDao) RowCount(filterExpr string) int { builder := sqlbuilder.PostgreSQL.NewSelectBuilder() builder.Select("count(*)") builder.From(d.table) if filterExpr != "" { program, err := expr.Compile(filterExpr) if err != nil { fmt.Println(err) return 0 } node := program.Node() visitor := &Visitor{} ast.Walk(&node, visitor) sqlWhere := visitor.ConvertToSql(visitor.LastNode) builder.Where(sqlWhere) } sql, args := builder.Build() fmt.Println("row count sql:", sql) var count int retry := 3 for i := 0; i < retry; i++ { row := d.db.QueryRow(sql, args...) err := row.Scan(&count) if i == retry-1 { fmt.Println(err) return 0 } if err != nil { time.Sleep(100 * time.Millisecond) continue } return count } return count } func (d *FeatureViewHologresDao) RowCountIds(filterExpr string) ([]string, int, error) { builder := sqlbuilder.PostgreSQL.NewSelectBuilder() builder.Select(d.primaryKeyField) builder.From(d.table) if filterExpr != "" { program, err := expr.Compile(filterExpr) if err != nil { return nil, 0, err } node := program.Node() visitor := &Visitor{} ast.Walk(&node, visitor) sqlWhere := visitor.ConvertToSql(visitor.LastNode) builder.Where(sqlWhere) } sql, args := builder.Build() fmt.Println("sql:", sql) rows, err := d.db.Query(sql, args...) if err != nil { return nil, 0, err } defer rows.Close() ids := make([]string, 0, 1024) for rows.Next() { var id string if err := rows.Scan(&id); err != nil { return nil, 0, err } else { ids = append(ids, id) } } return ids, len(ids), nil }