image/resources/knfsd-fsidd/main.go (194 lines of code) (raw):
/*
Copyright 2022 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package main
import (
"context"
"errors"
"os"
"os/signal"
"strconv"
"syscall"
"time"
"github.com/GoogleCloudPlatform/knfsd-cache-utils/image/resources/knfsd-fsidd/internal/metrics"
"github.com/GoogleCloudPlatform/knfsd-cache-utils/image/resources/knfsd-fsidd/log"
"github.com/coreos/go-systemd/v22/daemon"
"github.com/jackc/pgx/v4"
"github.com/spf13/pflag"
)
type FSIDProvider interface {
GetFSID(ctx context.Context, path string) (int32, error)
AllocateFSID(ctx context.Context, path string) (int32, error)
GetPath(ctx context.Context, fsid int32) (string, error)
}
func main() {
var err error
cfg := new(Config)
f := pflag.NewFlagSet(os.Args[0], pflag.ContinueOnError)
// setup flags before reading the config files , otherwise the pflag package
// will overwrite the config with the default values
f.StringVar(&cfg.SocketPath, "socket", defaultSocketPath, "")
f.StringVar(&cfg.Database.URL, "database-url", "", "")
f.StringVar(&cfg.Database.Instance, "database-instance", "", "")
f.StringVar(&cfg.Database.TableName, "table-name", "", "")
f.BoolVar(&cfg.Database.IAMAuth, "iam-auth", false, "")
f.BoolVar(&cfg.Database.PrivateIP, "private-ip", false, "")
f.BoolVar(&cfg.Debug, "debug", false, "")
f.BoolVar(&cfg.Cache, "cache", true, "")
// read the config file before parsing the command line arguments so
// that the command line arguments override any config values
err = readDefaultConfig(cfg)
if err != nil {
log.Error.Printf("could not read config: %s", err)
os.Exit(2)
}
// override values from the config file with environment variables
err = readEnv(cfg)
if err != nil {
printConfigError(err)
os.Exit(2)
}
// command line arguments overrides all other sources
err = f.Parse(os.Args[1:])
if errors.Is(err, pflag.ErrHelp) {
os.Exit(0)
}
if err != nil {
log.Error.Print(err)
os.Exit(2)
}
if cfg.Debug {
log.EnableDebug()
}
err = cfg.Validate()
if err != nil {
printConfigError(err)
os.Exit(2)
}
ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
defer cancel()
err = run(ctx, cfg)
if err != nil {
log.Error.Print(err)
os.Exit(1)
}
}
func run(ctx context.Context, cfg *Config) error {
var err error
m := metrics.Start(ctx, cfg.Metrics)
defer func() {
deadline, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
err := m.Shutdown(deadline)
if err != nil {
log.Warn.Printf("metrics did not shutdown gracefully: %s", err)
}
}()
db, err := connect(ctx, cfg.Database)
if err != nil {
return err
}
defer db.Close()
source := FSIDSource{
db: db,
tableName: cfg.Database.TableName,
}
if cfg.Database.CreateTable {
err = source.CreateTable(ctx)
if err != nil {
return err
}
}
var f FSIDProvider
if cfg.Cache {
f = &FSIDCache{source: source}
} else {
f = source
}
s, err := resolveSocket(cfg.SocketPath)
if err != nil {
return err
}
defer s.Close()
s.Handle("get_fsidnum", func(ctx context.Context, path string) (string, error) {
rec := metrics.StartRequest("get_fsidnum")
if path == "" {
rec.End(ctx, "error")
return "", ErrInvalidArgument
}
var fsid int32
err := withRetry(ctx, func() error {
var err error
rec := rec.StartOperation()
fsid, err = f.GetFSID(ctx, path)
rec.End(ctx, SQLMetricResult(err))
return err
})
rec.End(ctx, SQLMetricResult(err))
if err == nil {
return strconv.FormatInt(int64(fsid), 10), nil
} else if IsNotFound(err) {
return "", nil
} else {
return "", err
}
})
s.Handle("get_or_create_fsidnum", func(ctx context.Context, path string) (string, error) {
rec := metrics.StartRequest("get_or_create_fsidnum")
if path == "" {
rec.End(ctx, "error")
return "", ErrInvalidArgument
}
var fsid int32
err = withRetry(ctx, func() error {
var err error
rec := rec.StartOperation()
fsid, err = f.GetFSID(ctx, path)
if errors.Is(err, pgx.ErrNoRows) {
// FSID not found for path, so try and allocate one.
// This might fail with a 23505 unique_violation if the path has
// already been allocated an FSID by different process. withRetry
// will then retry this whole block and will find the FSID
// allocated by the other process.
fsid, err = f.AllocateFSID(ctx, path)
}
rec.End(ctx, SQLMetricResult(err))
return err
})
rec.End(ctx, SQLMetricResult(err))
return strconv.FormatInt(int64(fsid), 10), err
})
s.Handle("get_path", func(ctx context.Context, arg string) (string, error) {
rec := metrics.StartRequest("get_path")
fsid, err := strconv.ParseInt(arg, 10, 32)
if err != nil {
rec.End(ctx, "error")
return "", ErrInvalidArgument
}
if fsid < 1 {
rec.End(ctx, "error")
return "", ErrInvalidArgument
}
var path string
err = withRetry(ctx, func() error {
var err error
rec := rec.StartOperation()
path, err = f.GetPath(ctx, int32(fsid))
rec.End(ctx, SQLMetricResult(err))
return err
})
rec.End(ctx, SQLMetricResult(err))
return path, err
})
s.Handle("version", func(ctx context.Context, arg string) (string, error) {
metrics.Request(ctx, "version", "ok", 0, 0)
return "1", nil
})
go func() {
<-ctx.Done()
_, err := daemon.SdNotify(false, daemon.SdNotifyStopping)
if err != nil {
log.Error.Print(err)
}
deadline, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
s.Shutdown(deadline)
s.Close()
}()
_, err = daemon.SdNotify(false, daemon.SdNotifyReady)
if err != nil {
return err
}
log.Info.Print("service ready")
err = s.Serve()
if errors.Is(err, ErrServerClosed) {
err = nil
}
return err
}