github.com/panekj/cli@v0.0.0-20230304125325-467dd2f3797e/opts/gpus_test.go (about) 1 package opts 2 3 import ( 4 "testing" 5 6 "github.com/docker/docker/api/types/container" 7 "gotest.tools/v3/assert" 8 is "gotest.tools/v3/assert/cmp" 9 ) 10 11 func TestGpusOptAll(t *testing.T) { 12 for _, testcase := range []string{ 13 "all", 14 "-1", 15 "count=all", 16 "count=-1", 17 } { 18 var gpus GpuOpts 19 gpus.Set(testcase) 20 gpuReqs := gpus.Value() 21 assert.Assert(t, is.Len(gpuReqs, 1)) 22 assert.Check(t, is.DeepEqual(gpuReqs[0], container.DeviceRequest{ 23 Count: -1, 24 Capabilities: [][]string{{"gpu"}}, 25 Options: map[string]string{}, 26 })) 27 } 28 } 29 30 func TestGpusOpts(t *testing.T) { 31 for _, testcase := range []string{ 32 "driver=nvidia,\"capabilities=compute,utility\",\"options=foo=bar,baz=qux\"", 33 "1,driver=nvidia,\"capabilities=compute,utility\",\"options=foo=bar,baz=qux\"", 34 "count=1,driver=nvidia,\"capabilities=compute,utility\",\"options=foo=bar,baz=qux\"", 35 "driver=nvidia,\"capabilities=compute,utility\",\"options=foo=bar,baz=qux\",count=1", 36 } { 37 var gpus GpuOpts 38 gpus.Set(testcase) 39 gpuReqs := gpus.Value() 40 assert.Assert(t, is.Len(gpuReqs, 1)) 41 assert.Check(t, is.DeepEqual(gpuReqs[0], container.DeviceRequest{ 42 Driver: "nvidia", 43 Count: 1, 44 Capabilities: [][]string{{"compute", "utility", "gpu"}}, 45 Options: map[string]string{"foo": "bar", "baz": "qux"}, 46 })) 47 } 48 }