func()

in pkg/gpu/nvidia/mig/mig.go [109:242]


func (d *DeviceManager) Start(partitionSize string) error {
	if partitionSize == "" {
		return nil
	}

	maxPartitionCount, ok := gpuPartitionSizeMaxCount[partitionSize]
	if !ok {
		return fmt.Errorf("%s is not a valid GPU partition size", partitionSize)
	}

	d.gpuPartitionSpecs = make(map[string][]pluginapi.DeviceSpec)

	nvidiaCapDir := path.Join(d.procDirectory, "driver/nvidia/capabilities")
	capFiles, err := ioutil.ReadDir(nvidiaCapDir)
	if err != nil {
		return fmt.Errorf("failed to read capabilities directory (%s): %v", nvidiaCapDir, err)
	}

	gpuFileRegexp := regexp.MustCompile("gpu([0-9]+)")
	giFileRegexp := regexp.MustCompile("gi([0-9]+)")
	deviceRegexp := regexp.MustCompile("DeviceFileMinor: ([0-9]+)")

	numPartitionedGPUs := 0

	for _, capFile := range capFiles {
		m := gpuFileRegexp.FindStringSubmatch(capFile.Name())
		if len(m) != 2 {
			// Not a gpu, continue to next file
			continue
		}

		gpuID := m[1]
		numPartitionedGPUs++

		giBasePath := path.Join(nvidiaCapDir, capFile.Name(), "mig")
		giFiles, err := ioutil.ReadDir(giBasePath)
		if err != nil {
			return fmt.Errorf("failed to read GPU instance capabilities dir (%s): %v", giBasePath, err)
		}

		numPartitions := 0
		for _, giFile := range giFiles {
			if !giFileRegexp.MatchString(giFile.Name()) {
				continue
			}

			numPartitions++

			gpuInstanceID := "nvidia" + gpuID + "/" + giFile.Name()
			giAccessFile := path.Join(giBasePath, giFile.Name(), "access")
			content, err := ioutil.ReadFile(giAccessFile)
			if err != nil {
				return fmt.Errorf("failed to read GPU Instance access file (%s): %v", giAccessFile, err)
			}

			m := deviceRegexp.FindStringSubmatch(string(content))
			if len(m) != 2 {
				return fmt.Errorf("unexpected contents in GPU instance access file(%s): %v", giAccessFile, err)
			}
			giMinorDevice, err := strconv.Atoi(m[1])
			if err != nil {
				return fmt.Errorf("failed to find minor device from GPU instance access file(%s): %v", giAccessFile, err)
			}

			ciAccessFile := path.Join(giBasePath, giFile.Name(), "ci0", "access")
			content, err = ioutil.ReadFile(ciAccessFile)
			if err != nil {
				return fmt.Errorf("unable to read Compute Instance access file (%s): %v", ciAccessFile, err)
			}

			m = deviceRegexp.FindStringSubmatch(string(content))
			if len(m) != 2 {
				return fmt.Errorf("unexpected contents in compute instance access file(%s): %v", ciAccessFile, err)
			}
			ciMinorDevice, err := strconv.Atoi(m[1])
			if err != nil {
				return fmt.Errorf("failed to find minor device from compute instance access file(%s): %v", ciAccessFile, err)
			}

			gpuDevice := path.Join(d.devDirectory, "nvidia"+gpuID)
			if _, err := os.Stat(gpuDevice); err != nil {
				return fmt.Errorf("GPU device (%s) not fount: %v", gpuDevice, err)
			}

			giDevice := path.Join(d.devDirectory, "nvidia-caps", "nvidia-cap"+strconv.Itoa(giMinorDevice))
			if _, err := os.Stat(giDevice); err != nil {
				return fmt.Errorf("GPU instance device (%s) not fount: %v", giDevice, err)
			}

			ciDevice := path.Join(d.devDirectory, "nvidia-caps", "nvidia-cap"+strconv.Itoa(ciMinorDevice))
			if _, err := os.Stat(ciDevice); err != nil {
				return fmt.Errorf("Compute instance device (%s) not fount: %v", ciDevice, err)
			}

			glog.Infof("Discovered GPU partition: %s", gpuInstanceID)
			d.gpuPartitionSpecs[gpuInstanceID] = []pluginapi.DeviceSpec{
				{
					ContainerPath: gpuDevice,
					HostPath:      gpuDevice,
					Permissions:   "mrw",
				},
				{
					ContainerPath: giDevice,
					HostPath:      giDevice,
					Permissions:   "mrw",
				},
				{
					ContainerPath: ciDevice,
					HostPath:      ciDevice,
					Permissions:   "mrw",
				},
			}
			topologyInfo, err := d.topology(gpuID)
			if err != nil {
				glog.Errorf("unable to get topology for device with index %d: %v", gpuID, err)
			}
			d.gpuPartitions[gpuInstanceID] = pluginapi.Device{ID: gpuInstanceID, Health: pluginapi.Healthy, Topology: topologyInfo}
		}

		if numPartitions != maxPartitionCount {
			return fmt.Errorf("Number of partitions (%d) for GPU %s does not match expected partition count (%d)", numPartitions, gpuID, maxPartitionCount)
		}
	}

	numGPUs, err := d.discoverNumGPUs()
	if err != nil {
		return err
	}
	if numPartitionedGPUs != numGPUs {
		return fmt.Errorf("Not all GPUs are partitioned as expected. Total number of GPUs: %d, number of partitioned GPUs: %d", numGPUs, numPartitionedGPUs)
	}

	return nil
}