plugins/internal/plugin/device_plugin.go (260 lines of code) (raw):
//go:build windows
package plugin
import (
"context"
"fmt"
"net"
"path/filepath"
"strings"
"sync"
"time"
"go.uber.org/zap"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/status"
pluginapi "k8s.io/kubelet/pkg/apis/deviceplugin/v1beta1"
"github.com/tensorworks/directx-device-plugins/plugins/internal/discovery"
"github.com/tensorworks/directx-device-plugins/plugins/internal/mount"
)
type DevicePlugin struct {
// The name of the plugin
name string
// The configuration data for the plugin
config *PluginConfig
// The Unix socket on which the plugin's gRPC server listens for connections
endpoint string
endpointDeleted *DeletionWatcher
// The resource name that the plugin advertises to the Kubelet
resourceName string
// The device watcher that monitors the available DirectX devices
watcher *DeviceWatcher
// The most recent device list received from the device watcher, and a mutex to protect concurrent access
currentDevices []*discovery.Device
devicesMutex sync.Mutex
// The logger used to log diagnostic information
logger *zap.SugaredLogger
// The gRPC server that services requests from the Kubelet
server *grpc.Server
// The channel used to trigger a restart of the gRPC server in the event of a Kubelet restart
restart chan struct{}
// The channel used to stop the ListAndWatch streaming RPC during server shutdown
stopListWatch chan struct{}
// The channel used for reporting errors while the gRPC server is running
Errors chan error
}
// Creates a new device plugin
func NewDevicePlugin(pluginName string, pluginVersion string, resourceName string, filter discovery.DeviceFilter, config *PluginConfig, logger *zap.SugaredLogger) (*DevicePlugin, error) {
// Attempt to create a new DeviceWatcher
watcher, err := NewDeviceWatcher(
pluginVersion,
filter,
config.IncludeIntegrated,
config.IncludeDetachable,
config.AdditionalMounts,
config.AdditionalMountsWow64,
logger,
)
if err != nil {
return nil, err
}
// Verify that device watcher can successfully list devices
select {
case <-watcher.Updates:
logger.Info("Initial device list retrieved successfully")
case <-watcher.Errors:
watcher.Destroy()
return nil, fmt.Errorf("failed to perform device discovery: %v", err)
}
// Create a new device plugin instance with the device watcher
plugin := &DevicePlugin{
name: pluginName,
config: config,
endpoint: "",
endpointDeleted: nil,
resourceName: resourceName,
watcher: watcher,
currentDevices: []*discovery.Device{},
devicesMutex: sync.Mutex{},
logger: logger,
server: nil,
restart: make(chan struct{}, 1),
stopListWatch: nil,
Errors: make(chan error, 1),
}
// Forward any device watcher errors to the plugin's error channel
go func() {
for err := range plugin.watcher.Errors {
plugin.Errors <- err
}
}()
// Restart the plugin's gRPC server and perform plugin registration again in the event of a Kubelet restart
go func() {
for range plugin.restart {
// Restart the gRPC server with a new Unix socket filename since the Kubelet will delete the old one
if err := plugin.RestartServer(); err != nil {
plugin.Errors <- err
}
// Register the device plugin with the new Kubelet instance
if err := plugin.RegisterWithKubelet(); err != nil {
plugin.Errors <- err
}
}
}()
return plugin, nil
}
// Starts the gRPC server for the device plugin
func (p *DevicePlugin) StartServer() error {
// Create a new gRPC server instance
// (Note that this is necessary to support restarts, since a server instance cannot be reused after it has stopped serving)
p.server = grpc.NewServer()
// Register our service implementation with the gRPC server
p.logger.Info("Registering the service implementation with the gRPC server")
pluginapi.RegisterDevicePluginServer(p.server, p)
// Append a timestamp to the filename for the gRPC server's Unix socket to ensure it is unique
p.endpoint = filepath.Join(pluginapi.DevicePluginPathWindows, fmt.Sprintf("%s-%d.sock", p.name, time.Now().UnixMilli()))
// Attempt to listen for connections on our Unix socket
p.logger.Infow("Listening on endpoint", "endpoint", p.endpoint)
listener, err := net.Listen("unix", p.endpoint)
if err != nil {
return err
}
// Create the shutdown channel for stopping the ListAndWatch streaming RPC
p.stopListWatch = make(chan struct{})
// Create a file deletion watcher for our Unix socket
endpointDeleted, err := WatchForDeletion(p.endpoint)
if err != nil {
return err
}
// We detect Kubelet restarts by detecting the deletion of our socket
p.endpointDeleted = endpointDeleted
go func() {
for {
select {
case err, ok := <-p.endpointDeleted.Errors:
if !ok {
p.logger.Info("DeletionWatcher error channel closed")
return
}
p.Errors <- err
case _, ok := <-p.endpointDeleted.Deleted:
if !ok {
p.logger.Info("DeletionWatcher deletion channel closed")
return
}
p.logger.Info("Endpoint deletion detected, triggering a restart of the gRPC server")
p.restart <- struct{}{}
}
}
}()
// Start the gRPC server in a new goroutine and send any errors back through our error channel
go func() {
p.logger.Info("Starting the gRPC server")
if err := p.server.Serve(listener); err != nil {
p.Errors <- err
}
}()
return nil
}
// Gracefully stops the gRPC server for the device plugin
func (p *DevicePlugin) StopServer() {
// If StopServer() is called before StartServer() then do nothing
if p.server == nil {
return
}
// Stop the ListAndWatch streaming RPC if it is running
close(p.stopListWatch)
// Stop watching our Unix socket for deletion events
p.endpointDeleted.Cancel()
// Attempt to perform a graceful shutdown of the server (this will delete the Unix socket)
p.logger.Info("Gracefully stopping the gRPC server")
p.server.GracefulStop()
p.server = nil
}
// Restarts the gRPC server for the device plugin, generating a new Unix socket filename
func (p *DevicePlugin) RestartServer() error {
p.StopServer()
return p.StartServer()
}
// Destroys our underlying resources
func (p *DevicePlugin) Destroy() {
p.watcher.Destroy()
close(p.restart)
close(p.Errors)
}
// Registers the device plugin with the Kubelet
func (p *DevicePlugin) RegisterWithKubelet() error {
// Set a 60 second timeout when attempting to connect to the Kubelet
ctxConnect, cancelConnect := context.WithTimeout(context.Background(), time.Minute)
defer cancelConnect()
// Create a dialler that treats the Kubelet's endpoint as a Unix socket rather than a TCP address
dialler := grpc.WithContextDialer(func(ctx context.Context, address string) (net.Conn, error) {
return (&net.Dialer{}).DialContext(ctx, "unix", address)
})
// Attempt to connect to the Kubelet's gRPC service using the socket path for Windows
p.logger.Infow("Connecting to the Kubelet", "endpoint", pluginapi.KubeletSocketWindows)
conn, err := grpc.DialContext(
ctxConnect,
pluginapi.KubeletSocketWindows,
grpc.WithBlock(),
grpc.WithTransportCredentials(insecure.NewCredentials()),
dialler,
)
if err != nil {
return fmt.Errorf("failed to connect to the Kubelet's gRPC service: %v", err)
}
defer conn.Close()
// Prepare a registration request
request := &pluginapi.RegisterRequest{
Version: pluginapi.Version,
Endpoint: filepath.Base(p.endpoint),
ResourceName: p.resourceName,
}
// Set a 60 second timeout when attempting to register with the Kubelet
ctxRegister, cancelRegister := context.WithTimeout(context.Background(), time.Minute)
defer cancelRegister()
// Create a registration client and attempt to send our registration request
p.logger.Infow("Sending registration request to the Kubelet", "request", request)
client := pluginapi.NewRegistrationClient(conn)
if _, err := client.Register(ctxRegister, request); err != nil {
return fmt.Errorf("failed to register the device plugin with the Kubelet: %v", err)
}
p.logger.Info("Successfully registered the device plugin with the Kubelet")
return nil
}
func (p *DevicePlugin) GetDevicePluginOptions(ctx context.Context, request *pluginapi.Empty) (*pluginapi.DevicePluginOptions, error) {
// Instruct the Kubelet not to call the GetPreferredAllocation or PreStartContainer RPCs, since they aren't necessary
p.logger.Info("GetDevicePluginOptions RPC invoked")
return &pluginapi.DevicePluginOptions{
GetPreferredAllocationAvailable: false,
PreStartRequired: false,
}, nil
}
func (p *DevicePlugin) ListAndWatch(request *pluginapi.Empty, stream pluginapi.DevicePlugin_ListAndWatchServer) error {
// Force a device list refresh to ensure we have an initial list for the Kubelet
p.logger.Info("ListAndWatch streaming RPC started, refreshing the device list")
p.watcher.ForceRefresh()
// Continue sending updates as our device list changes or until shutdown is requested
for {
select {
case <-p.stopListWatch:
p.logger.Info("Shutdown requested, stopping ListAndWatch streaming RPC")
return nil
case <-stream.Context().Done():
p.logger.Info("Kubelet disconnect detected, stopping ListAndWatch streaming RPC")
return nil
case devices := <-p.watcher.Updates:
p.logger.Infow("Received new device list", "devices", devices)
// Store the device list
p.devicesMutex.Lock()
p.currentDevices = devices
p.devicesMutex.Unlock()
// Convert the device discovery devices to Kubernetes device plugin API devices
kubeletDevices := []*pluginapi.Device{}
for _, device := range devices {
// Advertise each device multiple times, as per our multitenancy setting
for i := uint32(0); i < p.config.Multitenancy; i += 1 {
kubeletDevices = append(kubeletDevices, &pluginapi.Device{
ID: fmt.Sprintf("%s\\%d", device.ID, i),
Health: pluginapi.Healthy,
})
}
}
// Send the device list to the Kubelet
p.logger.Infow("Sending device list to Kubelet", "devices", kubeletDevices)
stream.Send(&pluginapi.ListAndWatchResponse{
Devices: kubeletDevices,
})
}
}
}
func (p *DevicePlugin) GetPreferredAllocation(context.Context, *pluginapi.PreferredAllocationRequest) (*pluginapi.PreferredAllocationResponse, error) {
// This RPC should never be called
return nil, status.Error(codes.Unimplemented, "GetPreferredAllocation is not implemented")
}
// Retrieves the device with the specified ID
func (p *DevicePlugin) GetDeviceForID(deviceID string) (*discovery.Device, error) {
// Strip the multitenancy suffix from the device ID
backslash := strings.LastIndex(deviceID, "\\")
if backslash == -1 {
return nil, fmt.Errorf("malformed device ID \"%s\"", deviceID)
}
stripped := deviceID[0:backslash]
// Lock the mutex for the device list
p.devicesMutex.Lock()
defer p.devicesMutex.Unlock()
// Search for a device with the specified ID
for _, device := range p.currentDevices {
if device.ID == stripped {
return device, nil
}
}
return nil, fmt.Errorf("could not find device with ID \"%s\"", stripped)
}
func (p *DevicePlugin) Allocate(ctx context.Context, request *pluginapi.AllocateRequest) (*pluginapi.AllocateResponse, error) {
p.logger.Infow("Allocate RPC invoked, processing allocation request", "request", request)
response := &pluginapi.AllocateResponse{}
// Process each of the container requests
for _, containerReq := range request.ContainerRequests {
// Gather the list of requested devices for the container
devices := []*discovery.Device{}
for _, deviceID := range containerReq.DevicesIDs {
// Verify that the requested device exists
device, err := p.GetDeviceForID(deviceID)
if err != nil {
return nil, err
}
// Add the device to the list
devices = append(devices, device)
}
// Generate the device specs and runtime file mounts for the requested devices, appending the container response to our overall response
response.ContainerResponses = append(response.ContainerResponses, &pluginapi.ContainerAllocateResponse{
Devices: mount.SpecsForDevices(devices),
Mounts: mount.MountsForDevices(devices),
})
}
p.logger.Infow("Sending allocation response", "response", response)
return response, nil
}
func (p *DevicePlugin) PreStartContainer(context.Context, *pluginapi.PreStartContainerRequest) (*pluginapi.PreStartContainerResponse, error) {
// This RPC should never be called
return nil, status.Error(codes.Unimplemented, "PreStartContainer is not implemented")
}