github.com/containerd/nerdctl/v2@v2.0.0-beta.5.0.20240520001846-b5758f54fa28/pkg/cmd/container/run_gpus.go (about)

     1  /*
     2     Copyright The containerd Authors.
     3  
     4     Licensed under the Apache License, Version 2.0 (the "License");
     5     you may not use this file except in compliance with the License.
     6     You may obtain a copy of the License at
     7  
     8         http://www.apache.org/licenses/LICENSE-2.0
     9  
    10     Unless required by applicable law or agreed to in writing, software
    11     distributed under the License is distributed on an "AS IS" BASIS,
    12     WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13     See the License for the specific language governing permissions and
    14     limitations under the License.
    15  */
    16  
    17  package container
    18  
    19  import (
    20  	"encoding/csv"
    21  	"errors"
    22  	"fmt"
    23  	"strconv"
    24  	"strings"
    25  
    26  	"github.com/containerd/containerd/contrib/nvidia"
    27  	"github.com/containerd/containerd/oci"
    28  	"github.com/containerd/nerdctl/v2/pkg/rootlessutil"
    29  )
    30  
    31  // GPUReq is a request for GPUs.
    32  type GPUReq struct {
    33  	Count        int
    34  	DeviceIDs    []string
    35  	Capabilities []string
    36  }
    37  
    38  func parseGPUOpts(value []string) (res []oci.SpecOpts, _ error) {
    39  	for _, gpu := range value {
    40  		gpuOpt, err := parseGPUOpt(gpu)
    41  		if err != nil {
    42  			return nil, err
    43  		}
    44  		res = append(res, gpuOpt)
    45  	}
    46  	return res, nil
    47  }
    48  
    49  func parseGPUOpt(value string) (oci.SpecOpts, error) {
    50  	req, err := ParseGPUOptCSV(value)
    51  	if err != nil {
    52  		return nil, err
    53  	}
    54  
    55  	var gpuOpts []nvidia.Opts
    56  
    57  	if len(req.DeviceIDs) > 0 {
    58  		gpuOpts = append(gpuOpts, nvidia.WithDeviceUUIDs(req.DeviceIDs...))
    59  	} else if req.Count > 0 {
    60  		var devices []int
    61  		for i := 0; i < req.Count; i++ {
    62  			devices = append(devices, i)
    63  		}
    64  		gpuOpts = append(gpuOpts, nvidia.WithDevices(devices...))
    65  	} else if req.Count < 0 {
    66  		gpuOpts = append(gpuOpts, nvidia.WithAllDevices)
    67  	}
    68  
    69  	str2cap := make(map[string]nvidia.Capability)
    70  	for _, c := range nvidia.AllCaps() {
    71  		str2cap[string(c)] = c
    72  	}
    73  	var nvidiaCaps []nvidia.Capability
    74  	for _, c := range req.Capabilities {
    75  		if cp, isNvidiaCap := str2cap[c]; isNvidiaCap {
    76  			nvidiaCaps = append(nvidiaCaps, cp)
    77  		}
    78  	}
    79  	if len(nvidiaCaps) != 0 {
    80  		gpuOpts = append(gpuOpts, nvidia.WithCapabilities(nvidiaCaps...))
    81  	} else {
    82  		// Add "utility", "compute" capability if unset.
    83  		// Please see also: https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/user-guide.html#driver-capabilities
    84  		gpuOpts = append(gpuOpts, nvidia.WithCapabilities(nvidia.Utility, nvidia.Compute))
    85  	}
    86  
    87  	if rootlessutil.IsRootless() {
    88  		// "--no-cgroups" option is needed to nvidia-container-cli in rootless environment
    89  		// Please see also: https://github.com/moby/moby/issues/38729#issuecomment-463493866
    90  		gpuOpts = append(gpuOpts, nvidia.WithNoCgroups)
    91  	}
    92  
    93  	return nvidia.WithGPUs(gpuOpts...), nil
    94  }
    95  
    96  // ParseGPUOptCSV parses a GPU option from CSV.
    97  func ParseGPUOptCSV(value string) (*GPUReq, error) {
    98  	csvReader := csv.NewReader(strings.NewReader(value))
    99  	fields, err := csvReader.Read()
   100  	if err != nil {
   101  		return nil, err
   102  	}
   103  
   104  	var (
   105  		req  GPUReq
   106  		seen = map[string]struct{}{}
   107  	)
   108  	for _, field := range fields {
   109  		parts := strings.SplitN(field, "=", 2)
   110  		key := parts[0]
   111  		if _, ok := seen[key]; ok {
   112  			return nil, fmt.Errorf("gpu request key '%s' can be specified only once", key)
   113  		}
   114  		seen[key] = struct{}{}
   115  
   116  		if len(parts) == 1 {
   117  			seen["count"] = struct{}{}
   118  			req.Count, err = parseCount(key)
   119  			if err != nil {
   120  				return nil, err
   121  			}
   122  			continue
   123  		}
   124  
   125  		value := parts[1]
   126  		switch key {
   127  		case "driver":
   128  			if value != "nvidia" {
   129  				return nil, fmt.Errorf("invalid driver %q: \"nvidia\" is only supported", value)
   130  			}
   131  		case "count":
   132  			req.Count, err = parseCount(value)
   133  			if err != nil {
   134  				return nil, err
   135  			}
   136  		case "device":
   137  			req.DeviceIDs = strings.Split(value, ",")
   138  		case "capabilities":
   139  			req.Capabilities = strings.Split(value, ",")
   140  		case "options":
   141  			// This option is allowed but not used for gpus.
   142  			// Please see also: https://github.com/moby/moby/pull/38828
   143  		default:
   144  			return nil, fmt.Errorf("unexpected key '%s' in '%s'", key, field)
   145  		}
   146  	}
   147  
   148  	if req.Count != 0 && len(req.DeviceIDs) > 0 {
   149  		return nil, errors.New("cannot set both Count and DeviceIDs on device request")
   150  	}
   151  	if _, ok := seen["count"]; !ok && len(req.DeviceIDs) == 0 {
   152  		req.Count = 1
   153  	}
   154  
   155  	return &req, nil
   156  }
   157  
   158  func parseCount(s string) (int, error) {
   159  	if s == "all" {
   160  		return -1, nil
   161  	}
   162  	i, err := strconv.Atoi(s)
   163  	if err != nil {
   164  		return i, fmt.Errorf("count must be an integer: %w", err)
   165  	}
   166  	return i, nil
   167  }