plugins/targetlocker/dblocker/dblocker.go (327 lines of code) (raw):

// Copyright (c) Facebook, Inc. and its affiliates. // // This source code is licensed under the MIT license found in the // LICENSE file in the root directory of this source tree. package dblocker import ( "database/sql" "fmt" "strings" "time" "github.com/benbjohnson/clock" // this blank import registers the mysql driver _ "github.com/go-sql-driver/mysql" "github.com/facebookincubator/contest/pkg/target" "github.com/facebookincubator/contest/pkg/types" "github.com/facebookincubator/contest/pkg/xcontext" ) // Name is the plugin name. var Name = "DBLocker" const DefaultMaxBatchSize = 100 // used for functions that can operate with and without transactions type db interface { Query(query string, args ...interface{}) (*sql.Rows, error) } // dblock represents parts of lock in the database, basically // a row from SELECT target_id, job_ID, expires_at type dblock struct { targetID string jobID int64 createdAt time.Time expiresAt time.Time } // String pretty-prints dblocks for logging and errors func (d dblock) String() string { return fmt.Sprintf( "target: %s job: %d created: %s expires: %s", d.targetID, d.jobID, d.createdAt, d.expiresAt, ) } // targetIDList is a helper to convert contest targets to // a list of primary IDs used in the database func targetIDList(targets []*target.Target) []string { res := make([]string, 0, len(targets)) for _, target := range targets { res = append(res, target.ID) } return res } // listQueryString is a helper to create a (?, ?, ?) string // with as many ? as requested. // This can safely be concatenated into SQL queries as it // can never contain input data, it only repeates a static // string a given number of times. func listQueryString(length uint) string { switch length { case 0: return "()" case 1: return "(?)" default: return "(" + strings.Repeat("?, ", int(length)-1) + "?)" } } // DBLocker implements a simple target locker based on a relational database. // The current implementation only supports MySQL officially. // All functions in DBLocker are safe for concurrent use by multiple goroutines. type DBLocker struct { driverName string db *sql.DB maxBatchSize int // clock is used for measuring time clock clock.Clock } // queryLocks returns a map of ID -> dblock for a given list of targets func (d *DBLocker) queryLocks(tx db, targets []string) (map[string]dblock, error) { q := "SELECT target_id, job_id, created_at, expires_at FROM locks WHERE target_id IN " + listQueryString(uint(len(targets))) // convert targets to a list of interface{} queryList := make([]interface{}, 0, len(targets)) for _, targetID := range targets { queryList = append(queryList, targetID) } rows, err := tx.Query(q, queryList...) if err != nil { return nil, fmt.Errorf("unable to read existing locks: %w", err) } defer rows.Close() row := dblock{} locks := make(map[string]dblock) for rows.Next() { if err := rows.Scan(&row.targetID, &row.jobID, &row.createdAt, &row.expiresAt); err != nil { return nil, fmt.Errorf("unexpected read from database: %w", err) } locks[row.targetID] = row } if err := rows.Err(); err != nil { return nil, fmt.Errorf("unexpected error iterating db read results: %w", err) } return locks, nil } // handleLock does the real locking, it assumes the jobID is valid func (d *DBLocker) handleLock(ctx xcontext.Context, jobID int64, targets []string, limit uint, timeout time.Duration, requireLocked bool, allowConflicts bool) ([]string, error) { if len(targets) == 0 { return nil, nil } // everything operates on this frozen time now := d.clock.Now() expiresAt := now.Add(timeout) tx, err := d.db.Begin() if err != nil { return nil, fmt.Errorf("unable to start database transaction: %w", err) } defer func() { // this always fails if tx.Commit() was called before, ignore error _ = tx.Rollback() }() locks, err := d.queryLocks(tx, targets) if err != nil { return nil, err } // go through existing locks, they are either held by something else and valid // (abort or skip, depending on allowConflicts setting), // not valid anymore (update), held by something else and not valid (expired), // held by us or not held at all (insert) var toInsert, missing []string var toDelete, conflicts []dblock for _, targetID := range targets { lock, ok := locks[targetID] switch { case !ok: // nonexistent lock if requireLocked { missing = append(missing, targetID) } toInsert = append(toInsert, targetID) case lock.jobID == jobID: // our lock, possibly expired toDelete = append(toDelete, lock) toInsert = append(toInsert, targetID) case lock.expiresAt.Before(now): // other job's expired lock if !requireLocked { toDelete = append(toDelete, lock) toInsert = append(toInsert, targetID) } else { conflicts = append(conflicts, lock) } default: conflicts = append(conflicts, lock) } if uint(len(toInsert)) >= limit { break } } if (len(conflicts) > 0 && !allowConflicts) || len(missing) > 0 { return nil, fmt.Errorf("unable to lock targets %v for owner %d, have %d conflicting locks (%v), %d missing locks (%v)", targets, jobID, len(conflicts), conflicts, len(missing), missing) } // First, drop all the locks that we intend to extend or take over. // Use strict matching so that if another instance races ahead of us, row will not be deleted and subsequent insert will fail. { var stmt []string var args []interface{} for i, lock := range toDelete { if len(stmt) == 0 { stmt = append(stmt, "DELETE FROM locks WHERE (target_id = ? AND job_id = ? AND expires_at = ?)") } else { stmt = append(stmt, " OR (target_id = ? AND job_id = ? AND expires_at = ?)") } args = append(args, lock.targetID, lock.jobID, lock.expiresAt) if len(stmt) < d.maxBatchSize && i < len(toDelete)-1 { continue } if _, err := tx.Exec(strings.Join(stmt, ""), args...); err != nil { return nil, fmt.Errorf("insert statement failed: %w", err) } stmt = nil args = nil } } // Now insert new entries for all the targets we are locking. { var stmt []string var args []interface{} for i, targetID := range toInsert { createdAt := now // If we are updating our own lock, carry over the creation timestamp. if lock, ok := locks[targetID]; ok && lock.jobID == jobID { createdAt = lock.createdAt } if len(stmt) == 0 { // this can race with other transactions, acceptable for TryLock if allowConflicts { stmt = append(stmt, "INSERT IGNORE INTO locks (target_id, job_id, created_at, expires_at, valid) VALUES (?, ?, ?, ?, ?)") } else { stmt = append(stmt, "INSERT INTO locks (target_id, job_id, created_at, expires_at, valid) VALUES (?, ?, ?, ?, ?)") } } else { stmt = append(stmt, ", (?, ?, ?, ?, ?)") } args = append(args, targetID, jobID, createdAt, expiresAt, true) if len(stmt) < d.maxBatchSize && i < len(toInsert)-1 { continue } if _, err := tx.Exec(strings.Join(stmt, ""), args...); err != nil { return nil, fmt.Errorf("insert statement failed: %w", err) } stmt = nil args = nil } } // Main transaction done txErr := tx.Commit() // Done except for TryLock if txErr != nil || !allowConflicts || len(toInsert) == 0 { return toInsert, txErr } // TryLock uses INSERT IGNORE, read inserted rows back to see which ones made it var actualInserts []string { actualLocks, err := d.queryLocks(d.db, toInsert) if err != nil { return nil, err } for _, targetID := range toInsert { lock, ok := actualLocks[targetID] // only care about locks that we own now if ok && lock.jobID == jobID { actualInserts = append(actualInserts, targetID) } } } return actualInserts, nil } // handleUnlock does the real unlocking, it assumes the jobID is valid func (d *DBLocker) handleUnlock(ctx xcontext.Context, jobID int64, targets []string) error { if len(targets) == 0 { return nil } tx, err := d.db.Begin() if err != nil { return fmt.Errorf("unable to start database transaction: %w", err) } defer func() { // this always fails if tx.Commit() was called before, ignore error _ = tx.Rollback() }() // check lock states: must own locks for all targets (expired is ok too). locks, err := d.queryLocks(tx, targets) if err != nil { return err } for _, t := range targets { l, ok := locks[t] if !ok { return fmt.Errorf("target %q is not locked", t) } if l.jobID != jobID { return fmt.Errorf("unlock request: target %q is locked by %q, not by %q", t, l.jobID, jobID) } } // drop non-conflicting locks del := "DELETE FROM locks WHERE job_id = ? AND target_id IN " + listQueryString(uint(len(targets))) queryList := make([]interface{}, 0, len(targets)+1) queryList = append(queryList, jobID) for _, targetID := range targets { queryList = append(queryList, targetID) } if _, err := tx.Exec(del, queryList...); err != nil { return fmt.Errorf("unable to unlock targets %v, owner %d: %w", targets, jobID, err) } return tx.Commit() } func validateTargets(targets []*target.Target) error { for _, target := range targets { if target.ID == "" { return fmt.Errorf("target list cannot contain empty target ID. Full list: %v", targets) } } return nil } // Lock locks the given targets. // See target.Locker for API details func (d *DBLocker) Lock(ctx xcontext.Context, jobID types.JobID, duration time.Duration, targets []*target.Target) error { if jobID == 0 { return fmt.Errorf("invalid lock request, jobID cannot be zero (targets: %v)", targets) } if err := validateTargets(targets); err != nil { return fmt.Errorf("invalid lock request: %w", err) } _, err := d.handleLock(ctx, int64(jobID), targetIDList(targets), uint(len(targets)), duration, false /* requireLocked */, false /* allowConflicts */) ctx.Debugf("Lock %d targets for %s: %v", len(targets), duration, err) return err } // TryLock attempts to locks the given targets. // See target.Locker for API details func (d *DBLocker) TryLock(ctx xcontext.Context, jobID types.JobID, duration time.Duration, targets []*target.Target, limit uint) ([]string, error) { if jobID == 0 { return nil, fmt.Errorf("invalid tryLock request, jobID cannot be zero (targets: %v)", targets) } if err := validateTargets(targets); err != nil { return nil, fmt.Errorf("invalid tryLock request: %w", err) } if limit == 0 { return nil, nil } res, err := d.handleLock(ctx, int64(jobID), targetIDList(targets), limit, duration, false /* requireLocked */, true /* allowConflicts */) ctx.Debugf("TryLock %d targets for %s: %d %v", len(targets), duration, len(res), err) return res, err } // Unlock unlocks the given targets. // See target.Locker for API details func (d *DBLocker) Unlock(ctx xcontext.Context, jobID types.JobID, targets []*target.Target) error { if jobID == 0 { return fmt.Errorf("invalid unlock request, jobID cannot be zero (targets: %v)", targets) } if err := validateTargets(targets); err != nil { return fmt.Errorf("invalid unlock request: %w", err) } err := d.handleUnlock(ctx, int64(jobID), targetIDList(targets)) ctx.Debugf("Unlock %d targets: %v", len(targets), err) return err } // RefreshLocks refreshes the locks on the given targets. // See target.Locker for API details func (d *DBLocker) RefreshLocks(ctx xcontext.Context, jobID types.JobID, duration time.Duration, targets []*target.Target) error { if jobID == 0 { return fmt.Errorf("invalid refresh request, jobID cannot be zero (targets: %v)", targets) } if err := validateTargets(targets); err != nil { return fmt.Errorf("invalid refresh request: %w", err) } _, err := d.handleLock(ctx, int64(jobID), targetIDList(targets), uint(len(targets)), duration, true /* requireLocked */, false /* allowConflicts */) ctx.Debugf("RefreshLocks on %d targets for %s: %v", len(targets), duration, err) return err } // Close closes the DB connection and releases resources. func (d *DBLocker) Close() error { return d.db.Close() } // ResetAllLocks resets the database and clears all locks, regardless of who owns them. // This is primarily for testing, and should not be used by used in prod, this // is why it is not exposed by target.Locker func (d *DBLocker) ResetAllLocks(ctx xcontext.Context) error { ctx.Warnf("DELETING ALL LOCKS") _, err := d.db.Exec("TRUNCATE TABLE locks") return err } // Opt is a function type that sets parameters on the DBLocker object type Opt func(dblocker *DBLocker) // WithDriverName option allows using a mysql-compatible driver (e.g. a wrapper around mysql // or a syntax-compatible variant). func WithDriverName(name string) Opt { return func(d *DBLocker) { d.driverName = name } } // WithMaxBatchSize option sets maximum batch size for statements. func WithMaxBatchSize(value int) Opt { return func(d *DBLocker) { d.maxBatchSize = value } } // WithClock option sets clock used for timestamps. func WithClock(value clock.Clock) Opt { return func(d *DBLocker) { d.clock = value } } // New initializes and returns a new DBLocker target locker. func New(dbURI string, opts ...Opt) (*DBLocker, error) { res := &DBLocker{ maxBatchSize: DefaultMaxBatchSize, clock: clock.New(), } for _, Opt := range opts { Opt(res) } driverName := "mysql" if res.driverName != "" { driverName = res.driverName } db, err := sql.Open(driverName, dbURI) if err != nil { return nil, fmt.Errorf("could not initialize database: %w", err) } if err := db.Ping(); err != nil { return nil, fmt.Errorf("unable to contact database: %w", err) } res.db = db return res, nil }