sources/cassandra/validations/validation/validation.go (365 lines of code) (raw):
package main
import (
"flag"
"fmt"
"log"
"math/rand"
"os"
"strings"
"sync"
"time"
"github.com/gocql/gocql"
spanner "github.com/googleapis/go-spanner-cassandra/cassandra/gocql"
)
var (
sourceHost = flag.String("source-host", "localhost", "Source Cassandra host")
sourcePort = flag.Int("source-port", 9042, "Source Cassandra port")
sourceUsername = flag.String("source-username", "", "Source Cassandra username")
sourcePassword = flag.String("source-password", "", "Source Cassandra password")
spannerURI = flag.String("spanner-uri", "", "Spanner database URI (projects/PROJECT_ID/instances/INSTANCE_ID/databases/DATABASE_ID)")
keyspace = flag.String("keyspace", "", "Keyspace to validate")
table = flag.String("table", "", "Table to validate (mandatory)")
batchSize = flag.Int("batch-size", 100000, "Specifies how many rows to read and validate at a time. For sampling mode, it reads the specified number of rows during each sample.")
workers = flag.Int("workers", 1, "Number of parallel workers for validation")
samplingMode = flag.Bool("sampling-mode", false, "Validate a sample of rows instead of full matching")
numSampleRows = flag.Int("num-sample-rows", 0, "Number of rows to sample for validation (0 for indefinite sampling)")
totalRowsProcessed int
totalMismatchesFound int
totalMissingFound int
totalErrorsDuringMatching int
mu sync.Mutex
)
/*
This tool validates data consistency between source and target Cassandra clusters by comparing rows.
It supports two modes of operation:
1. Full validation mode (default): Validates all rows in the specified table
2. Sampling mode: Validates a random sample of rows, recommended for large datasets
Sample usage:
go run validation.go \
--source-host localhost --source-port 9042 --source-username user1 --source-password pass1 \
--spanner-uri projects/PROJECT_ID/instances/INSTANCE_ID/databases/DATABASE_ID \
--keyspace my_keyspace --table my_table \
--batch-size 100000 --workers 4
For sampling mode (recommended for large datasets):
go run validation.go \
--source-host localhost --target-host remote-host \
--keyspace my_keyspace --table my_table \
--sampling-mode --num-sample-rows 1000000
*/
func main() {
flag.Parse()
if *table == "" {
log.Fatal("The --table flag is mandatory.")
}
if *keyspace == "" {
log.Fatal("The --keyspace flag is mandatory.")
}
if !*samplingMode && *numSampleRows != 0 {
log.Fatal("The --sample-rows flag can only be specified when --sampling-mode is true.")
}
sourceCluster := gocql.NewCluster(*sourceHost)
sourceCluster.Port = *sourcePort
sourceCluster.Keyspace = *keyspace
if *sourceUsername != "" {
sourceCluster.Authenticator = gocql.PasswordAuthenticator{Username: *sourceUsername, Password: *sourcePassword}
}
sourceSession, err := sourceCluster.CreateSession()
if err != nil {
log.Fatalf("Error creating source session: %v", err)
}
defer sourceSession.Close()
opts := &spanner.Options{
DatabaseUri: *spannerURI,
}
targetCluster := spanner.NewCluster(opts)
// Important to close the adapter's resources
defer spanner.CloseCluster(targetCluster)
targetSession, err := targetCluster.CreateSession()
if err != nil {
log.Fatalf("Error creating target session: %v", err)
}
defer targetSession.Close()
// TODO: Verify table exists on target once system query is fixed.
err = verifyTableExists(sourceSession, *keyspace, *table)
if err != nil {
log.Fatalf("Source table '%s.%s' does not exist: %v", *keyspace, *table, err)
os.Exit(1)
}
pkColumns, partitionKeyCount, err := getPrimaryKeyColumns(sourceSession)
if err != nil {
log.Fatalf("Error getting primary key columns: %v", err)
}
log.Printf("Primary Key Columns: %v", pkColumns)
allColumns, err := getAllColumns(sourceSession)
if err != nil {
log.Fatalf("Error getting all columns: %v", err)
}
log.Printf("All Columns: %v", allColumns)
allColumnsStr := strings.Join(allColumns, ", ")
targetQuery := fmt.Sprintf("SELECT %s FROM %s.%s WHERE ", allColumnsStr, *keyspace, *table)
var whereClauses []string
for _, pkColumn := range pkColumns {
whereClauses = append(whereClauses, fmt.Sprintf("%s = ?", pkColumn))
}
targetQuery += strings.Join(whereClauses, " AND ")
log.Println("Target Query string: ", targetQuery)
if !*samplingMode {
validateEntireDataset(sourceSession, targetSession, targetQuery, pkColumns)
} else {
validateViaSampling(sourceSession, targetSession, targetQuery, pkColumns, partitionKeyCount)
}
log.Println("Validation complete.")
log.Printf("Total rows processed: %d\n", totalRowsProcessed)
log.Printf("Total errors during matching: %d\n", totalErrorsDuringMatching)
log.Printf("Total mismatches found: %d\n", totalMismatchesFound)
log.Printf("Total missing found: %d\n", totalMissingFound)
}
// MismatchDetail holds detail about a single mismatched row
type MismatchDetail struct {
Key map[string]interface{}
SourceRow map[string]interface{}
TargetRow map[string]interface{}
MissingInTarget bool
}
func verifyTableExists(session *gocql.Session, keyspace string, table string) error {
query := "SELECT table_name FROM system_schema.tables WHERE keyspace_name = ? AND table_name = ?"
var tableName string
if err := session.Query(query, keyspace, table).Scan(&tableName); err != nil {
if err == gocql.ErrNotFound {
return fmt.Errorf("table %s does not exist in keyspace %s", table, keyspace)
}
return fmt.Errorf("error checking if table exists: %v", err)
}
return nil
}
// getPrimaryKeyColumns retrieves both partition key and clustering columns for a table.
// It queries system tables to get column information and returns them in the correct order.
// Returns:
// - []string: ordered list of primary key columns (partition keys followed by clustering keys)
// - int: number of partition key columns
// - error: any error encountered during the process
func getPrimaryKeyColumns(session *gocql.Session) ([]string, int, error) {
var pkColumns []string
// Query to get partition key columns
partitionKeyQuery := `
SELECT column_name, position
FROM system_schema.columns
WHERE keyspace_name = ? AND table_name = ? AND kind = 'partition_key'
ALLOW FILTERING
`
iter := session.Query(partitionKeyQuery, *keyspace, *table).Iter()
partitionKeys := make(map[int]string)
var columnName string
var position int
// Scan with position for later sorting
for iter.Scan(&columnName, &position) {
partitionKeys[position] = columnName
}
if err := iter.Close(); err != nil {
return nil, 0, err
}
// Order partition keys by position
for i := 0; i < len(partitionKeys); i++ {
if key, ok := partitionKeys[i]; ok {
pkColumns = append(pkColumns, key)
}
}
// Query to get clustering columns
clusteringKeyQuery := `
SELECT column_name, position
FROM system_schema.columns
WHERE keyspace_name = ? AND table_name = ? AND kind = 'clustering'
ALLOW FILTERING
`
iter = session.Query(clusteringKeyQuery, *keyspace, *table).Iter()
clusteringKeys := make(map[int]string)
// Scan with position for later sorting
for iter.Scan(&columnName, &position) {
clusteringKeys[position] = columnName
}
if err := iter.Close(); err != nil {
return nil, 0, err
}
// Order clustering keys by position
for i := 0; i < len(clusteringKeys); i++ {
if key, ok := clusteringKeys[i]; ok {
pkColumns = append(pkColumns, key)
}
}
if len(pkColumns) == 0 {
return nil, 0, fmt.Errorf("no primary key columns found, please verify keyspace name '%s' and table name '%s' are correct", *keyspace, *table)
}
return pkColumns, len(partitionKeys), nil
}
// getAllColumns retrieves all column names for the specified table.
// This includes both primary key and non-primary key columns.
// Returns:
// - []string: list of all column names in the table
// - error: any error encountered during the query
func getAllColumns(session *gocql.Session) ([]string, error) {
var columns []string
query := `
SELECT column_name
FROM system_schema.columns
WHERE keyspace_name = ? AND table_name = ?
`
iter := session.Query(query, *keyspace, *table).Iter()
var columnName string
for iter.Scan(&columnName) {
columns = append(columns, columnName)
}
if err := iter.Close(); err != nil {
return nil, err
}
return columns, nil
}
// validateViaSampling performs data validation using random sampling of rows.
// It uses token-based sampling to randomly select rows from the source cluster
// and compares them with corresponding rows in the target cluster.
// This mode is recommended for large datasets where full validation is impractical.
func validateViaSampling(sourceSession *gocql.Session, targetSession *gocql.Session, targetQuery string, pkColumns []string, partitionKeyCount int) {
// Build partition key columns string for token function, which does not expect the entire pk but only the partition key.
partitionKeyStr := strings.Join(pkColumns[:partitionKeyCount], ", ")
// Build sampling query with token
samplingQuery := fmt.Sprintf("SELECT * FROM %s.%s WHERE token(%s) > ? LIMIT %d",
*keyspace, *table, partitionKeyStr, *batchSize)
log.Println("Sampling Query string: ", samplingQuery)
for {
// Generate random token value using time seed.
r := rand.New(rand.NewSource(time.Now().UnixNano()))
randToken := r.Int63()
// Get batch of random rows
iter := sourceSession.Query(samplingQuery, randToken).Iter()
row := make(map[string]interface{})
rows := make([]map[string]interface{}, 0)
for iter.MapScan(row) {
rows = append(rows, row)
row = make(map[string]interface{})
}
if err := iter.Close(); err != nil {
log.Fatalf("Error iterating through rows: %v", err)
}
if len(rows) == 0 {
fmt.Println("No rows found...")
continue
}
processBatch(rows, targetQuery, targetSession, pkColumns)
// If numSampleRows is set and we've processed enough rows, stop
if *numSampleRows != 0 && *&totalRowsProcessed >= *numSampleRows {
break
}
}
}
// validateEntireDataset performs a full validation of all rows in the source table
// against the target table. It reads rows in batches to manage memory usage
// and processes them in parallel using multiple workers.
func validateEntireDataset(sourceSession *gocql.Session, targetSession *gocql.Session, targetQuery string, pkColumns []string) {
sourceQuery := fmt.Sprintf("SELECT * FROM %s.%s", *keyspace, *table)
iter := sourceSession.Query(sourceQuery).Iter()
row := make(map[string]interface{})
rows := make([]map[string]interface{}, 0)
for iter.MapScan(row) {
rows = append(rows, row)
row = make(map[string]interface{})
if len(rows) == *batchSize {
processBatch(rows, targetQuery, targetSession, pkColumns)
rows = make([]map[string]interface{}, 0)
}
}
if len(rows) > 0 {
processBatch(rows, targetQuery, targetSession, pkColumns)
}
if err := iter.Close(); err != nil {
log.Fatalf("Error iterating through rows: %v", err)
}
}
// processBatch handles the validation of a batch of rows and updates global statistics.
// It coordinates the parallel validation of rows and aggregates the results.
// Parameters:
// - rows: batch of rows to validate
// - targetQuery: prepared query string for fetching rows from target
// - targetSession: connection to target cluster
// - pkColumns: list of primary key columns used for row lookup
func processBatch(rows []map[string]interface{}, targetQuery string, targetSession *gocql.Session, pkColumns []string) {
errors, mismatches, missing := validateRows(rows, targetQuery, targetSession, pkColumns)
totalRowsProcessed += len(rows)
totalMismatchesFound += mismatches
totalMissingFound += missing
totalErrorsDuringMatching += errors
log.Printf("Processed: %d more rows, found %d errors, %d missing, %d mismatches\n",
len(rows), errors, missing, mismatches)
log.Printf("Total rows processed: %d, Total errors: %d, Total missing: %d, Total mismatches: %d\n",
totalRowsProcessed, totalErrorsDuringMatching, totalMissingFound, totalMismatchesFound)
}
// validateRows performs parallel validation of source rows against target cluster.
// It distributes the work across multiple goroutines for better performance.
// Returns counts of errors, mismatches, and missing rows encountered during validation.
// Parameters:
// - sourceRows: rows from source cluster to validate
// - targetQuery: prepared query string for fetching rows from target
// - targetSession: connection to target cluster
// - pkColumns: list of primary key columns used for row lookup
func validateRows(sourceRows []map[string]interface{}, targetQuery string, targetSession *gocql.Session, pkColumns []string) (int, int, int) {
errors := 0
mismatches := 0
missing := 0
var wg sync.WaitGroup
rowChan := make(chan map[string]interface{}, len(sourceRows))
for _, row := range sourceRows {
rowChan <- row
}
close(rowChan)
for i := 0; i < *workers; i++ {
wg.Add(1)
go func() {
defer wg.Done()
for sourceRow := range rowChan {
var queryArgs []interface{}
for _, col := range pkColumns {
queryArgs = append(queryArgs, sourceRow[col])
}
targetRow := make(map[string]interface{})
// TODO: Consider batched reads from Spanner instead of point reads per row.
if err := targetSession.Query(targetQuery, queryArgs...).MapScan(targetRow); err != nil {
if err == gocql.ErrNotFound {
mu.Lock()
missing++
fmt.Printf("MISSING: row in target for source row: %+v\n", sourceRow)
mu.Unlock()
} else {
fmt.Printf("ERROR: got error %v while reading target row for source: %+v\n", err, sourceRow)
mu.Lock()
errors++
mu.Unlock()
}
continue
}
var mismatchDetails MismatchDetail
diffFound, mismatchDetails := diffRows(sourceRow, targetRow)
if !diffFound {
continue
}
if mismatchDetails.MissingInTarget {
mu.Lock()
missing++
fmt.Printf("MISSING: row in target for source row: %+v\n", sourceRow)
mu.Unlock()
} else {
mu.Lock()
mismatches++
fmt.Printf("MISMATCH: found for row: %+v\n", mismatchDetails)
mu.Unlock()
}
}
}()
}
wg.Wait()
return errors, mismatches, missing
}
// diffRows compares a row from the source cluster with its corresponding row in the target cluster.
// It checks for both missing rows and value mismatches in all columns.
// Returns:
// - bool: true if any difference is found
// - MismatchDetail: details of the mismatch if found
func diffRows(sourceRow, targetRow map[string]interface{}) (bool, MismatchDetail) {
// If target row is nil, it's missing in target
if targetRow == nil || len(targetRow) == 0 {
return true, MismatchDetail{
SourceRow: sourceRow,
TargetRow: nil,
MissingInTarget: true,
}
}
// Compare all column values
mismatch := false
for col, sourceVal := range sourceRow {
targetVal, exists := targetRow[col]
if !exists || !compareValues(sourceVal, targetVal) {
mismatch = true
break
}
}
if mismatch {
return true, MismatchDetail{
SourceRow: sourceRow,
TargetRow: targetRow,
MissingInTarget: false,
}
}
return false, MismatchDetail{}
}
// compareValues performs deep comparison of two values that may be of different types.
// It handles special cases for various data types including:
// - nil values
// - byte arrays
// - maps
// - slices
// For other types, it falls back to string representation comparison.
// Returns true if the values are equal, false otherwise.
func compareValues(v1, v2 interface{}) bool {
// Handle nil values
if v1 == nil && v2 == nil {
return true
}
if v1 == nil || v2 == nil {
return false
}
// Compare different types of values
switch val1 := v1.(type) {
case []byte:
if val2, ok := v2.([]byte); ok {
if len(val1) != len(val2) {
return false
}
for i := range val1 {
if val1[i] != val2[i] {
return false
}
}
return true
}
case map[string]interface{}:
if val2, ok := v2.(map[string]interface{}); ok {
if len(val1) != len(val2) {
return false
}
for k, v := range val1 {
if !compareValues(v, val2[k]) {
return false
}
}
return true
}
case []interface{}:
if val2, ok := v2.([]interface{}); ok {
if len(val1) != len(val2) {
return false
}
for i := range val1 {
if !compareValues(val1[i], val2[i]) {
return false
}
}
return true
}
}
// For other types, use simple equality
return fmt.Sprintf("%v", v1) == fmt.Sprintf("%v", v2)
}