pkg/gpu/nvidia/beta_plugin.go (108 lines of code) (raw):

// Copyright 2017 Google Inc. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package nvidia import ( "fmt" "net" "time" "github.com/golang/glog" "golang.org/x/net/context" "google.golang.org/grpc" pluginapi "k8s.io/kubelet/pkg/apis/deviceplugin/v1beta1" "github.com/GoogleCloudPlatform/container-engine-accelerators/pkg/gpu/nvidia/gpusharing" ) type pluginServiceV1Beta1 struct { ngm *nvidiaGPUManager } func (s *pluginServiceV1Beta1) GetDevicePluginOptions(ctx context.Context, e *pluginapi.Empty) (*pluginapi.DevicePluginOptions, error) { return &pluginapi.DevicePluginOptions{}, nil } func (s *pluginServiceV1Beta1) ListAndWatch(emtpy *pluginapi.Empty, stream pluginapi.DevicePlugin_ListAndWatchServer) error { glog.Infoln("device-plugin: ListAndWatch start") if err := s.sendDevices(stream); err != nil { return err } for { select { case d := <-s.ngm.Health: glog.Infof("device-plugin: %s device marked as %s", d.ID, d.Health) s.ngm.SetDeviceHealth(d.ID, d.Health, d.Topology) if err := s.sendDevices(stream); err != nil { return err } } } } func (s *pluginServiceV1Beta1) Allocate(ctx context.Context, requests *pluginapi.AllocateRequest) (*pluginapi.AllocateResponse, error) { resps := new(pluginapi.AllocateResponse) for _, rqt := range requests.ContainerRequests { // Validate if the request is for shared GPUs and check if the request meets the GPU sharing conditions. if err := gpusharing.ValidateRequest(rqt.DevicesIDs, len(s.ngm.ListPhysicalDevices())); err != nil { return nil, err } resp := new(pluginapi.ContainerAllocateResponse) // Add all requested devices to Allocate Response for _, id := range rqt.DevicesIDs { devices, err := s.ngm.DeviceSpec(id) if err != nil { return nil, err } for i := range devices { resp.Devices = append(resp.Devices, &devices[i]) } } // Add all default devices to Allocate Response for _, d := range s.ngm.defaultDevices { resp.Devices = append(resp.Devices, &pluginapi.DeviceSpec{ HostPath: d, ContainerPath: d, Permissions: "mrw", }) } for i := range s.ngm.mountPaths { resp.Mounts = append(resp.Mounts, &s.ngm.mountPaths[i]) } resp.Envs = s.ngm.Envs(len(rqt.DevicesIDs)) resps.ContainerResponses = append(resps.ContainerResponses, resp) } return resps, nil } func (s *pluginServiceV1Beta1) PreStartContainer(ctx context.Context, r *pluginapi.PreStartContainerRequest) (*pluginapi.PreStartContainerResponse, error) { glog.Errorf("device-plugin: PreStart should NOT be called for GKE nvidia GPU device plugin\n") return &pluginapi.PreStartContainerResponse{}, nil } func (s *pluginServiceV1Beta1) GetPreferredAllocation(context.Context, *pluginapi.PreferredAllocationRequest) (*pluginapi.PreferredAllocationResponse, error) { glog.Errorf("device-plugin: GetPreferredAllocation should NOT be called for GKE nvidia GPU device plugin\n") return &pluginapi.PreferredAllocationResponse{}, nil } func (s *pluginServiceV1Beta1) RegisterService() { pluginapi.RegisterDevicePluginServer(s.ngm.grpcServer, s) } // TODO: remove this function once we move to probe based registration. func RegisterWithV1Beta1Kubelet(kubeletEndpoint, pluginEndpoint, resourceName string) error { conn, err := grpc.Dial(kubeletEndpoint, grpc.WithInsecure(), grpc.WithDialer(func(addr string, timeout time.Duration) (net.Conn, error) { return net.DialTimeout("unix", addr, timeout) })) if err != nil { return fmt.Errorf("device-plugin: cannot connect to kubelet service: %v", err) } defer conn.Close() client := pluginapi.NewRegistrationClient(conn) request := &pluginapi.RegisterRequest{ Version: pluginapi.Version, Endpoint: pluginEndpoint, ResourceName: resourceName, } if _, err = client.Register(context.Background(), request); err != nil { return fmt.Errorf("device-plugin: cannot register to kubelet service: %v", err) } return nil } func (s *pluginServiceV1Beta1) sendDevices(stream pluginapi.DevicePlugin_ListAndWatchServer) error { resp := new(pluginapi.ListAndWatchResponse) for _, dev := range s.ngm.ListDevices() { resp.Devices = append(resp.Devices, &pluginapi.Device{ID: dev.ID, Health: dev.Health, Topology: dev.Topology}) } glog.Infof("ListAndWatch: send devices %v\n", resp) if err := stream.Send(resp); err != nil { glog.Errorf("device-plugin: cannot update device states: %v\n", err) s.ngm.grpcServer.Stop() return err } return nil }