plugins/internal/plugin/device_watcher.go (126 lines of code) (raw):

//go:build windows package plugin import ( "context" "fmt" "strings" "time" "github.com/tensorworks/directx-device-plugins/plugins/internal/discovery" "go.uber.org/zap" ) // Watches for device updates type DeviceWatcher struct { // Our interface to the underlying DeviceDiscovery object from the DirectX device discovery library deviceDiscovery *discovery.DeviceDiscovery // The filter used to control which devices are reported deviceFilter discovery.DeviceFilter // Whether to include integrated GPUs when reporting devices includeIntegrated bool // Whether to include detachable devices when reporting devices includeDetachable bool // The list of additional runtime files for each device vendor that will be added to each device's list for System32 additionalRuntimeFiles map[string][]*discovery.RuntimeFile // The list of additional runtime files for each device vendor that will be added to each device's list for SysWOW64 additionalRuntimeFilesWow64 map[string][]*discovery.RuntimeFile // The logger used to log diagnostic information logger *zap.SugaredLogger // The channel used to request a forced refresh of the device list refresh chan struct{} // The channel used to stop the device discovery goroutine shutdown chan struct{} // The channel used to report errors Errors chan error // The channel used to report device updates Updates chan []*discovery.Device } func NewDeviceWatcher( expectedVersion string, deviceFilter discovery.DeviceFilter, includeIntegrated bool, includeDetachable bool, additionalRuntimeFiles map[string][]*discovery.RuntimeFile, additionalRuntimeFilesWow64 map[string][]*discovery.RuntimeFile, logger *zap.SugaredLogger, ) (*DeviceWatcher, error) { // Attempt to load the DirectX device discovery library if err := discovery.LoadDiscoveryLibrary(); err != nil { return nil, err } // Verify that the version of the device discovery library matches our expected version libraryVersion := discovery.GetDiscoveryLibraryVersion() if libraryVersion != expectedVersion { return nil, fmt.Errorf( "device discovery library version mismatch (found %s, expected %s)", libraryVersion, expectedVersion, ) } // Enable verbose logging for the device discovery library discovery.EnableDiscoveryLogging() // Create a new DeviceDiscovery object deviceDiscovery, err := discovery.NewDeviceDiscovery() if err != nil { return nil, err } // Create the DeviceWatcher watcher := &DeviceWatcher{ deviceDiscovery: deviceDiscovery, deviceFilter: deviceFilter, includeIntegrated: includeIntegrated, includeDetachable: includeDetachable, additionalRuntimeFiles: additionalRuntimeFiles, additionalRuntimeFilesWow64: additionalRuntimeFilesWow64, logger: logger, refresh: make(chan struct{}, 1), shutdown: make(chan struct{}), Errors: make(chan error, 1), Updates: make(chan []*discovery.Device, 1), } // Start the watcher goroutine go watcher.watchDevices() return watcher, nil } // Stops our goroutine and destroys the underlying DeviceDiscovery object func (d *DeviceWatcher) Destroy() { close(d.shutdown) close(d.refresh) } // Forces a refresh of the device list, irrespective of whether the current list is stale func (d *DeviceWatcher) ForceRefresh() { d.refresh <- struct{}{} } // Merges any additional runtime files into the list for a device func (d *DeviceWatcher) mergeRuntimeFiles(device *discovery.Device) { // Determine whether we have any additional runtime files for the device vendor files, haveFiles := d.additionalRuntimeFiles[strings.ToLower(device.Vendor)] filesWow64, haveFilesWow64 := d.additionalRuntimeFilesWow64[strings.ToLower(device.Vendor)] // Merge any additions for System32 if haveFiles { ignored := device.AppendRuntimeFiles(files) for _, file := range ignored { d.logger.Infow("Ignoring additional 64-bit runtime file because it clashes with an existing filename", "file", file) } } // Merge any additions for SysWOW64 if haveFilesWow64 { ignored := device.AppendRuntimeFilesWow64(filesWow64) for _, file := range ignored { d.logger.Infow("Ignoring additional 32-bit runtime file because it clashes with an existing filename", "file", file) } } } // Refreshes the list of devices and reports the new list func (d *DeviceWatcher) refreshDevices() error { // Refresh the list of devices if err := d.deviceDiscovery.DiscoverDevices(d.deviceFilter, d.includeIntegrated, d.includeDetachable); err != nil { return err } // Process any additional runtime files for each device for _, device := range d.deviceDiscovery.Devices { d.mergeRuntimeFiles(device) } // Report the new device list d.Updates <- d.deviceDiscovery.Devices return nil } // The main device watch loop func (d *DeviceWatcher) watchDevices() { // Destroy the underlying DeviceDiscovery object when the loop completes defer d.deviceDiscovery.Destroy() // Use a context for waiting between polling operations rather than sleeping, so we remain responsive to shutdown and refresh events sleep, cancelSleep := context.WithTimeout(context.Background(), time.Second*0) defer cancelSleep() // Continue sending device updates until shutdown is requested: forceRefresh := false for { select { case <-d.shutdown: return case <-d.refresh: forceRefresh = true cancelSleep() case <-sleep.Done(): // Poll for device list changes refresh, err := d.deviceDiscovery.IsRefreshRequired() if err != nil { d.Errors <- err return } // Retrieve the updated device list if one is available or if a forced refresh has been requested if refresh || forceRefresh { if err := d.refreshDevices(); err != nil { d.Errors <- err return } } // Wait 10 seconds before polling again forceRefresh = false sleep, cancelSleep = context.WithTimeout(context.Background(), time.Second*10) defer cancelSleep() } } }