pkg/gpu/nvidia/gpusharing/gpusharing.go (41 lines of code) (raw):
// Copyright 2021 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package gpusharing
import (
"errors"
"fmt"
"regexp"
)
type GPUSharingStrategy string
const (
Undefined GPUSharingStrategy = ""
TimeSharing GPUSharingStrategy = "time-sharing"
MPS GPUSharingStrategy = "mps"
)
var SharingStrategy GPUSharingStrategy
// ValidateRequest will first check if the input device IDs are virtual device IDs, and then validate the request.
// A valid sharing request (time-sharing)should meet the following conditions:
// 1. it is only valid to request one virtual devices in a single request.
// A valid sharing request (mps) should meet the following conditions:
// 1. if there is only one physical device, it is valid to request multiple virtual devices in a single request.
// 2. if there are multiple physical devices, it is only valid to request one virtual device in a single request.
// Note: in this validation, each MIG partition will be regarded as a physical device.
func ValidateRequest(requestDevicesIDs []string, deviceCount int) error {
if len(requestDevicesIDs) > 1 && IsVirtualDeviceID(requestDevicesIDs[0]) {
if SharingStrategy == TimeSharing {
return errors.New("invalid request for sharing GPU (time-sharing), at most 1 nvidia.com/gpu can be requested on GPU nodes")
} else if SharingStrategy == MPS && deviceCount > 1 {
return errors.New("invalid request for sharing GPU (MPS), at most 1 nvidia.com/gpu can be requested on multi-GPU nodes")
}
}
return nil
}
// VirtualToPhysicalDeviceID takes a virtualDeviceID and converts it to a physicalDeviceID.
func VirtualToPhysicalDeviceID(virtualDeviceID string) (string, error) {
if !IsVirtualDeviceID(virtualDeviceID) {
return "", fmt.Errorf("virtual device ID (%s) is not valid", virtualDeviceID)
}
vgpuRegex := regexp.MustCompile("/vgpu([0-9]+)$")
return vgpuRegex.Split(virtualDeviceID, -1)[0], nil
}
// isVirtualDeviceID returns true if a input device ID comes from a virtual GPU device.
func IsVirtualDeviceID(virtualDeviceID string) bool {
return isVirtualDeviceIDForDefaultMode(virtualDeviceID) || isVirtualDeviceIDForMIGMode(virtualDeviceID)
}
func isVirtualDeviceIDForDefaultMode(virtualDeviceID string) bool {
// Generally, the virtualDeviceID will form as 'nvidia0/vgpu0', with the underlying physicalDeviceID as 'nvidia0'.
validRegex := regexp.MustCompile("nvidia([0-9]+)\\/vgpu([0-9]+)$")
return validRegex.MatchString(virtualDeviceID)
}
func isVirtualDeviceIDForMIGMode(virtualDeviceID string) bool {
// In MIG case, the virtualDeviceID will form as `nvidia0/gi0/vgpu0`, with the underlying physicalDeviceID as 'nvidia0/gi0'.
validMigRegex := regexp.MustCompile("nvidia([0-9]+)\\/gi([0-9]+)\\/vgpu([0-9]+)$")
return validMigRegex.MatchString(virtualDeviceID)
}