image/resources/knfsd-fsidd/sql.go (160 lines of code) (raw):
/*
Copyright 2022 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package main
import (
"context"
_ "embed"
"errors"
"fmt"
"net"
"strings"
"text/template"
"time"
"cloud.google.com/go/cloudsqlconn"
"github.com/GoogleCloudPlatform/knfsd-cache-utils/image/resources/knfsd-fsidd/internal/metrics"
"github.com/GoogleCloudPlatform/knfsd-cache-utils/image/resources/knfsd-fsidd/log"
"github.com/jackc/pgconn"
"github.com/jackc/pgx/v4"
"github.com/jackc/pgx/v4/pgxpool"
)
//go:embed schema.sql
var tableSchema string
type DB interface {
BeginTxFunc(ctx context.Context, txOptions pgx.TxOptions, f func(pgx.Tx) error) error
Exec(ctx context.Context, sql string, arguments ...interface{}) (pgconn.CommandTag, error)
QueryRow(ctx context.Context, sql string, args ...interface{}) pgx.Row
Close()
}
type DBWrapper struct {
dialer *cloudsqlconn.Dialer
db DB
}
func (w *DBWrapper) Close() {
if w.db != nil {
w.db.Close()
}
if w.dialer != nil {
w.dialer.Close()
}
}
func (w *DBWrapper) BeginTxFunc(ctx context.Context, txOptions pgx.TxOptions, f func(pgx.Tx) error) error {
return w.db.BeginTxFunc(ctx, txOptions, f)
}
func (w *DBWrapper) Exec(ctx context.Context, sql string, arguments ...interface{}) (pgconn.CommandTag, error) {
return w.db.Exec(ctx, sql, arguments...)
}
func (w *DBWrapper) QueryRow(ctx context.Context, sql string, args ...interface{}) pgx.Row {
return w.db.QueryRow(ctx, sql, args...)
}
func connect(ctx context.Context, config DatabaseConfig) (DB, error) {
pgConfig, err := pgxpool.ParseConfig(config.URL)
if err != nil {
return nil, err
}
dialer, err := newDialer(ctx, config)
if err != nil {
return nil, err
}
pgConfig.ConnConfig.DialFunc = func(ctx context.Context, network, addr string) (net.Conn, error) {
// ignore the host (addr) requested by pgx and instead use the cloud SQL instance
return dialer.Dial(ctx, config.Instance)
}
log.Debug.Print("Creating pgxpool")
db, err := pgxpool.ConnectConfig(ctx, pgConfig)
if err != nil {
dialer.Close()
return nil, err
}
return &DBWrapper{dialer, db}, err
}
func newDialer(ctx context.Context, config DatabaseConfig) (*cloudsqlconn.Dialer, error) {
var dialOptions []cloudsqlconn.DialOption
var options []cloudsqlconn.Option
if config.IAMAuth {
options = append(options, cloudsqlconn.WithIAMAuthN())
}
if config.PrivateIP {
dialOptions = append(dialOptions, cloudsqlconn.WithPrivateIP())
} else {
dialOptions = append(dialOptions, cloudsqlconn.WithPublicIP())
}
options = append(options, cloudsqlconn.WithDefaultDialOptions(dialOptions...))
log.Debug.Print("Creating Cloud SQL dialer")
dialer, err := cloudsqlconn.NewDialer(ctx, options...)
if err != nil {
return nil, err
}
log.Debug.Print("Warming up Cloud SQL dialer")
err = dialer.Warmup(ctx, config.Instance)
if err != nil {
dialer.Close()
return nil, err
}
return dialer, nil
}
type FSIDSource struct {
db DB
tableName string
}
func (s FSIDSource) CreateTable(ctx context.Context) error {
log.Debug.Printf("creating table \"%s\"", s.tableName)
t, err := template.New("schema").Parse(tableSchema)
if err != nil {
return err
}
w := &strings.Builder{}
err = t.Execute(w, s.tableName)
if err != nil {
return err
}
sql := w.String()
return withRetry(ctx, func() error {
_, err = s.db.Exec(ctx, sql)
return err
})
}
func (s FSIDSource) GetFSID(ctx context.Context, path string) (int32, error) {
var fsid int32
start := time.Now()
sql := fmt.Sprintf("SELECT fsid FROM \"%s\" WHERE path = $1", s.tableName)
row := s.db.QueryRow(ctx, sql, path)
err := row.Scan(&fsid)
metrics.SQLOperation(ctx, "get_fsid", SQLMetricResult(err), time.Since(start))
return fsid, err
}
func (s FSIDSource) AllocateFSID(ctx context.Context, path string) (int32, error) {
var fsid int32
start := time.Now()
sql := fmt.Sprintf("INSERT INTO \"%s\" (path) VALUES ($1) RETURNING fsid", s.tableName)
row := s.db.QueryRow(ctx, sql, path)
err := row.Scan(&fsid)
metrics.SQLOperation(ctx, "allocate_fsid", SQLMetricResult(err), time.Since(start))
return fsid, err
}
func (s FSIDSource) GetPath(ctx context.Context, fsid int32) (string, error) {
var path string
start := time.Now()
sql := fmt.Sprintf("SELECT path FROM \"%s\" WHERE fsid = $1", s.tableName)
row := s.db.QueryRow(ctx, sql, fsid)
err := row.Scan(&path)
metrics.SQLOperation(ctx, "get_path", SQLMetricResult(err), time.Since(start))
return path, err
}
func IsConflict(err error) bool {
var pgerr *pgconn.PgError
if errors.As(err, &pgerr) {
// unique constraint violation
return pgerr.Code == "23505"
} else {
return false
}
}
func IsNotFound(err error) bool {
return errors.Is(err, pgx.ErrNoRows)
}
func SQLMetricResult(err error) string {
if err == nil {
return "ok"
} else if IsNotFound(err) {
return "not_found"
} else if IsConflict(err) {
return "conflict"
} else {
return "error"
}
}