cmd/proxy/main.go (203 lines of code) (raw):
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.
package main
import (
"context"
"errors"
"fmt"
"net"
"net/http"
"net/url"
"os"
"os/signal"
"syscall"
"time"
"github.com/alexflint/go-arg"
"github.com/azure/peerd/pkg/containerd"
pcontext "github.com/azure/peerd/pkg/context"
"github.com/azure/peerd/pkg/discovery/content/provider"
"github.com/azure/peerd/pkg/discovery/routing"
"github.com/azure/peerd/pkg/files/store"
"github.com/azure/peerd/pkg/handlers"
"github.com/azure/peerd/pkg/k8s"
"github.com/azure/peerd/pkg/k8s/events"
"github.com/azure/peerd/pkg/metrics"
"github.com/prometheus/client_golang/prometheus/promhttp"
"github.com/rs/zerolog"
"github.com/spf13/afero"
"golang.org/x/sync/errgroup"
)
func main() {
args := &Arguments{}
arg.MustParse(args)
ll, err := zerolog.ParseLevel(args.LogLevel)
if err != nil {
fmt.Fprintf(os.Stderr, "invalid log level: %s\n", args.LogLevel)
os.Exit(1)
}
zerolog.SetGlobalLevel(ll)
zerolog.TimeFieldFormat = time.RFC3339Nano
l := zerolog.New(os.Stdout).With().Timestamp().Str("self", pcontext.NodeName).Str("version", version).Logger()
ctx := l.WithContext(context.Background())
ctx, err = metrics.WithContext(ctx, pcontext.NodeName, "peerd")
if err != nil {
l.Error().Err(err).Msg("failed to initialize metrics")
os.Exit(1)
}
err = run(ctx, args)
if err != nil {
l.Error().Err(err).Msg("server error")
os.Exit(1)
}
l.Info().Msg("server shutdown")
}
func run(ctx context.Context, args *Arguments) error {
ctx, cancel := signal.NotifyContext(ctx, syscall.SIGTERM)
defer cancel()
switch {
case args.Version:
zerolog.Ctx(ctx).Info().Msg("version") // version field is already added to the logger
return nil
case args.Server != nil:
return serverCommand(ctx, args.Server)
default:
return fmt.Errorf("unknown subcommand")
}
}
func serverCommand(ctx context.Context, args *ServerCmd) (err error) {
l := zerolog.Ctx(ctx)
store.PrefetchWorkers = args.PrefetchWorkers
_, httpsPort, err := net.SplitHostPort(args.HttpsAddr)
if err != nil {
return err
}
clientset, err := k8s.NewKubernetesInterface(pcontext.KubeConfigPath, pcontext.NodeName)
if err != nil {
return err
}
ctx, err = events.WithContext(ctx, clientset)
if err != nil {
return err
}
eventsRecorder := events.FromContext(ctx)
defer func() {
if err != nil {
eventsRecorder.Failed()
}
}()
eventsRecorder.Initializing()
r, err := routing.NewRouter(ctx, clientset, args.RouterAddr, httpsPort)
if err != nil {
return err
}
err = addMirrorConfiguration(ctx, args)
if err != nil {
return err
}
containerdStore, err := containerd.NewDefaultStore(args.Hosts)
if err != nil {
return err
}
err = containerdStore.Verify(ctx)
if err != nil {
return err
}
filesStore, err := store.NewFilesStore(ctx, r, store.DefaultFileCachePath)
if err != nil {
return err
}
g, ctx := errgroup.WithContext(ctx)
g.Go(func() error {
provider.Provide(ctx, r, containerdStore, filesStore.Subscribe())
return nil
})
handler, err := handlers.Handler(ctx, r, containerdStore, filesStore)
if err != nil {
return err
}
httpsSrv := &http.Server{
Addr: args.HttpsAddr,
Handler: handler,
TLSConfig: r.Net().DefaultTLSConfig(),
}
g.Go(func() error {
if err := httpsSrv.ListenAndServeTLS("", ""); err != nil && !errors.Is(err, http.ErrServerClosed) {
return err
}
return nil
})
g.Go(func() error {
<-ctx.Done()
shutdownCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
return httpsSrv.Shutdown(shutdownCtx)
})
httpSrv := &http.Server{
Addr: args.HttpAddr,
Handler: handler,
}
g.Go(func() error {
if err := httpSrv.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
return err
}
return nil
})
g.Go(func() error {
<-ctx.Done()
shutdownCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
return httpSrv.Shutdown(shutdownCtx)
})
g.Go(func() error {
http.Handle("/metrics/prometheus", promhttp.Handler())
if err = http.ListenAndServe(args.PromAddr, nil); err != nil && !errors.Is(err, http.ErrServerClosed) {
return err
}
return nil
})
l.Info().Str("https", args.HttpsAddr).Str("http", args.HttpAddr).Str("router", args.RouterAddr).Str("prom", args.PromAddr).Msg("server start")
err = g.Wait()
if err != nil {
return err
}
return nil
}
func addMirrorConfiguration(ctx context.Context, args *ServerCmd) error {
if !args.AddMirrorConfiguration {
return nil
}
l := zerolog.Ctx(ctx)
fs := afero.NewOsFs()
defaultHost, _ := url.Parse("https://mcr.microsoft.com")
hosts := append([]url.URL{}, *defaultHost)
var err error
if len(args.Hosts) > 0 {
hosts, err = toUrls(args.Hosts)
if err != nil {
return err
}
}
l.Info().Msg(fmt.Sprintf("mirrors args: %v, hosts: %v", args.Hosts, hosts))
defaultMirror, _ := url.Parse("https://localhost:30001")
mirrors := append([]url.URL{}, *defaultMirror)
if len(args.Mirrors) > 0 {
mirrors, err = toUrls(args.Mirrors)
if err != nil {
return err
}
}
err = containerd.AddHostsConfiguration(ctx, fs, args.ContainerdHostsConfigPath, hosts, mirrors, false)
if err != nil {
return err
}
return nil
}
func toUrls(hosts []string) ([]url.URL, error) {
var urls []url.URL
for _, h := range hosts {
u, err := url.Parse(h)
if err != nil {
return nil, err
}
urls = append(urls, *u)
}
return urls, nil
}