pkg/agent/datamodel/gpu_components.go (94 lines of code) (raw):
package datamodel
import (
"encoding/json"
"fmt"
"strings"
"github.com/Azure/agentbaker/parts"
)
const Nvidia470CudaDriverVersion = "cuda-470.82.01"
//nolint:gochecknoglobals
var (
NvidiaCudaDriverVersion string
NvidiaGridDriverVersion string
AKSGPUCudaVersionSuffix string
AKSGPUGridVersionSuffix string
)
type gpuVersion struct {
RenovateTag string `json:"renovateTag"`
LatestVersion string `json:"latestVersion"`
}
type gpuContainerImage struct {
DownloadURL string `json:"downloadURL"`
GPUVersion gpuVersion `json:"gpuVersion"`
}
type componentsConfig struct {
GPUContainerImages []gpuContainerImage `json:"GPUContainerImages"`
}
func LoadConfig() error {
// Read the embedded components.json file
data, err := parts.Templates.ReadFile("common/components.json")
if err != nil {
return fmt.Errorf("failed to read components.json: %w", err)
}
var config componentsConfig
if err := json.Unmarshal(data, &config); err != nil {
return fmt.Errorf("failed to unmarshal components.json: %w", err)
}
const driverIndex = 0
const suffixIndex = 1
const expectedLength = 2
for _, image := range config.GPUContainerImages {
parts := strings.Split(image.GPUVersion.LatestVersion, "-")
if len(parts) != expectedLength {
continue
}
version, suffix := parts[driverIndex], parts[suffixIndex]
if strings.Contains(image.DownloadURL, "aks-gpu-cuda") {
NvidiaCudaDriverVersion = version
AKSGPUCudaVersionSuffix = suffix
} else if strings.Contains(image.DownloadURL, "aks-gpu-grid") {
NvidiaGridDriverVersion = version
AKSGPUGridVersionSuffix = suffix
}
}
return nil
}
//nolint:gochecknoinits
func init() {
if err := LoadConfig(); err != nil {
panic(fmt.Sprintf("Failed to load configuration: %v", err))
}
}
/* convergedGPUDriverSizes : these sizes use a "converged" driver to support both cuda/grid workloads.
how do you figure this out? ask HPC or find out by trial and error.
installing vanilla cuda drivers will fail to install with opaque errors.
nvidia-bug-report.sh may be helpful, but usually it tells you the pci card id is incompatible.
That sends me to HPC folks.
see https://github.com/Azure/azhpc-extensions/blob/daaefd78df6f27012caf30f3b54c3bd6dc437652/NvidiaGPU/resources.json
*/
//nolint:gochecknoglobals
var ConvergedGPUDriverSizes = map[string]bool{
"standard_nv6ads_a10_v5": true,
"standard_nv12ads_a10_v5": true,
"standard_nv18ads_a10_v5": true,
"standard_nv36ads_a10_v5": true,
"standard_nv72ads_a10_v5": true,
"standard_nv36adms_a10_v5": true,
"standard_nc8ads_a10_v4": true,
"standard_nc16ads_a10_v4": true,
"standard_nc32ads_a10_v4": true,
}
//nolint:gochecknoglobals
var FabricManagerGPUSizes = map[string]bool{
// A100
"standard_nd96asr_v4": true,
"standard_nd112asr_a100_v4": true,
"standard_nd120asr_a100_v4": true,
"standard_nd96amsr_a100_v4": true,
"standard_nd112amsr_a100_v4": true,
"standard_nd120amsr_a100_v4": true,
// TODO(ace): one of these is probably dupe...
// confirm with HPC/SKU owners.
"standard_nd96ams_a100_v4": true,
"standard_nd96ams_v4": true,
// H100.
"standard_nd46s_h100_v5": true,
"standard_nd48s_h100_v5": true,
"standard_nd50s_h100_v5": true,
"standard_nd92is_h100_v5": true,
"standard_nd96is_h100_v5": true,
"standard_nd100is_h100_v5": true,
"standard_nd92isr_h100_v5": true,
"standard_nd96isr_h100_v5": true,
"standard_nd100isr_h100_v5": true,
// H200
"standard_nd96is_h200_v5": true,
"standard_nd96isr_h200_v5": true,
"standard_nd96isrf_h200_v5": true,
// A100 oddballs.
"standard_nc24ads_a100_v4": false, // NCads_v4 will fail to start fabricmanager.
"standard_nc48ads_a100_v4": false,
"standard_nc96ads_a100_v4": false,
}