github.com/Ilhicas/nomad@v1.0.4-0.20210304152020-e86851182bc3/devices/gpu/nvidia/device_test.go (about)

     1  package nvidia
     2  
     3  import (
     4  	"testing"
     5  
     6  	hclog "github.com/hashicorp/go-hclog"
     7  	"github.com/hashicorp/nomad/devices/gpu/nvidia/nvml"
     8  	"github.com/hashicorp/nomad/plugins/device"
     9  	"github.com/stretchr/testify/require"
    10  )
    11  
    12  type MockNvmlClient struct {
    13  	FingerprintError            error
    14  	FingerprintResponseReturned *nvml.FingerprintData
    15  
    16  	StatsError            error
    17  	StatsResponseReturned []*nvml.StatsData
    18  }
    19  
    20  func (c *MockNvmlClient) GetFingerprintData() (*nvml.FingerprintData, error) {
    21  	return c.FingerprintResponseReturned, c.FingerprintError
    22  }
    23  
    24  func (c *MockNvmlClient) GetStatsData() ([]*nvml.StatsData, error) {
    25  	return c.StatsResponseReturned, c.StatsError
    26  }
    27  
    28  func TestReserve(t *testing.T) {
    29  	cases := []struct {
    30  		Name                string
    31  		ExpectedReservation *device.ContainerReservation
    32  		ExpectedError       error
    33  		Device              *NvidiaDevice
    34  		RequestedIDs        []string
    35  	}{
    36  		{
    37  			Name:                "All RequestedIDs are not managed by Device",
    38  			ExpectedReservation: nil,
    39  			ExpectedError: &reservationError{[]string{
    40  				"UUID1",
    41  				"UUID2",
    42  				"UUID3",
    43  			}},
    44  			RequestedIDs: []string{
    45  				"UUID1",
    46  				"UUID2",
    47  				"UUID3",
    48  			},
    49  			Device: &NvidiaDevice{
    50  				logger:  hclog.NewNullLogger(),
    51  				enabled: true,
    52  			},
    53  		},
    54  		{
    55  			Name:                "Some RequestedIDs are not managed by Device",
    56  			ExpectedReservation: nil,
    57  			ExpectedError: &reservationError{[]string{
    58  				"UUID1",
    59  				"UUID2",
    60  			}},
    61  			RequestedIDs: []string{
    62  				"UUID1",
    63  				"UUID2",
    64  				"UUID3",
    65  			},
    66  			Device: &NvidiaDevice{
    67  				devices: map[string]struct{}{
    68  					"UUID3": {},
    69  				},
    70  				logger:  hclog.NewNullLogger(),
    71  				enabled: true,
    72  			},
    73  		},
    74  		{
    75  			Name: "All RequestedIDs are managed by Device",
    76  			ExpectedReservation: &device.ContainerReservation{
    77  				Envs: map[string]string{
    78  					NvidiaVisibleDevices: "UUID1,UUID2,UUID3",
    79  				},
    80  			},
    81  			ExpectedError: nil,
    82  			RequestedIDs: []string{
    83  				"UUID1",
    84  				"UUID2",
    85  				"UUID3",
    86  			},
    87  			Device: &NvidiaDevice{
    88  				devices: map[string]struct{}{
    89  					"UUID1": {},
    90  					"UUID2": {},
    91  					"UUID3": {},
    92  				},
    93  				logger:  hclog.NewNullLogger(),
    94  				enabled: true,
    95  			},
    96  		},
    97  		{
    98  			Name:                "No IDs requested",
    99  			ExpectedReservation: &device.ContainerReservation{},
   100  			ExpectedError:       nil,
   101  			RequestedIDs:        nil,
   102  			Device: &NvidiaDevice{
   103  				devices: map[string]struct{}{
   104  					"UUID1": {},
   105  					"UUID2": {},
   106  					"UUID3": {},
   107  				},
   108  				logger:  hclog.NewNullLogger(),
   109  				enabled: true,
   110  			},
   111  		},
   112  		{
   113  			Name:                "Device is disabled",
   114  			ExpectedReservation: nil,
   115  			ExpectedError:       device.ErrPluginDisabled,
   116  			RequestedIDs: []string{
   117  				"UUID1",
   118  				"UUID2",
   119  				"UUID3",
   120  			},
   121  			Device: &NvidiaDevice{
   122  				devices: map[string]struct{}{
   123  					"UUID1": {},
   124  					"UUID2": {},
   125  					"UUID3": {},
   126  				},
   127  				logger:  hclog.NewNullLogger(),
   128  				enabled: false,
   129  			},
   130  		},
   131  	}
   132  
   133  	for _, c := range cases {
   134  		t.Run(c.Name, func(t *testing.T) {
   135  			actualReservation, actualError := c.Device.Reserve(c.RequestedIDs)
   136  			require.Equal(t, c.ExpectedReservation, actualReservation)
   137  			require.Equal(t, c.ExpectedError, actualError)
   138  		})
   139  	}
   140  }