sources/cassandra/validations/count/count.go (360 lines of code) (raw):

package main import ( "context" "flag" "fmt" "math/big" "os" "sort" "strings" "sync" "time" "cloud.google.com/go/spanner" "github.com/gocql/gocql" "google.golang.org/api/iterator" ) // This script connects to both Cassandra and Spanner databases to count the number of rows in specified tables. // It allows for parallel processing of the counting operation using multiple workers. // // Sample usage: // go run count.go -host localhost -port 9042 -keyspace my_keyspace -spanner-uri projects/my_project/instances/my_instance/databases/my_database func main() { // Define command line arguments host := flag.String("host", "localhost", "Cassandra host") port := flag.Int("port", 9042, "Cassandra port") keyspace := flag.String("keyspace", "", "Keyspace name") username := flag.String("username", "", "Cassandra username") password := flag.String("password", "", "Cassandra password") table := flag.String("table", "", "Table name (empty for all tables, or comma-separated list of tables)") workers := flag.Int("workers", 8, "Number of parallel workers") spannerURI := flag.String("spanner-uri", "", "Spanner database URI (projects/PROJECT_ID/instances/INSTANCE_ID/databases/DATABASE_ID)") flag.Parse() if *keyspace == "" { fmt.Println("keyspace name is required") os.Exit(1) } if *spannerURI == "" { fmt.Println("spanner-uri is required") os.Exit(1) } // Connect to Cassandra cassSession, err := connectToCassandra(*host, *port, *keyspace, *username, *password) if err != nil { fmt.Printf("Failed to connect to Cassandra: %v\n", err) os.Exit(1) } defer cassSession.Close() // Connect to Spanner ctx := context.Background() spannerClient, err := connectToSpanner(ctx, *spannerURI) if err != nil { fmt.Printf("Failed to connect to Spanner: %v\n", err) os.Exit(1) } defer spannerClient.Close() tables := []string{} if *table != "" { tables = strings.Split(*table, ",") for i, t := range tables { tables[i] = strings.TrimSpace(t) } } else { cassandraTables, err := getCassandraTables(cassSession, *keyspace) if err != nil { fmt.Printf("Error fetching Cassandra tables: %v\n", err) os.Exit(1) } fmt.Printf("Found %d tables in Cassandra: %v\n", len(cassandraTables), cassandraTables) spannerTables, err := getSpannerTables(ctx, spannerClient) if err != nil { fmt.Printf("Error fetching Spanner tables: %v\n", err) os.Exit(1) } fmt.Printf("Found %d tables in Spanner: %v\n", len(spannerTables), spannerTables) spannerTableSet := make(map[string]struct{}, len(spannerTables)) for _, t := range spannerTables { spannerTableSet[t] = struct{}{} } for _, cassTable := range cassandraTables { if _, found := spannerTableSet[cassTable]; found { tables = append(tables, cassTable) } } if len(tables) == 0 { fmt.Println("No common tables found between Cassandra and Spanner.") os.Exit(0) } } fmt.Printf("Getting counts for tables: %v\n", tables) // Get token ranges for the cluster // TODO: Consider generating custom token ranges based on size estimates instead of relying on cassandra partitions. tokenRanges, err := getClusterTokenRanges(cassSession) if err != nil { fmt.Printf("Error fetching token ranges: %v\n", err) os.Exit(1) } fmt.Printf("Found %d token ranges across the cluster\n", len(tokenRanges)) // TODO: Consider parallelizing across tables. for _, tableName := range tables { fmt.Printf("\nTable: %s\n", tableName) fmt.Printf("----------------------------------------\n") result := countBothDatabases(ctx, cassSession, spannerClient, *keyspace, tableName, tokenRanges, *workers) // Print Cassandra results if result.CassandraError != nil { fmt.Printf(" Cassandra count: ERROR - %v\n", result.CassandraError) } else { fmt.Printf(" Cassandra count: %d\n", result.CassandraCount) } // Print Spanner results if result.SpannerError != nil { fmt.Printf(" Spanner count: ERROR - %v\n", result.SpannerError) } else { fmt.Printf(" Spanner count: %d\n", result.SpannerCount) } fmt.Printf("----------------------------------------\n") } } // getCassandraTables fetches all user table names from the specified keyspace in Cassandra. func getCassandraTables(session *gocql.Session, keyspace string) ([]string, error) { var tables []string iter := session.Query(`SELECT table_name FROM system_schema.tables WHERE keyspace_name = ?`, keyspace).Iter() var tableName string for iter.Scan(&tableName) { tables = append(tables, tableName) } if err := iter.Close(); err != nil { return nil, fmt.Errorf("querying system_schema.tables failed: %w", err) } return tables, nil } // getSpannerTables fetches all user table names from the connected Spanner database. func getSpannerTables(ctx context.Context, client *spanner.Client) ([]string, error) { var tables []string stmt := spanner.Statement{SQL: `SELECT table_name FROM INFORMATION_SCHEMA.TABLES WHERE table_catalog = '' AND table_schema = ''`} iter := client.Single().Query(ctx, stmt) defer iter.Stop() for { row, err := iter.Next() if err == iterator.Done { break } if err != nil { return nil, fmt.Errorf("querying INFORMATION_SCHEMA.TABLES failed: %w", err) } var tableName string if err := row.Columns(&tableName); err != nil { return nil, fmt.Errorf("reading table name from Spanner result failed: %w", err) } tables = append(tables, tableName) } return tables, nil } // TokenRange represents a Cassandra token range type TokenRange struct { Start string End string } // TableCount holds row counts for both Cassandra and Spanner tables type TableCount struct { TableName string CassandraCount int64 SpannerCount int64 CassandraError error SpannerError error } // connectToCassandra establishes a connection to the Cassandra cluster func connectToCassandra(host string, port int, keyspace, username, password string) (*gocql.Session, error) { cluster := gocql.NewCluster(host) cluster.Port = port cluster.Keyspace = keyspace cluster.Consistency = gocql.Quorum cluster.Timeout = 30 * time.Second // Add authentication if credentials are provided if username != "" { cluster.Authenticator = gocql.PasswordAuthenticator{ Username: username, Password: password, } } return cluster.CreateSession() } // connectToSpanner establishes a connection to the Spanner database func connectToSpanner(ctx context.Context, uri string) (*spanner.Client, error) { return spanner.NewClient(ctx, uri) } // countSpannerRows counts the total number of rows in a Spanner table func countSpannerRows(ctx context.Context, client *spanner.Client, table string) (int64, error) { stmt := spanner.Statement{ SQL: fmt.Sprintf("SELECT COUNT(*) FROM %s", table), } iter := client.Single().Query(ctx, stmt) defer iter.Stop() row, err := iter.Next() if err == iterator.Done { return 0, fmt.Errorf("no results returned from count query") } if err != nil { return 0, fmt.Errorf("error executing count query: %w", err) } var count int64 if err := row.Columns(&count); err != nil { return 0, fmt.Errorf("error scanning count result: %w", err) } return count, nil } // getClusterTokenRanges retrieves and processes token ranges from all nodes in the Cassandra cluster. // It combines tokens from both local and peer nodes, deduplicates them, and creates a sorted list of // token ranges that cover the entire token ring. // TODO: Consider creating configurable number of partitions instead of system ranges. func getClusterTokenRanges(session *gocql.Session) ([]TokenRange, error) { localTokens, err := getLocalNodeTokens(session) if err != nil { return nil, fmt.Errorf("failed to get local node tokens: %w", err) } peerTokens, err := getPeerNodeTokens(session) if err != nil { return nil, fmt.Errorf("failed to get peer node tokens: %w", err) } // 1. Combine and Deduplicate Tokens allTokens := append(localTokens, peerTokens...) uniqueTokens := distinct(allTokens) // 2. Map to Big Ints and Sort bigIntTokens := make([]*big.Int, len(uniqueTokens)) for i, tokenStr := range uniqueTokens { bigIntToken := new(big.Int) if _, ok := bigIntToken.SetString(tokenStr, 10); !ok { return nil, fmt.Errorf("invalid token string: %s", tokenStr) } bigIntTokens[i] = bigIntToken } sort.Slice(bigIntTokens, func(i, j int) bool { return bigIntTokens[i].Cmp(bigIntTokens[j]) < 0 }) // 3. Create Token Ranges tokenRanges := make([]TokenRange, len(bigIntTokens)+1) // First range tokenRanges[0] = TokenRange{Start: "nil", End: bigIntTokens[0].String()} // Middle ranges for i := 1; i < len(bigIntTokens); i++ { tokenRanges[i].Start = bigIntTokens[i-1].String() tokenRanges[i].End = bigIntTokens[i].String() } // Last range lastIndex := len(tokenRanges) - 1 tokenRanges[lastIndex] = TokenRange{Start: bigIntTokens[lastIndex-1].String(), End: "nil"} return tokenRanges, nil } // distinct takes a slice of strings and returns a new slice containing only unique elements, // removing any duplicates from the input slice. func distinct(tokens []string) []string { seen := make(map[string]bool) unique := []string{} for _, token := range tokens { if _, ok := seen[token]; !ok { seen[token] = true unique = append(unique, token) } } return unique } // getLocalNodeTokens retrieves the token ranges assigned to the local Cassandra node. // It queries the system.local table to get the tokens for the current node. func getLocalNodeTokens(session *gocql.Session) ([]string, error) { var tokens []string // Corrected: Should be a slice of strings err := session.Query("SELECT tokens FROM system.local").Scan(&tokens) if err != nil { return nil, err } // No need to trim or split; the driver handles this for slices return tokens, nil } // getPeerNodeTokens retrieves the token ranges assigned to all peer nodes in the Cassandra cluster. // It queries the system.peers_v2 table to get tokens for all other nodes in the cluster. func getPeerNodeTokens(session *gocql.Session) ([]string, error) { var allTokens []string iter := session.Query("SELECT tokens FROM system.peers_v2").Iter() // Use peers_v2 for modern Cassandra var tokens []string // Corrected: Scan into a slice for iter.Scan(&tokens) { allTokens = append(allTokens, tokens...) // Append directly; no splitting needed } if err := iter.Close(); err != nil { return nil, err } return allTokens, nil } // getPartitionKeysFromMetadata retrieves the partition key columns for a given table from Cassandra's metadata. // It returns the partition key column names in the correct order as defined in the table schema. func getPartitionKeysFromMetadata(session *gocql.Session, keyspaceName, tableName string) ([]string, error) { // Get keyspace metadata which contains table information keyspaceMetadata, err := session.KeyspaceMetadata(keyspaceName) if err != nil { return nil, fmt.Errorf("failed to get keyspace metadata: %w", err) } // Get table metadata table, ok := keyspaceMetadata.Tables[tableName] if !ok { return nil, fmt.Errorf("table %s not found in keyspace %s", tableName, keyspaceName) } // Get partition key columns in correct order var partitionKeyColumns []string for _, column := range table.PartitionKey { partitionKeyColumns = append(partitionKeyColumns, column.Name) } if len(partitionKeyColumns) == 0 { return nil, fmt.Errorf("no partition keys found for table %s", tableName) } return partitionKeyColumns, nil } // countTableRows counts rows in both Cassandra and Spanner concurrently func countBothDatabases(ctx context.Context, cassSession *gocql.Session, spannerClient *spanner.Client, keyspace, tableName string, tokenRanges []TokenRange, workers int) TableCount { result := TableCount{TableName: tableName} var wg sync.WaitGroup wg.Add(2) // Count Cassandra rows in parallel go func() { defer wg.Done() count, err := countCassandraRows(cassSession, keyspace, tableName, tokenRanges, workers) result.CassandraCount = count result.CassandraError = err }() // Count Spanner rows in parallel go func() { defer wg.Done() count, err := countSpannerRows(ctx, spannerClient, tableName) result.SpannerCount = count result.SpannerError = err }() wg.Wait() return result } // countCassandraRows counts the total number of rows in a Cassandra table by getting partition keys // and using token ranges for parallel counting. func countCassandraRows(session *gocql.Session, keyspace, tableName string, tokenRanges []TokenRange, workers int) (int64, error) { // Get table metadata including partition keys partitionKeys, err := getPartitionKeysFromMetadata(session, keyspace, tableName) if err != nil { return 0, fmt.Errorf("error getting partition keys: %w", err) } // Count rows using token ranges return countTableRows(session, keyspace, tableName, partitionKeys, tokenRanges, workers) } // countTableRows counts the total number of rows in a table by dividing the work across multiple token ranges // and processing them in parallel using a worker pool. It aggregates the results from all workers and handles // any errors that occur during the counting process. func countTableRows(session *gocql.Session, keyspace, table string, partitionKeys []string, tokenRanges []TokenRange, workers int) (int64, error) { // Create channels for work distribution and results workChan := make(chan TokenRange, len(tokenRanges)) resultChan := make(chan int64, len(tokenRanges)) errChan := make(chan error, workers) // Start worker pool var wg sync.WaitGroup for i := 0; i < workers; i++ { wg.Add(1) go func() { defer wg.Done() for tr := range workChan { count, err := countInRange(session, keyspace, table, partitionKeys, tr) if err != nil { errChan <- err return } resultChan <- count } }() } // Send token ranges to worker pool for _, tr := range tokenRanges { workChan <- tr } close(workChan) // Wait for all workers to finish and close result channel go func() { wg.Wait() close(resultChan) close(errChan) }() // Collect results and check for errors var totalCount int64 for count := range resultChan { totalCount += count } // Check for errors for err := range errChan { if err != nil { return 0, err // Return the *first* error encountered } } return totalCount, nil } // countInRange counts the number of rows within a specific token range for a given table. // It constructs a query using the TOKEN function to filter rows based on the partition key's token value // falling within the specified range. func countInRange(session *gocql.Session, keyspace, table string, partitionKeys []string, tr TokenRange) (int64, error) { ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) defer cancel() // Build partition key string for TOKEN function partitionKeyStr := partitionKeys[0] if len(partitionKeys) > 1 { partitionKeyStr = fmt.Sprintf("%s", strings.Join(partitionKeys, ", ")) } // Build query with dynamic token range filter queryBase := fmt.Sprintf("SELECT COUNT(*) FROM %s.%s", keyspace, table) whereClauses := []string{} var args []interface{} if tr.Start != "nil" { whereClauses = append(whereClauses, fmt.Sprintf("TOKEN(%s) >= ?", partitionKeyStr)) args = append(args, tr.Start) } if tr.End != "nil" { whereClauses = append(whereClauses, fmt.Sprintf("TOKEN(%s) < ?", partitionKeyStr)) args = append(args, tr.End) } query := queryBase if len(whereClauses) > 0 { query += " WHERE " + strings.Join(whereClauses, " AND ") } var count int64 err := session.Query(query, args...).WithContext(ctx).Scan(&count) if err != nil { return 0, fmt.Errorf("range query failed: %w", err) } return count, nil }