cns/deviceplugin/plugin.go (153 lines of code) (raw):
package deviceplugin
import (
"context"
"fmt"
"net"
"os"
"path"
"path/filepath"
"strings"
"sync"
"time"
"github.com/Azure/azure-container-networking/crd/multitenancy/api/v1alpha1"
"github.com/pkg/errors"
"go.uber.org/zap"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
"k8s.io/kubelet/pkg/apis/deviceplugin/v1beta1"
)
type Plugin struct {
Logger *zap.Logger
ResourceName string
SocketWatcher *SocketWatcher
Socket string
deviceCountMutex sync.Mutex
deviceCount int
deviceType v1alpha1.DeviceType
kubeletSocket string
deviceCheckInterval time.Duration
devicePluginDirectory string
}
func NewPlugin(l *zap.Logger, resourceName string, socketWatcher *SocketWatcher, pluginDir string,
initialDeviceCount int, deviceType v1alpha1.DeviceType, kubeletSocket string, deviceCheckInterval time.Duration,
) *Plugin {
return &Plugin{
Logger: l.With(zap.String("resourceName", resourceName)),
ResourceName: resourceName,
SocketWatcher: socketWatcher,
Socket: getSocketName(pluginDir, deviceType),
deviceCount: initialDeviceCount,
deviceType: deviceType,
kubeletSocket: kubeletSocket,
deviceCheckInterval: deviceCheckInterval,
devicePluginDirectory: pluginDir,
}
}
// Run runs the plugin until the context is cancelled, restarting the server as needed
func (p *Plugin) Run(ctx context.Context) {
defer p.mustCleanUp()
for {
select {
case <-ctx.Done():
return
default:
p.Logger.Info("starting device plugin for resource", zap.String("resource", p.ResourceName))
if err := p.run(ctx); err != nil {
p.Logger.Error("device plugin for resource exited", zap.String("resource", p.ResourceName), zap.Error(err))
}
}
}
}
// Here we start the gRPC server and wait for it to be ready
// Once the server is ready, device plugin registers with the Kubelet
// so that it can start serving the kubelet requests
func (p *Plugin) run(ctx context.Context) error {
childCtx, cancel := context.WithCancel(ctx)
defer cancel()
s := NewServer(p.Logger, p.Socket, p, p.deviceCheckInterval)
// Run starts the grpc server and blocks until an error or context is cancelled
runErrChan := make(chan error, 2) //nolint:gomnd // disabled in favor of readability
go func(errChan chan error) {
if err := s.Run(childCtx); err != nil {
errChan <- err
}
}(runErrChan)
// Wait till the server is ready before registering with kubelet
// This call is not blocking and returns as soon as the server is ready
readyErrChan := make(chan error, 2) //nolint:gomnd // disabled in favor of readability
go func(errChan chan error) {
errChan <- s.Ready(childCtx)
}(readyErrChan)
select {
case err := <-runErrChan:
return errors.Wrap(err, "error starting grpc server")
case err := <-readyErrChan:
if err != nil {
return errors.Wrap(err, "error waiting on grpc server to be ready")
}
case <-ctx.Done():
return nil
}
p.Logger.Info("registering with kubelet")
// register with kubelet
if err := p.registerWithKubelet(childCtx); err != nil {
return errors.Wrap(err, "failed to register with kubelet")
}
// run until the socket goes away or the context is cancelled
<-p.SocketWatcher.WatchSocket(childCtx, p.Socket)
return nil
}
func (p *Plugin) registerWithKubelet(ctx context.Context) error {
conn, err := grpc.Dial(p.kubeletSocket, grpc.WithTransportCredentials(insecure.NewCredentials()), //nolint:staticcheck // TODO: Move to grpc.NewClient method
grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) {
d := &net.Dialer{}
conn, err := d.DialContext(ctx, "unix", addr)
if err != nil {
return nil, errors.Wrap(err, "failed to dial context")
}
return conn, nil
}))
if err != nil {
return errors.Wrap(err, "error connecting to kubelet")
}
defer conn.Close()
client := v1beta1.NewRegistrationClient(conn)
request := &v1beta1.RegisterRequest{
Version: v1beta1.Version,
Endpoint: filepath.Base(p.Socket),
ResourceName: p.ResourceName,
}
if _, err = client.Register(ctx, request); err != nil {
return errors.Wrap(err, "error sending request to register with kubelet")
}
return nil
}
func (p *Plugin) mustCleanUp() {
p.Logger.Info("cleaning up device plugin")
if err := os.Remove(p.Socket); err != nil && !os.IsNotExist(err) {
p.Logger.Panic("failed to remove socket", zap.Error(err))
}
}
func (p *Plugin) CleanOldState() error {
entries, err := os.ReadDir(p.devicePluginDirectory)
if err != nil {
return errors.Wrap(err, "error listing existing device plugin sockets")
}
for _, entry := range entries {
if strings.HasPrefix(entry.Name(), path.Base(getSocketPrefix(p.devicePluginDirectory, p.deviceType))) {
// try to delete it
f := path.Join(p.devicePluginDirectory, entry.Name())
if err := os.Remove(f); err != nil {
return errors.Wrapf(err, "error removing old socket %q", f)
}
}
}
return nil
}
func (p *Plugin) UpdateDeviceCount(count int) {
p.deviceCountMutex.Lock()
p.deviceCount = count
p.deviceCountMutex.Unlock()
}
func (p *Plugin) getDeviceCount() int {
p.deviceCountMutex.Lock()
defer p.deviceCountMutex.Unlock()
return p.deviceCount
}
// getSocketPrefix returns a fully qualified path prefix for a given device type. For example, if the device plugin directory is
// /home/foo and the device type is acn.azure.com/vnet-nic, this function returns /home/foo/acn.azure.com_vnet-nic
func getSocketPrefix(devicePluginDirectory string, deviceType v1alpha1.DeviceType) string {
sanitizedDeviceName := strings.ReplaceAll(string(deviceType), "/", "_")
return path.Join(devicePluginDirectory, sanitizedDeviceName)
}
func getSocketName(devicePluginDirectory string, deviceType v1alpha1.DeviceType) string {
return fmt.Sprintf("%s-%d.sock", getSocketPrefix(devicePluginDirectory, deviceType), time.Now().Unix())
}