github.com/khulnasoft/cli@v0.0.0-20240402070845-01bcad7beefa/opts/gpus.go (about)

     1  package opts
     2  
     3  import (
     4  	"encoding/csv"
     5  	"fmt"
     6  	"strconv"
     7  	"strings"
     8  
     9  	"github.com/docker/docker/api/types/container"
    10  	"github.com/pkg/errors"
    11  )
    12  
    13  // GpuOpts is a Value type for parsing mounts
    14  type GpuOpts struct {
    15  	values []container.DeviceRequest
    16  }
    17  
    18  func parseCount(s string) (int, error) {
    19  	if s == "all" {
    20  		return -1, nil
    21  	}
    22  	i, err := strconv.Atoi(s)
    23  	return i, errors.Wrap(err, "count must be an integer")
    24  }
    25  
    26  // Set a new mount value
    27  //
    28  //nolint:gocyclo
    29  func (o *GpuOpts) Set(value string) error {
    30  	csvReader := csv.NewReader(strings.NewReader(value))
    31  	fields, err := csvReader.Read()
    32  	if err != nil {
    33  		return err
    34  	}
    35  
    36  	req := container.DeviceRequest{}
    37  
    38  	seen := map[string]struct{}{}
    39  	// Set writable as the default
    40  	for _, field := range fields {
    41  		key, val, withValue := strings.Cut(field, "=")
    42  		if _, ok := seen[key]; ok {
    43  			return fmt.Errorf("gpu request key '%s' can be specified only once", key)
    44  		}
    45  		seen[key] = struct{}{}
    46  
    47  		if !withValue {
    48  			seen["count"] = struct{}{}
    49  			req.Count, err = parseCount(key)
    50  			if err != nil {
    51  				return err
    52  			}
    53  			continue
    54  		}
    55  
    56  		switch key {
    57  		case "driver":
    58  			req.Driver = val
    59  		case "count":
    60  			req.Count, err = parseCount(val)
    61  			if err != nil {
    62  				return err
    63  			}
    64  		case "device":
    65  			req.DeviceIDs = strings.Split(val, ",")
    66  		case "capabilities":
    67  			req.Capabilities = [][]string{append(strings.Split(val, ","), "gpu")}
    68  		case "options":
    69  			r := csv.NewReader(strings.NewReader(val))
    70  			optFields, err := r.Read()
    71  			if err != nil {
    72  				return errors.Wrap(err, "failed to read gpu options")
    73  			}
    74  			req.Options = ConvertKVStringsToMap(optFields)
    75  		default:
    76  			return fmt.Errorf("unexpected key '%s' in '%s'", key, field)
    77  		}
    78  	}
    79  
    80  	if _, ok := seen["count"]; !ok && req.DeviceIDs == nil {
    81  		req.Count = 1
    82  	}
    83  	if req.Options == nil {
    84  		req.Options = make(map[string]string)
    85  	}
    86  	if req.Capabilities == nil {
    87  		req.Capabilities = [][]string{{"gpu"}}
    88  	}
    89  
    90  	o.values = append(o.values, req)
    91  	return nil
    92  }
    93  
    94  // Type returns the type of this option
    95  func (o *GpuOpts) Type() string {
    96  	return "gpu-request"
    97  }
    98  
    99  // String returns a string repr of this option
   100  func (o *GpuOpts) String() string {
   101  	gpus := []string{}
   102  	for _, gpu := range o.values {
   103  		gpus = append(gpus, fmt.Sprintf("%v", gpu))
   104  	}
   105  	return strings.Join(gpus, ", ")
   106  }
   107  
   108  // Value returns the mounts
   109  func (o *GpuOpts) Value() []container.DeviceRequest {
   110  	return o.values
   111  }