func main()

in sources/cassandra/validations/count/count.go [24:134]


func main() {
	// Define command line arguments
	host := flag.String("host", "localhost", "Cassandra host")
	port := flag.Int("port", 9042, "Cassandra port")
	keyspace := flag.String("keyspace", "", "Keyspace name")
	username := flag.String("username", "", "Cassandra username")
	password := flag.String("password", "", "Cassandra password")
	table := flag.String("table", "", "Table name (empty for all tables, or comma-separated list of tables)")
	workers := flag.Int("workers", 8, "Number of parallel workers")
	spannerURI := flag.String("spanner-uri", "", "Spanner database URI (projects/PROJECT_ID/instances/INSTANCE_ID/databases/DATABASE_ID)")
	flag.Parse()

	if *keyspace == "" {
		fmt.Println("keyspace name is required")
		os.Exit(1)
	}

	if *spannerURI == "" {
		fmt.Println("spanner-uri is required")
		os.Exit(1)
	}

	// Connect to Cassandra
	cassSession, err := connectToCassandra(*host, *port, *keyspace, *username, *password)
	if err != nil {
		fmt.Printf("Failed to connect to Cassandra: %v\n", err)
		os.Exit(1)
	}
	defer cassSession.Close()

	// Connect to Spanner
	ctx := context.Background()
	spannerClient, err := connectToSpanner(ctx, *spannerURI)
	if err != nil {
		fmt.Printf("Failed to connect to Spanner: %v\n", err)
		os.Exit(1)
	}
	defer spannerClient.Close()

	tables := []string{}
	if *table != "" {
		tables = strings.Split(*table, ",")
		for i, t := range tables {
			tables[i] = strings.TrimSpace(t)
		}
	} else {
		cassandraTables, err := getCassandraTables(cassSession, *keyspace)
		if err != nil {
			fmt.Printf("Error fetching Cassandra tables: %v\n", err)
			os.Exit(1)
		}
		fmt.Printf("Found %d tables in Cassandra: %v\n", len(cassandraTables), cassandraTables)

		spannerTables, err := getSpannerTables(ctx, spannerClient)
		if err != nil {
			fmt.Printf("Error fetching Spanner tables: %v\n", err)
			os.Exit(1)
		}
		fmt.Printf("Found %d tables in Spanner: %v\n", len(spannerTables), spannerTables)

		spannerTableSet := make(map[string]struct{}, len(spannerTables))
		for _, t := range spannerTables {
			spannerTableSet[t] = struct{}{}
		}

		for _, cassTable := range cassandraTables {
			if _, found := spannerTableSet[cassTable]; found {
				tables = append(tables, cassTable)
			}
		}

		if len(tables) == 0 {
			fmt.Println("No common tables found between Cassandra and Spanner.")
			os.Exit(0)
		}
	}

	fmt.Printf("Getting counts for tables: %v\n", tables)

	// Get token ranges for the cluster
	// TODO: Consider generating custom token ranges based on size estimates instead of relying on cassandra partitions.
	tokenRanges, err := getClusterTokenRanges(cassSession)
	if err != nil {
		fmt.Printf("Error fetching token ranges: %v\n", err)
		os.Exit(1)
	}
	fmt.Printf("Found %d token ranges across the cluster\n", len(tokenRanges))

	// TODO: Consider parallelizing across tables.
	for _, tableName := range tables {
		fmt.Printf("\nTable: %s\n", tableName)
		fmt.Printf("----------------------------------------\n")

		result := countBothDatabases(ctx, cassSession, spannerClient, *keyspace, tableName, tokenRanges, *workers)

		// Print Cassandra results
		if result.CassandraError != nil {
			fmt.Printf("  Cassandra count: ERROR - %v\n", result.CassandraError)
		} else {
			fmt.Printf("  Cassandra count: %d\n", result.CassandraCount)
		}

		// Print Spanner results
		if result.SpannerError != nil {
			fmt.Printf("  Spanner count:   ERROR - %v\n", result.SpannerError)
		} else {
			fmt.Printf("  Spanner count:   %d\n", result.SpannerCount)
		}
		fmt.Printf("----------------------------------------\n")
	}
}