common/db_opertion.go (300 lines of code) (raw):
package utils
import (
"fmt"
LOG "github.com/vinllen/log4go"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
"math"
"strconv"
"strings"
"sort"
)
var (
QueryTs = "ts"
localDB = "local"
)
const (
DBRefRef = "$ref"
DBRefId = "$id"
DBRefDb = "$db"
CollectionCapped = "CollectionScan died due to position in capped" // bigger than 3.0
CollectionCappedLowVersion = "UnknownError" // <= 3.0 version
)
// for UT only
var (
GetAllTimestampInUTInput map[string]Pair // replicaSet/MongoS name => <oldest timestamp, newest timestamp>
)
/************************************************/
type MongoSource struct {
URL string
ReplicaName string
Gids []string
}
func (ms *MongoSource) String() string {
return fmt.Sprintf("url[%v], name[%v]", BlockMongoUrlPassword(ms.URL, "***"), ms.ReplicaName)
}
// get db version, return string with format like "3.0.1"
func GetDBVersion(conn *MongoCommunityConn) (string, error) {
res, err := conn.Client.Database("admin").
RunCommand(conn.ctx, bson.D{{"buildInfo", 1}}).DecodeBytes()
if err != nil {
return "", err
}
ver, ok := res.Lookup("version").StringValueOK()
if !ok {
return "", fmt.Errorf("buildInfo do not have version")
}
return ver, nil
}
// get current db version and compare to threshold. Return whether the result
// is bigger or equal to the input threshold.
func GetAndCompareVersion(conn *MongoCommunityConn, threshold string, compare string) (bool, error) {
var err error
if compare == "" {
if conn == nil {
return false, nil
}
compare, err = GetDBVersion(conn)
if err != nil {
return false, err
}
}
compareArr := strings.Split(compare, ".")
thresholdArr := strings.Split(threshold, ".")
if len(compareArr) < 2 || len(thresholdArr) < 2 {
return false, nil
}
for i := 0; i < 2; i++ {
compareEle, errC := strconv.Atoi(compareArr[i])
thresholdEle, errT := strconv.Atoi(thresholdArr[i])
if errC != nil || errT != nil {
return false, fmt.Errorf("errC:[%v], errT:[%v]", errC, errT)
}
if compareEle > thresholdEle {
return true, nil
} else if compareEle < thresholdEle {
return false, fmt.Errorf("compare[%v] < threshold[%v]", compare, threshold)
}
}
return true, nil
}
func ApplyOpsFilter(key string) bool {
// convert to map if has more later
k := strings.TrimSpace(key)
if k == "$db" {
// 40621, $db is not allowed in OP_QUERY requests
return true
} else if k == "ui" {
return true
}
return false
}
func getOplogTimestamp(conn *MongoCommunityConn, sortType int) (int64, error) {
var result bson.M
opts := options.FindOne().SetSort(bson.D{{"$natural", sortType}})
err := conn.Client.Database(localDB).Collection(OplogNS).FindOne(nil, bson.M{}, opts).Decode(&result)
if err != nil {
return 0, err
}
return TimeStampToInt64(result["ts"].(primitive.Timestamp)), nil
}
// get newest oplog
func GetNewestTimestampByConn(conn *MongoCommunityConn) (int64, error) {
return getOplogTimestamp(conn, -1)
}
// get oldest oplog
func GetOldestTimestampByConn(conn *MongoCommunityConn) (int64, error) {
return getOplogTimestamp(conn, 1)
}
func GetNewestTimestampByUrl(url string, fromMongoS bool, sslRootFile string) (int64, error) {
var conn *MongoCommunityConn
var err error
if conn, err = NewMongoCommunityConn(url, VarMongoConnectModeSecondaryPreferred, true,
ReadWriteConcernDefault, ReadWriteConcernDefault, sslRootFile); conn == nil || err != nil {
return 0, err
}
defer conn.Close()
if fromMongoS {
return TimeStampToInt64(conn.CurrentDate()), nil
}
return GetNewestTimestampByConn(conn)
}
func GetOldestTimestampByUrl(url string, fromMongoS bool, sslRootFile string) (int64, error) {
if fromMongoS {
return 0, nil
}
var conn *MongoCommunityConn
var err error
if conn, err = NewMongoCommunityConn(url, VarMongoConnectModeSecondaryPreferred, true,
ReadWriteConcernDefault, ReadWriteConcernDefault, sslRootFile); conn == nil || err != nil {
return 0, err
}
defer conn.Close()
return GetOldestTimestampByConn(conn)
}
// record the oldest and newest timestamp of each mongod
type TimestampNode struct {
Oldest int64
Newest int64
}
/*
* get all newest timestamp
* return:
* map: whole timestamp map, key: replset name, value: struct that includes the newest and oldest timestamp
* primitive.Timestamp: the biggest of the newest timestamp
* primitive.Timestamp: the smallest of the newest timestamp
* error: error
*/
func GetAllTimestamp(sources []*MongoSource, sslRootFile string) (map[string]TimestampNode, int64,
int64, int64, int64, error) {
smallestNew := int64(math.MaxInt64)
biggestNew := int64(0)
smallestOld := int64(math.MaxInt64)
biggestOld := int64(0)
tsMap := make(map[string]TimestampNode)
for _, src := range sources {
newest, err := GetNewestTimestampByUrl(src.URL, false, sslRootFile)
if err != nil {
return nil, 0, 0, 0, 0, err
} else if newest == 0 {
return nil, 0, 0, 0, 0, fmt.Errorf("illegal newest timestamp == 0")
}
oldest, err := GetOldestTimestampByUrl(src.URL, false, sslRootFile)
if err != nil {
return nil, 0, 0, 0, 0, err
}
tsMap[src.ReplicaName] = TimestampNode{
Oldest: oldest,
Newest: newest,
}
if newest > biggestNew {
biggestNew = newest
}
if newest < smallestNew {
smallestNew = newest
}
if oldest > biggestOld {
biggestOld = oldest
}
if oldest < smallestOld {
smallestOld = oldest
}
}
LOG.Info("GetAllTimestamp biggestNew:%v, smallestNew:%v, biggestOld:%v, smallestOld:%v,"+
" MongoSource:%v, tsMap:%v",
Int64ToTimestamp(biggestNew), Int64ToTimestamp(smallestNew),
Int64ToTimestamp(biggestOld), Int64ToTimestamp(smallestOld),
sources, tsMap)
return tsMap, biggestNew, smallestNew, biggestOld, smallestOld, nil
}
// only used in unit test
func GetAllTimestampInUT() (map[string]TimestampNode, int64,
int64, int64, int64, error) {
smallestNew := int64(math.MaxInt64)
biggestNew := int64(0)
smallestOld := int64(math.MaxInt64)
biggestOld := int64(0)
tsMap := make(map[string]TimestampNode)
for name, ele := range GetAllTimestampInUTInput {
oldest := ele.First.(int64)
newest := ele.Second.(int64)
tsMap[name] = TimestampNode{
Oldest: oldest,
Newest: newest,
}
if newest > biggestNew {
biggestNew = newest
}
if newest < smallestNew {
smallestNew = newest
}
if oldest > biggestOld {
biggestOld = oldest
}
if oldest < smallestOld {
smallestOld = oldest
}
}
return tsMap, biggestNew, smallestNew, biggestOld, smallestOld, nil
}
func IsCollectionCappedError(err error) bool {
errMsg := err.Error()
if strings.Contains(errMsg, CollectionCapped) || strings.Contains(errMsg, CollectionCappedLowVersion) {
return true
}
return false
}
func FindFirstErrorIndexAndMessageN(err error) (int, string, bool) {
if err == nil {
return 0, "", false
}
bwError, ok := err.(mongo.BulkWriteException)
if ok == false {
return 0, "", false
}
wError := bwError.WriteErrors
if len(wError) == 0 {
return 0, "", false
}
if wError[0].HasErrorCode(11000) {
return wError[0].Index, wError[0].Message, true
}
return wError[0].Index, wError[0].Message, false
}
func GetListCollectionQueryCondition(conn *MongoCommunityConn) bson.M {
// "collection", "timeseries", 3.4 start to support views
versionOk, _ := GetAndCompareVersion(conn, "3.4.0", "")
queryConditon := bson.M{}
if versionOk {
// 改成 not
queryConditon = bson.M{"type": bson.M{"$in": bson.A{"collection", "timeseries"}}}
}
return queryConditon
}
/**
* return db namespace. return:
* @[]NS: namespace list, e.g., []{"a.b", "a.c"}
* @map[string][]string: db->collection map. e.g., "a"->[]string{"b", "c"}
* @error: error info
*/
func GetDbNamespace(url string, filterFunc func(name string) bool, sslRootFile string) ([]NS, map[string][]string, error) {
var err error
var conn *MongoCommunityConn
if conn, err = NewMongoCommunityConn(url, VarMongoConnectModePrimary, true,
ReadWriteConcernLocal, ReadWriteConcernDefault, sslRootFile); conn == nil || err != nil {
return nil, nil, err
}
defer conn.Close()
queryConditon := GetListCollectionQueryCondition(conn)
var dbNames []string
if dbNames, err = conn.Client.ListDatabaseNames(nil, bson.M{}); err != nil {
err = fmt.Errorf("get database names of mongodb[%s] error: %v", url, err)
return nil, nil, err
}
// sort by db names
sort.Strings(dbNames)
LOG.Debug("dbNames:%v queryConditon:%v", dbNames, queryConditon)
nsList := make([]NS, 0, 128)
for _, db := range dbNames {
colNames, err := conn.Client.Database(db).ListCollectionNames(nil, queryConditon)
if err != nil {
err = fmt.Errorf("get collection names of mongodb[%s] db[%v] error: %v", url, db, err)
return nil, nil, err
}
LOG.Debug("db[%v] colNames: %v queryConditon:%v", db, colNames, queryConditon)
for _, col := range colNames {
ns := NS{Database: db, Collection: col}
if strings.HasPrefix(col, "system.") {
continue
}
if filterFunc != nil && filterFunc(ns.Str()) {
LOG.Debug("Namespace is filtered. %v", ns.Str())
continue
}
nsList = append(nsList, ns)
}
}
// copy, convert nsList to map
nsMap := make(map[string][]string, 0)
for _, ns := range nsList {
if _, ok := nsMap[ns.Database]; !ok {
nsMap[ns.Database] = make([]string, 0)
}
nsMap[ns.Database] = append(nsMap[ns.Database], ns.Collection)
}
return nsList, nsMap, nil
}
/**
* return all namespace. return:
* @map[NS]struct{}: namespace set where key is the namespace while value is useless, e.g., "a.b"->nil, "a.c"->nil
* @map[string][]string: db->collection map. e.g., "a"->[]string{"b", "c"}
* @error: error info
*/
func GetAllNamespace(sources []*MongoSource, filterFunc func(name string) bool,
sslRootFile string) (map[NS]struct{}, map[string][]string, error) {
nsSet := make(map[NS]struct{})
for _, src := range sources {
nsList, _, err := GetDbNamespace(src.URL, filterFunc, sslRootFile)
if err != nil {
return nil, nil, err
}
for _, ns := range nsList {
nsSet[ns] = struct{}{}
}
}
// copy
nsMap := make(map[string][]string, len(sources))
for ns := range nsSet {
if _, ok := nsMap[ns.Database]; !ok {
nsMap[ns.Database] = make([]string, 0)
}
nsMap[ns.Database] = append(nsMap[ns.Database], ns.Collection)
}
return nsSet, nsMap, nil
}