github.com/containerd/nerdctl/v2@v2.0.0-beta.5.0.20240520001846-b5758f54fa28/pkg/cmd/container/run_gpus.go (about) 1 /* 2 Copyright The containerd Authors. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 package container 18 19 import ( 20 "encoding/csv" 21 "errors" 22 "fmt" 23 "strconv" 24 "strings" 25 26 "github.com/containerd/containerd/contrib/nvidia" 27 "github.com/containerd/containerd/oci" 28 "github.com/containerd/nerdctl/v2/pkg/rootlessutil" 29 ) 30 31 // GPUReq is a request for GPUs. 32 type GPUReq struct { 33 Count int 34 DeviceIDs []string 35 Capabilities []string 36 } 37 38 func parseGPUOpts(value []string) (res []oci.SpecOpts, _ error) { 39 for _, gpu := range value { 40 gpuOpt, err := parseGPUOpt(gpu) 41 if err != nil { 42 return nil, err 43 } 44 res = append(res, gpuOpt) 45 } 46 return res, nil 47 } 48 49 func parseGPUOpt(value string) (oci.SpecOpts, error) { 50 req, err := ParseGPUOptCSV(value) 51 if err != nil { 52 return nil, err 53 } 54 55 var gpuOpts []nvidia.Opts 56 57 if len(req.DeviceIDs) > 0 { 58 gpuOpts = append(gpuOpts, nvidia.WithDeviceUUIDs(req.DeviceIDs...)) 59 } else if req.Count > 0 { 60 var devices []int 61 for i := 0; i < req.Count; i++ { 62 devices = append(devices, i) 63 } 64 gpuOpts = append(gpuOpts, nvidia.WithDevices(devices...)) 65 } else if req.Count < 0 { 66 gpuOpts = append(gpuOpts, nvidia.WithAllDevices) 67 } 68 69 str2cap := make(map[string]nvidia.Capability) 70 for _, c := range nvidia.AllCaps() { 71 str2cap[string(c)] = c 72 } 73 var nvidiaCaps []nvidia.Capability 74 for _, c := range req.Capabilities { 75 if cp, isNvidiaCap := str2cap[c]; isNvidiaCap { 76 nvidiaCaps = append(nvidiaCaps, cp) 77 } 78 } 79 if len(nvidiaCaps) != 0 { 80 gpuOpts = append(gpuOpts, nvidia.WithCapabilities(nvidiaCaps...)) 81 } else { 82 // Add "utility", "compute" capability if unset. 83 // Please see also: https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/user-guide.html#driver-capabilities 84 gpuOpts = append(gpuOpts, nvidia.WithCapabilities(nvidia.Utility, nvidia.Compute)) 85 } 86 87 if rootlessutil.IsRootless() { 88 // "--no-cgroups" option is needed to nvidia-container-cli in rootless environment 89 // Please see also: https://github.com/moby/moby/issues/38729#issuecomment-463493866 90 gpuOpts = append(gpuOpts, nvidia.WithNoCgroups) 91 } 92 93 return nvidia.WithGPUs(gpuOpts...), nil 94 } 95 96 // ParseGPUOptCSV parses a GPU option from CSV. 97 func ParseGPUOptCSV(value string) (*GPUReq, error) { 98 csvReader := csv.NewReader(strings.NewReader(value)) 99 fields, err := csvReader.Read() 100 if err != nil { 101 return nil, err 102 } 103 104 var ( 105 req GPUReq 106 seen = map[string]struct{}{} 107 ) 108 for _, field := range fields { 109 parts := strings.SplitN(field, "=", 2) 110 key := parts[0] 111 if _, ok := seen[key]; ok { 112 return nil, fmt.Errorf("gpu request key '%s' can be specified only once", key) 113 } 114 seen[key] = struct{}{} 115 116 if len(parts) == 1 { 117 seen["count"] = struct{}{} 118 req.Count, err = parseCount(key) 119 if err != nil { 120 return nil, err 121 } 122 continue 123 } 124 125 value := parts[1] 126 switch key { 127 case "driver": 128 if value != "nvidia" { 129 return nil, fmt.Errorf("invalid driver %q: \"nvidia\" is only supported", value) 130 } 131 case "count": 132 req.Count, err = parseCount(value) 133 if err != nil { 134 return nil, err 135 } 136 case "device": 137 req.DeviceIDs = strings.Split(value, ",") 138 case "capabilities": 139 req.Capabilities = strings.Split(value, ",") 140 case "options": 141 // This option is allowed but not used for gpus. 142 // Please see also: https://github.com/moby/moby/pull/38828 143 default: 144 return nil, fmt.Errorf("unexpected key '%s' in '%s'", key, field) 145 } 146 } 147 148 if req.Count != 0 && len(req.DeviceIDs) > 0 { 149 return nil, errors.New("cannot set both Count and DeviceIDs on device request") 150 } 151 if _, ok := seen["count"]; !ok && len(req.DeviceIDs) == 0 { 152 req.Count = 1 153 } 154 155 return &req, nil 156 } 157 158 func parseCount(s string) (int, error) { 159 if s == "all" { 160 return -1, nil 161 } 162 i, err := strconv.Atoi(s) 163 if err != nil { 164 return i, fmt.Errorf("count must be an integer: %w", err) 165 } 166 return i, nil 167 }