cns/deviceplugin/server.go (138 lines of code) (raw):

package deviceplugin import ( "context" "fmt" "net" "time" "github.com/pkg/errors" "go.uber.org/zap" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" "k8s.io/kubelet/pkg/apis/deviceplugin/v1beta1" ) const devicePrefix = "NIC-" type deviceCounter interface { getDeviceCount() int } type Server struct { address string logger *zap.Logger deviceCounter deviceCounter shutdownCh <-chan struct{} deviceCheckInterval time.Duration } func NewServer(logger *zap.Logger, address string, deviceCounter deviceCounter, deviceCheckInterval time.Duration) *Server { return &Server{ address: address, logger: logger, deviceCounter: deviceCounter, deviceCheckInterval: deviceCheckInterval, } } // Run starts the grpc server and blocks until an error or context is cancelled. Wait on Ready to know when the server is ready. func (s *Server) Run(ctx context.Context) error { grpcServer := grpc.NewServer() v1beta1.RegisterDevicePluginServer(grpcServer, s) childCtx, cancel := context.WithCancel(ctx) defer cancel() s.shutdownCh = childCtx.Done() l, err := net.Listen("unix", s.address) if err != nil { return errors.Wrap(err, "error listening on socket") } defer l.Close() go func() { <-ctx.Done() grpcServer.GracefulStop() }() if err := grpcServer.Serve(l); err != nil && !errors.Is(err, grpc.ErrServerStopped) { return errors.Wrap(err, "error running grpc server") } return nil } // Ready blocks until the server is ready func (s *Server) Ready(ctx context.Context) error { c, err := grpc.DialContext(ctx, s.address, grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithBlock(), //nolint:staticcheck // TODO: Move to grpc.NewClient method grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) { conn, err := (&net.Dialer{}).DialContext(ctx, "unix", addr) if err != nil { return nil, errors.Wrap(err, "failed to dial context") } return conn, nil }), ) if err != nil { return errors.Wrap(err, "error dialing local grpc server") } if err := c.Close(); err != nil { return errors.Wrap(err, "error closing connection to local grpc server") } return nil } // This is a dummy implementation for allocate to conform to the interface requirements. // Allocate is called during container creation so that the Device // Plugin can run device specific operations and instruct Kubelet // of the steps to make the Device available in the container // We are not using this functionality currently func (s *Server) Allocate(_ context.Context, req *v1beta1.AllocateRequest) (*v1beta1.AllocateResponse, error) { s.logger.Info("allocate request", zap.Any("req", *req)) resps := make([]*v1beta1.ContainerAllocateResponse, len(req.ContainerRequests)) for i, containerReq := range req.ContainerRequests { resp := &v1beta1.ContainerAllocateResponse{ Envs: make(map[string]string), } for j := range containerReq.DevicesIDs { resp.Envs[fmt.Sprintf("%s%d", devicePrefix, j)] = containerReq.DevicesIDs[j] } resps[i] = resp } r := &v1beta1.AllocateResponse{ ContainerResponses: resps, } return r, nil } func (s *Server) ListAndWatch(_ *v1beta1.Empty, stream v1beta1.DevicePlugin_ListAndWatchServer) error { // send the initial count right away advertisedCount := s.deviceCounter.getDeviceCount() devices := make([]*v1beta1.Device, advertisedCount) for i := range devices { devices[i] = &v1beta1.Device{ ID: fmt.Sprintf("%s%d", devicePrefix, i), Health: v1beta1.Healthy, } } if err := stream.Send(&v1beta1.ListAndWatchResponse{ Devices: devices, }); err != nil { return errors.Wrap(err, "error sending listAndWatch response") } // every interval, check if the current count has changed from what we've previously sent, and if so, send the new count ticker := time.NewTicker(s.deviceCheckInterval) defer ticker.Stop() for { select { case <-s.shutdownCh: return nil case <-stream.Context().Done(): return errors.Wrap(stream.Context().Err(), "client context done") case <-ticker.C: currentCount := s.deviceCounter.getDeviceCount() if currentCount == advertisedCount { continue } advertisedCount = currentCount devices := make([]*v1beta1.Device, advertisedCount) for i := range devices { devices[i] = &v1beta1.Device{ ID: fmt.Sprintf("%s%d", devicePrefix, i), Health: v1beta1.Healthy, } } if err := stream.Send(&v1beta1.ListAndWatchResponse{ Devices: devices, }); err != nil { return errors.Wrap(err, "error sending listAndWatch response") } } } } func (s *Server) GetDevicePluginOptions(context.Context, *v1beta1.Empty) (*v1beta1.DevicePluginOptions, error) { return &v1beta1.DevicePluginOptions{}, nil } func (s *Server) GetPreferredAllocation(context.Context, *v1beta1.PreferredAllocationRequest) (*v1beta1.PreferredAllocationResponse, error) { return &v1beta1.PreferredAllocationResponse{}, nil } func (s *Server) PreStartContainer(context.Context, *v1beta1.PreStartContainerRequest) (*v1beta1.PreStartContainerResponse, error) { return &v1beta1.PreStartContainerResponse{}, nil }