github.com/kaisenlinux/docker.io@v0.0.0-20230510090727-ea55db55fac7/cli/opts/gpus.go (about)

     1  package opts
     2  
     3  import (
     4  	"encoding/csv"
     5  	"fmt"
     6  	"strconv"
     7  	"strings"
     8  
     9  	"github.com/docker/docker/api/types/container"
    10  	"github.com/pkg/errors"
    11  )
    12  
    13  // GpuOpts is a Value type for parsing mounts
    14  type GpuOpts struct {
    15  	values []container.DeviceRequest
    16  }
    17  
    18  func parseCount(s string) (int, error) {
    19  	if s == "all" {
    20  		return -1, nil
    21  	}
    22  	i, err := strconv.Atoi(s)
    23  	return i, errors.Wrap(err, "count must be an integer")
    24  }
    25  
    26  // Set a new mount value
    27  //
    28  //nolint:gocyclo
    29  func (o *GpuOpts) Set(value string) error {
    30  	csvReader := csv.NewReader(strings.NewReader(value))
    31  	fields, err := csvReader.Read()
    32  	if err != nil {
    33  		return err
    34  	}
    35  
    36  	req := container.DeviceRequest{}
    37  
    38  	seen := map[string]struct{}{}
    39  	// Set writable as the default
    40  	for _, field := range fields {
    41  		parts := strings.SplitN(field, "=", 2)
    42  		key := parts[0]
    43  		if _, ok := seen[key]; ok {
    44  			return fmt.Errorf("gpu request key '%s' can be specified only once", key)
    45  		}
    46  		seen[key] = struct{}{}
    47  
    48  		if len(parts) == 1 {
    49  			seen["count"] = struct{}{}
    50  			req.Count, err = parseCount(key)
    51  			if err != nil {
    52  				return err
    53  			}
    54  			continue
    55  		}
    56  
    57  		value := parts[1]
    58  		switch key {
    59  		case "driver":
    60  			req.Driver = value
    61  		case "count":
    62  			req.Count, err = parseCount(value)
    63  			if err != nil {
    64  				return err
    65  			}
    66  		case "device":
    67  			req.DeviceIDs = strings.Split(value, ",")
    68  		case "capabilities":
    69  			req.Capabilities = [][]string{append(strings.Split(value, ","), "gpu")}
    70  		case "options":
    71  			r := csv.NewReader(strings.NewReader(value))
    72  			optFields, err := r.Read()
    73  			if err != nil {
    74  				return errors.Wrap(err, "failed to read gpu options")
    75  			}
    76  			req.Options = ConvertKVStringsToMap(optFields)
    77  		default:
    78  			return fmt.Errorf("unexpected key '%s' in '%s'", key, field)
    79  		}
    80  	}
    81  
    82  	if _, ok := seen["count"]; !ok && req.DeviceIDs == nil {
    83  		req.Count = 1
    84  	}
    85  	if req.Options == nil {
    86  		req.Options = make(map[string]string)
    87  	}
    88  	if req.Capabilities == nil {
    89  		req.Capabilities = [][]string{{"gpu"}}
    90  	}
    91  
    92  	o.values = append(o.values, req)
    93  	return nil
    94  }
    95  
    96  // Type returns the type of this option
    97  func (o *GpuOpts) Type() string {
    98  	return "gpu-request"
    99  }
   100  
   101  // String returns a string repr of this option
   102  func (o *GpuOpts) String() string {
   103  	gpus := []string{}
   104  	for _, gpu := range o.values {
   105  		gpus = append(gpus, fmt.Sprintf("%v", gpu))
   106  	}
   107  	return strings.Join(gpus, ", ")
   108  }
   109  
   110  // Value returns the mounts
   111  func (o *GpuOpts) Value() []container.DeviceRequest {
   112  	return o.values
   113  }