func()

in pkg/datasource/sql/exec/at/insert_executor.go [326:419]


func (i *insertExecutor) parsePkValuesFromStatement(insertStmt *ast.InsertStmt, meta types.TableMeta, nameValues []driver.NamedValue) (map[string][]interface{}, error) {
	if insertStmt == nil {
		return nil, nil
	}
	pkIndexMap := i.getPkIndex(insertStmt, meta)
	if pkIndexMap == nil || len(pkIndexMap) == 0 {
		return nil, fmt.Errorf("pkIndex is not found")
	}
	var pkIndexArray []int
	for _, val := range pkIndexMap {
		tmpVal := val
		pkIndexArray = append(pkIndexArray, tmpVal)
	}

	if insertStmt == nil || len(insertStmt.Lists) == 0 {
		return nil, fmt.Errorf("parCtx is nil, perhaps InsertStmt is empty")
	}

	pkValuesMap := make(map[string][]interface{})

	if nameValues != nil && len(nameValues) > 0 {
		// use prepared statements
		insertRows, err := getInsertRows(insertStmt, pkIndexArray)
		if err != nil {
			return nil, err
		}
		if insertRows == nil || len(insertRows) == 0 {
			return nil, err
		}
		totalPlaceholderNum := -1
		for _, row := range insertRows {
			if len(row) == 0 {
				continue
			}
			currentRowPlaceholderNum := -1
			for _, r := range row {
				rStr, ok := r.(string)
				if ok && strings.EqualFold(rStr, sqlPlaceholder) {
					totalPlaceholderNum += 1
					currentRowPlaceholderNum += 1
				}
			}
			var pkKey string
			var pkIndex int
			var pkValues []interface{}
			for key, index := range pkIndexMap {
				curKey := key
				curIndex := index

				pkKey = curKey
				pkValues = pkValuesMap[pkKey]

				pkIndex = curIndex
				if pkIndex > len(row)-1 {
					continue
				}
				pkValue := row[pkIndex]
				pkValueStr, ok := pkValue.(string)
				if ok && strings.EqualFold(pkValueStr, sqlPlaceholder) {
					currentRowNotPlaceholderNumBeforePkIndex := 0
					for i := range row {
						r := row[i]
						rStr, ok := r.(string)
						if i < pkIndex && ok && !strings.EqualFold(rStr, sqlPlaceholder) {
							currentRowNotPlaceholderNumBeforePkIndex++
						}
					}
					idx := totalPlaceholderNum - currentRowPlaceholderNum + pkIndex - currentRowNotPlaceholderNumBeforePkIndex
					pkValues = append(pkValues, nameValues[idx].Value)
				} else {
					pkValues = append(pkValues, pkValue)
				}
				if _, ok := pkValuesMap[pkKey]; !ok {
					pkValuesMap[pkKey] = pkValues
				}
			}
		}
	} else {
		for _, list := range insertStmt.Lists {
			for pkName, pkIndex := range pkIndexMap {
				tmpPkName := pkName
				tmpPkIndex := pkIndex
				if tmpPkIndex >= len(list) {
					return nil, fmt.Errorf("pkIndex out of range")
				}
				if node, ok := list[tmpPkIndex].(ast.ValueExpr); ok {
					pkValuesMap[tmpPkName] = append(pkValuesMap[tmpPkName], node.GetValue())
				}
			}
		}
	}

	return pkValuesMap, nil
}