plugins/internal/plugin/plugin_configuration.go (88 lines of code) (raw):
//go:build windows
package plugin
import (
"errors"
"fmt"
"io/fs"
"os"
"path/filepath"
"strings"
"github.com/spf13/viper"
"github.com/tensorworks/directx-device-plugins/plugins/internal/discovery"
"github.com/tensorworks/directx-device-plugins/plugins/internal/mount"
"go.uber.org/zap"
"golang.org/x/exp/maps"
)
// PluginConfig represents the available configuration options for a device plugin
type PluginConfig struct {
// The number of containers that can access each device simultaneously (set this to 1 for exclusive access)
Multitenancy uint32
// Specifies whether we advertise integrated devices (i.e. integrated GPUs)
IncludeIntegrated bool
// Specifies whether we advertise detachable devices (e.g. external GPUs)
IncludeDetachable bool
// The list of additional runtime files to be mounted to System32 for each device vendor
AdditionalMounts map[string][]*discovery.RuntimeFile
// The list of additional runtime files to be mounted to SysWOW64 for each device vendor
AdditionalMountsWow64 map[string][]*discovery.RuntimeFile
}
// Appends a default set of mounts to the supplied mounts, converting all vendor names to lower case to ensure consistency
func appendMounts(mounts map[string][]*discovery.RuntimeFile, defaults map[string][]*discovery.RuntimeFile) map[string][]*discovery.RuntimeFile {
// Gather the set of unique vendor names, converting all names to lower case
vendors := make(map[string]bool)
for _, vendor := range append(maps.Keys(mounts), maps.Keys(defaults)...) {
vendorLower := strings.ToLower(vendor)
if !vendors[vendorLower] {
vendors[vendorLower] = true
}
}
// Process the mounts for each vendor in turn
appended := make(map[string][]*discovery.RuntimeFile)
for vendor := range vendors {
appended[vendor] = []*discovery.RuntimeFile{}
// Add the mounts for the vendor if we have any
vendorMounts, haveMounts := mounts[vendor]
if haveMounts {
appended[vendor] = append(appended[vendor], vendorMounts...)
}
// Add the defaults for the vendor if we have any
vendorDefaults, haveDefaults := defaults[vendor]
if haveDefaults {
appended[vendor] = append(appended[vendor], vendorDefaults...)
}
}
return appended
}
// Load loads the configuration data from the runtime environment.
func LoadConfig(pluginName string, logger *zap.SugaredLogger) (*PluginConfig, error) {
// Set our default configuration values
v := viper.New()
v.SetDefault("multitenancy", 0)
v.SetDefault("includeIntegrated", false)
v.SetDefault("includeDetachable", false)
v.SetDefault("additionalMounts", make(map[string][]*discovery.RuntimeFile))
v.SetDefault("additionalMountsWow64", make(map[string][]*discovery.RuntimeFile))
// The names of our environment variables reflect the plugin name
envPrefix := fmt.Sprint(strings.ToUpper(pluginName), "_DEVICE_PLUGIN_")
v.BindEnv("multitenancy", fmt.Sprint(envPrefix, "MULTITENANCY"))
v.BindEnv("includeIntegrated", fmt.Sprint(envPrefix, "INCLUDE_INTEGRATED"))
v.BindEnv("includeDetachable", fmt.Sprint(envPrefix, "INCLUDE_DETACHABLE"))
// Check if a config file path was explicitly specified through an environment variable
configPath, configPathExists := os.LookupEnv(fmt.Sprint(envPrefix, "CONFIG_FILE"))
if configPathExists {
// Verify that the specified value is an absolute path
if !filepath.IsAbs(configPath) {
return nil, errors.New("configuration file path must be an absolute path")
}
// Verify that the specified file exists
if _, err := os.Stat(configPath); errors.Is(err, fs.ErrNotExist) {
return nil, fmt.Errorf("specified configuration file does not exist: %s", configPath)
}
// Use the specified path
v.SetConfigFile(configPath)
} else {
// The default name of our YAML configuration file reflects the plugin name
v.SetConfigName(pluginName)
v.SetConfigType("yaml")
// We search for the configuration file in both our global config directory and the current working directory
v.AddConfigPath(".")
v.AddConfigPath("\\etc\\directx-device-plugins")
}
// Attempt to parse our YAML configuration file if it exists
if err := v.ReadInConfig(); err != nil {
if _, ok := err.(viper.ConfigFileNotFoundError); ok {
logger.Infow("Configuration file not found, using configuration values from environment variables")
} else {
return nil, err
}
}
// Load the parsed configuration values into our struct
c := &PluginConfig{}
if err := v.Unmarshal(c); err != nil {
return nil, err
}
// Enforce a minimum value of 1 for multitenancy
if c.Multitenancy == 0 {
c.Multitenancy = 1
}
// Append our default mounts to any user-supplied values
c.AdditionalMounts = appendMounts(c.AdditionalMounts, mount.DefaultMounts)
c.AdditionalMountsWow64 = appendMounts(c.AdditionalMountsWow64, mount.DefaultMountsWow64)
// Log the parsed configuration values
logger.Infow("Parsed configuration data", "config", c)
return c, nil
}