tools/migration/rdbms/migrationlib/version.go (50 lines of code) (raw):
// Copyright (c) Facebook, Inc. and its affiliates.
//
// This source code is licensed under the MIT license found in the
// LICENSE file in the root directory of this source tree.
package migrationlib
import (
"fmt"
"database/sql"
"github.com/go-sql-driver/mysql"
"github.com/pressly/goose"
)
// queryExecutor represents the MCD interface between sql.Db and sql.Tx
// objects that is sufficient to retrieve the db version number.
type queryExecutor interface {
Query(query string, args ...interface{}) (*sql.Rows, error)
}
// DBVersion returns the current version of the database schema.
// The library exposes a GetDBVersion function which however requires
// a *sql.Db object. Several contest flows work instead within a
// transactional context, so we need to be able to extract the version
// in a way that is agnostic to the `sql` object used (either `sql.Tx` or
// `sql.Db`).
func DBVersion(db queryExecutor) (uint64, error) {
rows, err := db.Query(fmt.Sprintf("select version_id, is_applied from %s order by id desc", goose.TableName()))
if err != nil {
if mysqlErr, ok := err.(*mysql.MySQLError); ok {
if mysqlErr.Number == 1146 {
// db versioning table does not exist, assume version 0
return 0, nil
}
}
return 0, fmt.Errorf("could not retrieve db version, error %+v, %T: %w", err, err, err)
}
defer rows.Close()
toSkip := make([]uint64, 0)
for rows.Next() {
var (
versionID uint64
isApplied bool
)
if err = rows.Scan(&versionID, &isApplied); err != nil {
return 0, fmt.Errorf("could not scan row: %w", err)
}
// the is_applied field tracks whether the migration was an
// up or down migration. If we see a down migration, while
// going through the records in descending order, then we
// should skip that version altogether.
skip := false
for _, v := range toSkip {
if v == versionID {
skip = true
break
}
}
if skip {
continue
}
if isApplied {
return versionID, nil
}
// latest version of migration has not been applied.
toSkip = append(toSkip, versionID)
}
if err := rows.Err(); err != nil {
return 0, fmt.Errorf("could not retrieve db version: %w", err)
}
// if we couldn't figure out the db version, assume we are working on version 0
// i.e. the db is not versioned yet.
return 0, nil
}