github.com/panekj/cli@v0.0.0-20230304125325-467dd2f3797e/opts/gpus_test.go (about)

     1  package opts
     2  
     3  import (
     4  	"testing"
     5  
     6  	"github.com/docker/docker/api/types/container"
     7  	"gotest.tools/v3/assert"
     8  	is "gotest.tools/v3/assert/cmp"
     9  )
    10  
    11  func TestGpusOptAll(t *testing.T) {
    12  	for _, testcase := range []string{
    13  		"all",
    14  		"-1",
    15  		"count=all",
    16  		"count=-1",
    17  	} {
    18  		var gpus GpuOpts
    19  		gpus.Set(testcase)
    20  		gpuReqs := gpus.Value()
    21  		assert.Assert(t, is.Len(gpuReqs, 1))
    22  		assert.Check(t, is.DeepEqual(gpuReqs[0], container.DeviceRequest{
    23  			Count:        -1,
    24  			Capabilities: [][]string{{"gpu"}},
    25  			Options:      map[string]string{},
    26  		}))
    27  	}
    28  }
    29  
    30  func TestGpusOpts(t *testing.T) {
    31  	for _, testcase := range []string{
    32  		"driver=nvidia,\"capabilities=compute,utility\",\"options=foo=bar,baz=qux\"",
    33  		"1,driver=nvidia,\"capabilities=compute,utility\",\"options=foo=bar,baz=qux\"",
    34  		"count=1,driver=nvidia,\"capabilities=compute,utility\",\"options=foo=bar,baz=qux\"",
    35  		"driver=nvidia,\"capabilities=compute,utility\",\"options=foo=bar,baz=qux\",count=1",
    36  	} {
    37  		var gpus GpuOpts
    38  		gpus.Set(testcase)
    39  		gpuReqs := gpus.Value()
    40  		assert.Assert(t, is.Len(gpuReqs, 1))
    41  		assert.Check(t, is.DeepEqual(gpuReqs[0], container.DeviceRequest{
    42  			Driver:       "nvidia",
    43  			Count:        1,
    44  			Capabilities: [][]string{{"compute", "utility", "gpu"}},
    45  			Options:      map[string]string{"foo": "bar", "baz": "qux"},
    46  		}))
    47  	}
    48  }