github.com/itscaro/cli@v0.0.0-20190705081621-c9db0fe93829/opts/gpus.go (about)

     1  package opts
     2  
     3  import (
     4  	"encoding/csv"
     5  	"fmt"
     6  	"strconv"
     7  	"strings"
     8  
     9  	"github.com/docker/docker/api/types/container"
    10  	"github.com/pkg/errors"
    11  )
    12  
    13  // GpuOpts is a Value type for parsing mounts
    14  type GpuOpts struct {
    15  	values []container.DeviceRequest
    16  }
    17  
    18  func parseCount(s string) (int, error) {
    19  	if s == "all" {
    20  		return -1, nil
    21  	}
    22  	i, err := strconv.Atoi(s)
    23  	return i, errors.Wrap(err, "count must be an integer")
    24  }
    25  
    26  // Set a new mount value
    27  // nolint: gocyclo
    28  func (o *GpuOpts) Set(value string) error {
    29  	csvReader := csv.NewReader(strings.NewReader(value))
    30  	fields, err := csvReader.Read()
    31  	if err != nil {
    32  		return err
    33  	}
    34  
    35  	req := container.DeviceRequest{}
    36  
    37  	seen := map[string]struct{}{}
    38  	// Set writable as the default
    39  	for _, field := range fields {
    40  		parts := strings.SplitN(field, "=", 2)
    41  		key := parts[0]
    42  		if _, ok := seen[key]; ok {
    43  			return fmt.Errorf("gpu request key '%s' can be specified only once", key)
    44  		}
    45  		seen[key] = struct{}{}
    46  
    47  		if len(parts) == 1 {
    48  			seen["count"] = struct{}{}
    49  			req.Count, err = parseCount(key)
    50  			if err != nil {
    51  				return err
    52  			}
    53  			continue
    54  		}
    55  
    56  		value := parts[1]
    57  		switch key {
    58  		case "driver":
    59  			req.Driver = value
    60  		case "count":
    61  			req.Count, err = parseCount(value)
    62  			if err != nil {
    63  				return err
    64  			}
    65  		case "device":
    66  			req.DeviceIDs = strings.Split(value, ",")
    67  		case "capabilities":
    68  			req.Capabilities = [][]string{append(strings.Split(value, ","), "gpu")}
    69  		case "options":
    70  			r := csv.NewReader(strings.NewReader(value))
    71  			optFields, err := r.Read()
    72  			if err != nil {
    73  				return errors.Wrap(err, "failed to read gpu options")
    74  			}
    75  			req.Options = ConvertKVStringsToMap(optFields)
    76  		default:
    77  			return fmt.Errorf("unexpected key '%s' in '%s'", key, field)
    78  		}
    79  	}
    80  
    81  	if _, ok := seen["count"]; !ok && req.DeviceIDs == nil {
    82  		req.Count = 1
    83  	}
    84  	if req.Options == nil {
    85  		req.Options = make(map[string]string)
    86  	}
    87  	if req.Capabilities == nil {
    88  		req.Capabilities = [][]string{{"gpu"}}
    89  	}
    90  
    91  	o.values = append(o.values, req)
    92  	return nil
    93  }
    94  
    95  // Type returns the type of this option
    96  func (o *GpuOpts) Type() string {
    97  	return "gpu-request"
    98  }
    99  
   100  // String returns a string repr of this option
   101  func (o *GpuOpts) String() string {
   102  	gpus := []string{}
   103  	for _, gpu := range o.values {
   104  		gpus = append(gpus, fmt.Sprintf("%v", gpu))
   105  	}
   106  	return strings.Join(gpus, ", ")
   107  }
   108  
   109  // Value returns the mounts
   110  func (o *GpuOpts) Value() []container.DeviceRequest {
   111  	return o.values
   112  }