github.com/khulnasoft/cli@v0.0.0-20240402070845-01bcad7beefa/opts/gpus.go (about) 1 package opts 2 3 import ( 4 "encoding/csv" 5 "fmt" 6 "strconv" 7 "strings" 8 9 "github.com/docker/docker/api/types/container" 10 "github.com/pkg/errors" 11 ) 12 13 // GpuOpts is a Value type for parsing mounts 14 type GpuOpts struct { 15 values []container.DeviceRequest 16 } 17 18 func parseCount(s string) (int, error) { 19 if s == "all" { 20 return -1, nil 21 } 22 i, err := strconv.Atoi(s) 23 return i, errors.Wrap(err, "count must be an integer") 24 } 25 26 // Set a new mount value 27 // 28 //nolint:gocyclo 29 func (o *GpuOpts) Set(value string) error { 30 csvReader := csv.NewReader(strings.NewReader(value)) 31 fields, err := csvReader.Read() 32 if err != nil { 33 return err 34 } 35 36 req := container.DeviceRequest{} 37 38 seen := map[string]struct{}{} 39 // Set writable as the default 40 for _, field := range fields { 41 key, val, withValue := strings.Cut(field, "=") 42 if _, ok := seen[key]; ok { 43 return fmt.Errorf("gpu request key '%s' can be specified only once", key) 44 } 45 seen[key] = struct{}{} 46 47 if !withValue { 48 seen["count"] = struct{}{} 49 req.Count, err = parseCount(key) 50 if err != nil { 51 return err 52 } 53 continue 54 } 55 56 switch key { 57 case "driver": 58 req.Driver = val 59 case "count": 60 req.Count, err = parseCount(val) 61 if err != nil { 62 return err 63 } 64 case "device": 65 req.DeviceIDs = strings.Split(val, ",") 66 case "capabilities": 67 req.Capabilities = [][]string{append(strings.Split(val, ","), "gpu")} 68 case "options": 69 r := csv.NewReader(strings.NewReader(val)) 70 optFields, err := r.Read() 71 if err != nil { 72 return errors.Wrap(err, "failed to read gpu options") 73 } 74 req.Options = ConvertKVStringsToMap(optFields) 75 default: 76 return fmt.Errorf("unexpected key '%s' in '%s'", key, field) 77 } 78 } 79 80 if _, ok := seen["count"]; !ok && req.DeviceIDs == nil { 81 req.Count = 1 82 } 83 if req.Options == nil { 84 req.Options = make(map[string]string) 85 } 86 if req.Capabilities == nil { 87 req.Capabilities = [][]string{{"gpu"}} 88 } 89 90 o.values = append(o.values, req) 91 return nil 92 } 93 94 // Type returns the type of this option 95 func (o *GpuOpts) Type() string { 96 return "gpu-request" 97 } 98 99 // String returns a string repr of this option 100 func (o *GpuOpts) String() string { 101 gpus := []string{} 102 for _, gpu := range o.values { 103 gpus = append(gpus, fmt.Sprintf("%v", gpu)) 104 } 105 return strings.Join(gpus, ", ") 106 } 107 108 // Value returns the mounts 109 func (o *GpuOpts) Value() []container.DeviceRequest { 110 return o.values 111 }