registry/datastore/testutil/testutil.go (354 lines of code) (raw):

package testutil import ( "context" "fmt" "io" "os" "path/filepath" "strconv" "strings" "testing" "time" "github.com/docker/distribution/configuration" "github.com/docker/distribution/registry/datastore" "github.com/redis/go-redis/v9" "github.com/sirupsen/logrus" "github.com/stretchr/testify/require" ) // table represents a table in the test database. type table string // trigger represents a trigger in the test database. type trigger struct { name string table table } const ( NamespacesTable table = "top_level_namespaces" RepositoriesTable table = "repositories" MediaTypesTable table = "media_types" ManifestsTable table = "manifests" ManifestReferencesTable table = "manifest_references" BlobsTable table = "blobs" RepositoryBlobsTable table = "repository_blobs" LayersTable table = "layers" TagsTable table = "tags" GCBlobReviewQueueTable table = "gc_blob_review_queue" GCBlobsConfigurationsTable table = "gc_blobs_configurations" GCBlobsLayersTable table = "gc_blobs_layers" GCManifestReviewQueueTable table = "gc_manifest_review_queue" GCTmpBlobsManifestsTable table = "gc_tmp_blobs_manifests" GCReviewAfterDefaultsTable table = "gc_review_after_defaults" BackgroundMigrationTable table = "batched_background_migrations" BackgroundMigrationJobsTable table = "batched_background_migration_jobs" ) // AllTables represents all tables in the test database. var ( AllTables = []table{ NamespacesTable, RepositoriesTable, ManifestsTable, ManifestReferencesTable, BlobsTable, RepositoryBlobsTable, LayersTable, TagsTable, GCBlobReviewQueueTable, GCBlobsConfigurationsTable, GCBlobsLayersTable, GCManifestReviewQueueTable, GCTmpBlobsManifestsTable, } GCTrackBlobUploadsTrigger = trigger{ name: "gc_track_blob_uploads_trigger", table: BlobsTable, } GCTrackConfigurationBlobsTrigger = trigger{ name: "gc_track_configuration_blobs_trigger", table: ManifestsTable, } GCTrackLayerBlobsTrigger = trigger{ name: "gc_track_layer_blobs_trigger", table: LayersTable, } GCTrackManifestUploadsTrigger = trigger{ name: "gc_track_manifest_uploads_trigger", table: ManifestsTable, } GCTrackDeletedManifestsTrigger = trigger{ name: "gc_track_deleted_manifests_trigger", table: ManifestsTable, } GCTrackDeletedLayersTrigger = trigger{ name: "gc_track_deleted_layers_trigger", table: LayersTable, } GCTrackDeletedManifestListsTrigger = trigger{ name: "gc_track_deleted_manifest_lists_trigger", table: ManifestReferencesTable, } GCTrackSwitchedTagsTrigger = trigger{ name: "gc_track_switched_tags_trigger", table: TagsTable, } GCTrackDeletedTagsTrigger = trigger{ name: "gc_track_deleted_tags_trigger", table: TagsTable, } ) // truncate truncates t in the test database. func (t table) truncate(db *datastore.DB) error { if _, err := db.Exec(fmt.Sprintf("TRUNCATE %s RESTART IDENTITY CASCADE", t)); err != nil { return fmt.Errorf("truncating table %q: %w", t, err) } return nil } // seedFileName generates the expected seed filename based on the convention `<table name>.sql`. func (t table) seedFileName() string { return fmt.Sprintf("%s.sql", t) } // DumpAsJSON dumps the table contents in JSON format using the PostgresSQL `json_agg` function. `bytea` columns are // automatically decoded for easy visualization/comparison. The output from each table is sorted for consistency. When // sorting, we use the `(top_level_namespace_id, id)` columns, the `digest` column or the `id` column, in this order // of preference for deterministic reasons. func (t table) DumpAsJSON(ctx context.Context, db datastore.Queryer) ([]byte, error) { var tmpl string switch t { case ManifestsTable: tmpl = `SELECT json_agg(t) FROM ( SELECT id, top_level_namespace_id, repository_id, created_at, total_size, schema_version, encode(digest, 'hex') AS digest, convert_from(payload, 'UTF8')::json AS payload, media_type_id, configuration_media_type_id, convert_from(configuration_payload, 'UTF8')::json AS configuration_payload, encode(configuration_blob_digest, 'hex') AS configuration_blob_digest, non_conformant, non_distributable_layers FROM %s ORDER BY (top_level_namespace_id, id)) t` case NamespacesTable: tmpl = "SELECT json_agg(t) FROM (SELECT * FROM %s ORDER BY id) t" case RepositoriesTable, ManifestReferencesTable, RepositoryBlobsTable, LayersTable, TagsTable, GCBlobsConfigurationsTable: tmpl = "SELECT json_agg(t) FROM (SELECT * FROM %s ORDER BY (top_level_namespace_id, id)) t" case GCManifestReviewQueueTable: tmpl = "SELECT json_agg(t) FROM (SELECT * FROM %s ORDER BY (top_level_namespace_id, repository_id)) t" case GCBlobsLayersTable: tmpl = "SELECT json_agg(t) FROM (SELECT * FROM %s ORDER BY id) t" default: tmpl = "SELECT json_agg(t) FROM (SELECT * FROM %s ORDER BY digest) t" } var dump []byte row := db.QueryRowContext(ctx, fmt.Sprintf(tmpl, t)) if err := row.Scan(&dump); err != nil { return nil, err } return dump, nil } // Disable disables a trigger in the test database. Returns a function that can be deferred to re-enable the trigger. func (t trigger) Disable(db *datastore.DB) (func() error, error) { _, err := db.Exec(fmt.Sprintf("ALTER TABLE %s DISABLE TRIGGER %s", t.table, t.name)) return func() error { _, err := db.Exec(fmt.Sprintf("ALTER TABLE %s ENABLE TRIGGER %s", t.table, t.name)) return err }, err } // NewDSNFromEnv generates a new DSN for the test database based on environment variable configurations. func NewDSNFromEnv() (*datastore.DSN, error) { port, err := strconv.Atoi(os.Getenv("REGISTRY_DATABASE_PORT")) if err != nil { return nil, fmt.Errorf("parsing DSN port: %w", err) } dsn := &datastore.DSN{ Host: os.Getenv("REGISTRY_DATABASE_HOST"), Port: port, User: os.Getenv("REGISTRY_DATABASE_USER"), Password: os.Getenv("REGISTRY_DATABASE_PASSWORD"), DBName: "registry_test", SSLMode: os.Getenv("REGISTRY_DATABASE_SSLMODE"), SSLCert: os.Getenv("REGISTRY_DATABASE_SSLCERT"), SSLKey: os.Getenv("REGISTRY_DATABASE_SSLKEY"), SSLRootCert: os.Getenv("REGISTRY_DATABASE_SSLROOTCERT"), } return dsn, nil } // NewDSNFromConfig generates a new DSN for the test database based on configuration options. func NewDSNFromConfig(config configuration.Database) (*datastore.DSN, error) { dsn := &datastore.DSN{ Host: config.Host, Port: config.Port, User: config.User, Password: config.Password, DBName: "registry_test", SSLMode: config.SSLMode, SSLCert: config.SSLCert, SSLKey: config.SSLKey, SSLRootCert: config.SSLRootCert, } return dsn, nil } func newDB(dsn *datastore.DSN, logLevel logrus.Level, logOut io.Writer, opts []datastore.Option) (datastore.LoadBalancer, error) { log := logrus.New() log.SetLevel(logLevel) log.SetOutput(logOut) // The registry application defaults to using the simple protocol for connecting to PostgreSQL // instead of prepared statements. There are notable differences in how some queries are built // and resolved depending on the chosen execution mode. For details, see issue https://github.com/jackc/pgx/issues/2157. // To ensure consistency and avoid false positives in tests, we configure the test database // to also use the simple protocol. opts = append(opts, datastore.WithLogger(logrus.NewEntry(log)), datastore.WithPreparedStatements(false)) db, err := datastore.NewDBLoadBalancer(context.Background(), dsn, opts...) if err != nil { return nil, fmt.Errorf("opening database connection: %w", err) } return db, nil } // NewDBFromEnv generates a new datastore.DB and opens the underlying connection based on environment variable settings. func NewDBFromEnv() (*datastore.DB, error) { dsn, err := NewDSNFromEnv() if err != nil { return nil, err } logLevel, err := logrus.ParseLevel(os.Getenv("REGISTRY_LOG_LEVEL")) if err != nil { logLevel = logrus.InfoLevel } var logOut io.Writer switch os.Getenv("REGISTRY_LOG_OUTPUT") { case "stdout": logOut = os.Stdout case "stderr": logOut = os.Stderr case "discard": logOut = io.Discard default: logOut = os.Stdout } var dbOpts []datastore.Option poolConfig := datastore.PoolConfig{} tmp := os.Getenv("REGISTRY_DATABASE_POOL_MAXOPEN") if tmp != "" { poolMaxOpen, err := strconv.Atoi(tmp) if err != nil { return nil, fmt.Errorf("invalid REGISTRY_DATABASE_POOL_MAXOPEN: %w", err) } poolConfig.MaxOpen = poolMaxOpen dbOpts = append(dbOpts, datastore.WithPoolConfig(&poolConfig)) } dlb, err := newDB(dsn, logLevel, logOut, dbOpts) if err != nil { return nil, err } return dlb.Primary(), nil } // NewDBFromConfig generates a new datastore.LoadBalancer and opens the underlying connections based on configuration settings. func NewDBFromConfig(config *configuration.Configuration) (datastore.LoadBalancer, error) { dsn, err := NewDSNFromConfig(config.Database) if err != nil { return nil, err } logLevel, err := logrus.ParseLevel(config.Log.Level.String()) if err != nil { logLevel = logrus.InfoLevel } var logOut io.Writer switch config.Log.Output { case configuration.LogOutputStdout: logOut = configuration.LogOutputStdout.Descriptor() case configuration.LogOutputStderr: logOut = configuration.LogOutputStderr.Descriptor() case configuration.LogOutputDiscard: default: logOut = configuration.LogOutputStdout.Descriptor() } var dbOpts []datastore.Option poolConfig := datastore.PoolConfig{MaxOpen: config.Database.Pool.MaxOpen} dbOpts = append(dbOpts, datastore.WithPoolConfig(&poolConfig)) if config.Database.LoadBalancing.Enabled { // service discovery takes precedence over fixed hosts if config.Database.LoadBalancing.Record != "" { nameserver := config.Database.LoadBalancing.Nameserver port := config.Database.LoadBalancing.Port record := config.Database.LoadBalancing.Record resolver := datastore.NewDNSResolver(nameserver, port, record) dbOpts = append(dbOpts, datastore.WithServiceDiscovery(resolver)) } else if len(config.Database.LoadBalancing.Hosts) > 0 { hosts := config.Database.LoadBalancing.Hosts dbOpts = append(dbOpts, datastore.WithFixedHosts(hosts)) } } return newDB(dsn, logLevel, logOut, dbOpts) } // TruncateTables truncates a set of tables in the test database. func TruncateTables(db *datastore.DB, tables ...table) error { for _, table := range tables { if err := table.truncate(db); err != nil { return fmt.Errorf("truncating tables: %w", err) } } return nil } // TruncateAllTables truncates all tables in the test database. func TruncateAllTables(db *datastore.DB) error { return TruncateTables(db, AllTables...) } // ReloadFixtures truncates all a given set of tables and then injects related fixtures. func ReloadFixtures(tb testing.TB, db *datastore.DB, basePath string, tables ...table) { tb.Helper() require.NoError(tb, TruncateTables(db, tables...)) for _, table := range tables { path := filepath.Join(basePath, "testdata", "fixtures", table.seedFileName()) // nolint: gosec // this is just a testutil query, err := os.ReadFile(path) require.NoErrorf(tb, err, "error reading fixture") _, err = db.Exec(string(query)) require.NoErrorf(tb, err, "error loading fixture") } } // ParseTimestamp parses a timestamp into a time.Time, matching a given location. func ParseTimestamp(tb testing.TB, timestamp string, location *time.Location) time.Time { tb.Helper() t, err := time.Parse("2006-01-02 15:04:05.000000", timestamp) require.NoError(tb, err) return t.In(location) } func createGoldenFile(tb testing.TB, path string) { tb.Helper() if _, err := os.Stat(path); os.IsNotExist(err) { tb.Log("creating .golden file") // nolint: gosec // this is just a testutil f, err := os.Create(path) require.NoError(tb, err, "error creating .golden file") require.NoError(tb, f.Close()) } } func updateGoldenFile(tb testing.TB, path string, content []byte) { tb.Helper() tb.Log("updating .golden file") // nolint: gosec // this is just a testutil err := os.WriteFile(path, content, 0o644) require.NoError(tb, err, "error updating .golden file") } func readGoldenFile(tb testing.TB, path string) []byte { tb.Helper() // nolint: gosec // this is just a testutil content, err := os.ReadFile(path) require.NoError(tb, err, "error reading .golden file") return content } // CompareWithGoldenFile compares an actual value with the content of a .golden file. If requested, a missing golden // file is automatically created and an outdated golden file automatically updated to match the actual content. func CompareWithGoldenFile(tb testing.TB, path string, actual []byte, create, update bool) { tb.Helper() if create { createGoldenFile(tb, path) } if update { updateGoldenFile(tb, path, actual) } expected := readGoldenFile(tb, path) require.Equal(tb, string(expected), string(actual), "does not match .golden file") } type RedisClient struct { redis redis.UniversalClient } // FlushCache Removes all cached data in the cache func (r *RedisClient) FlushCache() error { if err := r.redis.FlushAll(context.Background()).Err(); err != nil { return fmt.Errorf("flushing redis cache: %w", err) } return nil } // NewRedisClientFromConfig generates a new redis cache client based on configuration settings. func NewRedisClientFromConfig(config *configuration.Configuration) (*RedisClient, error) { opts := &redis.UniversalOptions{ Addrs: strings.Split(config.Redis.Cache.Addr, ","), DB: config.Redis.Cache.DB, Password: config.Redis.Cache.Password, } redis := redis.NewUniversalClient(opts) // Ensure the client is correctly configured and the server is reachable. We use a new local context here with a // tight timeout to avoid blocking the application start for too long. pingCtx, cancel := context.WithTimeout(context.Background(), 1*time.Second) defer cancel() if cmd := redis.Ping(pingCtx); cmd.Err() != nil { return nil, cmd.Err() } return &RedisClient{redis}, nil }