github.com/itscaro/cli@v0.0.0-20190705081621-c9db0fe93829/opts/gpus.go (about) 1 package opts 2 3 import ( 4 "encoding/csv" 5 "fmt" 6 "strconv" 7 "strings" 8 9 "github.com/docker/docker/api/types/container" 10 "github.com/pkg/errors" 11 ) 12 13 // GpuOpts is a Value type for parsing mounts 14 type GpuOpts struct { 15 values []container.DeviceRequest 16 } 17 18 func parseCount(s string) (int, error) { 19 if s == "all" { 20 return -1, nil 21 } 22 i, err := strconv.Atoi(s) 23 return i, errors.Wrap(err, "count must be an integer") 24 } 25 26 // Set a new mount value 27 // nolint: gocyclo 28 func (o *GpuOpts) Set(value string) error { 29 csvReader := csv.NewReader(strings.NewReader(value)) 30 fields, err := csvReader.Read() 31 if err != nil { 32 return err 33 } 34 35 req := container.DeviceRequest{} 36 37 seen := map[string]struct{}{} 38 // Set writable as the default 39 for _, field := range fields { 40 parts := strings.SplitN(field, "=", 2) 41 key := parts[0] 42 if _, ok := seen[key]; ok { 43 return fmt.Errorf("gpu request key '%s' can be specified only once", key) 44 } 45 seen[key] = struct{}{} 46 47 if len(parts) == 1 { 48 seen["count"] = struct{}{} 49 req.Count, err = parseCount(key) 50 if err != nil { 51 return err 52 } 53 continue 54 } 55 56 value := parts[1] 57 switch key { 58 case "driver": 59 req.Driver = value 60 case "count": 61 req.Count, err = parseCount(value) 62 if err != nil { 63 return err 64 } 65 case "device": 66 req.DeviceIDs = strings.Split(value, ",") 67 case "capabilities": 68 req.Capabilities = [][]string{append(strings.Split(value, ","), "gpu")} 69 case "options": 70 r := csv.NewReader(strings.NewReader(value)) 71 optFields, err := r.Read() 72 if err != nil { 73 return errors.Wrap(err, "failed to read gpu options") 74 } 75 req.Options = ConvertKVStringsToMap(optFields) 76 default: 77 return fmt.Errorf("unexpected key '%s' in '%s'", key, field) 78 } 79 } 80 81 if _, ok := seen["count"]; !ok && req.DeviceIDs == nil { 82 req.Count = 1 83 } 84 if req.Options == nil { 85 req.Options = make(map[string]string) 86 } 87 if req.Capabilities == nil { 88 req.Capabilities = [][]string{{"gpu"}} 89 } 90 91 o.values = append(o.values, req) 92 return nil 93 } 94 95 // Type returns the type of this option 96 func (o *GpuOpts) Type() string { 97 return "gpu-request" 98 } 99 100 // String returns a string repr of this option 101 func (o *GpuOpts) String() string { 102 gpus := []string{} 103 for _, gpu := range o.values { 104 gpus = append(gpus, fmt.Sprintf("%v", gpu)) 105 } 106 return strings.Join(gpus, ", ") 107 } 108 109 // Value returns the mounts 110 func (o *GpuOpts) Value() []container.DeviceRequest { 111 return o.values 112 }