github.com/anth0d/nomad@v0.0.0-20221214183521-ae3a0a2cad06/plugins/device/plugin_test.go (about)

     1  package device
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"testing"
     7  	"time"
     8  
     9  	pb "github.com/golang/protobuf/proto"
    10  	plugin "github.com/hashicorp/go-plugin"
    11  	"github.com/hashicorp/nomad/ci"
    12  	"github.com/hashicorp/nomad/helper/pointer"
    13  	"github.com/hashicorp/nomad/nomad/structs"
    14  	"github.com/hashicorp/nomad/plugins/base"
    15  	"github.com/hashicorp/nomad/plugins/shared/hclspec"
    16  	psstructs "github.com/hashicorp/nomad/plugins/shared/structs"
    17  	"github.com/hashicorp/nomad/testutil"
    18  	"github.com/stretchr/testify/require"
    19  	"github.com/zclconf/go-cty/cty"
    20  	"github.com/zclconf/go-cty/cty/msgpack"
    21  	"google.golang.org/grpc/status"
    22  )
    23  
    24  func TestDevicePlugin_PluginInfo(t *testing.T) {
    25  	ci.Parallel(t)
    26  	require := require.New(t)
    27  
    28  	var (
    29  		apiVersions = []string{"v0.1.0", "v0.2.0"}
    30  	)
    31  
    32  	const (
    33  		pluginVersion = "v0.2.1"
    34  		pluginName    = "mock_device"
    35  	)
    36  
    37  	knownType := func() (*base.PluginInfoResponse, error) {
    38  		info := &base.PluginInfoResponse{
    39  			Type:              base.PluginTypeDevice,
    40  			PluginApiVersions: apiVersions,
    41  			PluginVersion:     pluginVersion,
    42  			Name:              pluginName,
    43  		}
    44  		return info, nil
    45  	}
    46  	unknownType := func() (*base.PluginInfoResponse, error) {
    47  		info := &base.PluginInfoResponse{
    48  			Type:              "bad",
    49  			PluginApiVersions: apiVersions,
    50  			PluginVersion:     pluginVersion,
    51  			Name:              pluginName,
    52  		}
    53  		return info, nil
    54  	}
    55  
    56  	mock := &MockDevicePlugin{
    57  		MockPlugin: &base.MockPlugin{
    58  			PluginInfoF: knownType,
    59  		},
    60  	}
    61  
    62  	client, server := plugin.TestPluginGRPCConn(t, map[string]plugin.Plugin{
    63  		base.PluginTypeBase:   &base.PluginBase{Impl: mock},
    64  		base.PluginTypeDevice: &PluginDevice{Impl: mock},
    65  	})
    66  	defer server.Stop()
    67  	defer client.Close()
    68  
    69  	raw, err := client.Dispense(base.PluginTypeDevice)
    70  	if err != nil {
    71  		t.Fatalf("err: %s", err)
    72  	}
    73  
    74  	impl, ok := raw.(DevicePlugin)
    75  	if !ok {
    76  		t.Fatalf("bad: %#v", raw)
    77  	}
    78  
    79  	resp, err := impl.PluginInfo()
    80  	require.NoError(err)
    81  	require.Equal(apiVersions, resp.PluginApiVersions)
    82  	require.Equal(pluginVersion, resp.PluginVersion)
    83  	require.Equal(pluginName, resp.Name)
    84  	require.Equal(base.PluginTypeDevice, resp.Type)
    85  
    86  	// Swap the implementation to return an unknown type
    87  	mock.PluginInfoF = unknownType
    88  	_, err = impl.PluginInfo()
    89  	require.Error(err)
    90  	require.Contains(err.Error(), "unknown type")
    91  }
    92  
    93  func TestDevicePlugin_ConfigSchema(t *testing.T) {
    94  	ci.Parallel(t)
    95  	require := require.New(t)
    96  
    97  	mock := &MockDevicePlugin{
    98  		MockPlugin: &base.MockPlugin{
    99  			ConfigSchemaF: func() (*hclspec.Spec, error) {
   100  				return base.TestSpec, nil
   101  			},
   102  		},
   103  	}
   104  
   105  	client, server := plugin.TestPluginGRPCConn(t, map[string]plugin.Plugin{
   106  		base.PluginTypeBase:   &base.PluginBase{Impl: mock},
   107  		base.PluginTypeDevice: &PluginDevice{Impl: mock},
   108  	})
   109  	defer server.Stop()
   110  	defer client.Close()
   111  
   112  	raw, err := client.Dispense(base.PluginTypeDevice)
   113  	if err != nil {
   114  		t.Fatalf("err: %s", err)
   115  	}
   116  
   117  	impl, ok := raw.(DevicePlugin)
   118  	if !ok {
   119  		t.Fatalf("bad: %#v", raw)
   120  	}
   121  
   122  	specOut, err := impl.ConfigSchema()
   123  	require.NoError(err)
   124  	require.True(pb.Equal(base.TestSpec, specOut))
   125  }
   126  
   127  func TestDevicePlugin_SetConfig(t *testing.T) {
   128  	ci.Parallel(t)
   129  	require := require.New(t)
   130  
   131  	var receivedData []byte
   132  	mock := &MockDevicePlugin{
   133  		MockPlugin: &base.MockPlugin{
   134  			PluginInfoF: func() (*base.PluginInfoResponse, error) {
   135  				return &base.PluginInfoResponse{
   136  					Type:              base.PluginTypeDevice,
   137  					PluginApiVersions: []string{"v0.0.1"},
   138  					PluginVersion:     "v0.0.1",
   139  					Name:              "mock_device",
   140  				}, nil
   141  			},
   142  			ConfigSchemaF: func() (*hclspec.Spec, error) {
   143  				return base.TestSpec, nil
   144  			},
   145  			SetConfigF: func(cfg *base.Config) error {
   146  				receivedData = cfg.PluginConfig
   147  				return nil
   148  			},
   149  		},
   150  	}
   151  
   152  	client, server := plugin.TestPluginGRPCConn(t, map[string]plugin.Plugin{
   153  		base.PluginTypeBase:   &base.PluginBase{Impl: mock},
   154  		base.PluginTypeDevice: &PluginDevice{Impl: mock},
   155  	})
   156  	defer server.Stop()
   157  	defer client.Close()
   158  
   159  	raw, err := client.Dispense(base.PluginTypeDevice)
   160  	if err != nil {
   161  		t.Fatalf("err: %s", err)
   162  	}
   163  
   164  	impl, ok := raw.(DevicePlugin)
   165  	if !ok {
   166  		t.Fatalf("bad: %#v", raw)
   167  	}
   168  
   169  	config := cty.ObjectVal(map[string]cty.Value{
   170  		"foo": cty.StringVal("v1"),
   171  		"bar": cty.NumberIntVal(1337),
   172  		"baz": cty.BoolVal(true),
   173  	})
   174  	cdata, err := msgpack.Marshal(config, config.Type())
   175  	require.NoError(err)
   176  	require.NoError(impl.SetConfig(&base.Config{PluginConfig: cdata}))
   177  	require.Equal(cdata, receivedData)
   178  
   179  	// Decode the value back
   180  	var actual base.TestConfig
   181  	require.NoError(structs.Decode(receivedData, &actual))
   182  	require.Equal("v1", actual.Foo)
   183  	require.EqualValues(1337, actual.Bar)
   184  	require.True(actual.Baz)
   185  }
   186  
   187  func TestDevicePlugin_Fingerprint(t *testing.T) {
   188  	ci.Parallel(t)
   189  	require := require.New(t)
   190  
   191  	devices1 := []*DeviceGroup{
   192  		{
   193  			Vendor: "nvidia",
   194  			Type:   DeviceTypeGPU,
   195  			Name:   "foo",
   196  			Attributes: map[string]*psstructs.Attribute{
   197  				"memory": {
   198  					Int:  pointer.Of(int64(4)),
   199  					Unit: "GiB",
   200  				},
   201  			},
   202  		},
   203  	}
   204  	devices2 := []*DeviceGroup{
   205  		{
   206  			Vendor: "nvidia",
   207  			Type:   DeviceTypeGPU,
   208  			Name:   "foo",
   209  		},
   210  		{
   211  			Vendor: "nvidia",
   212  			Type:   DeviceTypeGPU,
   213  			Name:   "bar",
   214  		},
   215  	}
   216  
   217  	mock := &MockDevicePlugin{
   218  		FingerprintF: func(ctx context.Context) (<-chan *FingerprintResponse, error) {
   219  			outCh := make(chan *FingerprintResponse, 1)
   220  			go func() {
   221  				// Send two messages
   222  				for _, devs := range [][]*DeviceGroup{devices1, devices2} {
   223  					select {
   224  					case <-ctx.Done():
   225  						return
   226  					case outCh <- &FingerprintResponse{Devices: devs}:
   227  					}
   228  				}
   229  				close(outCh)
   230  				return
   231  			}()
   232  			return outCh, nil
   233  		},
   234  	}
   235  
   236  	client, server := plugin.TestPluginGRPCConn(t, map[string]plugin.Plugin{
   237  		base.PluginTypeBase:   &base.PluginBase{Impl: mock},
   238  		base.PluginTypeDevice: &PluginDevice{Impl: mock},
   239  	})
   240  	defer server.Stop()
   241  	defer client.Close()
   242  
   243  	raw, err := client.Dispense(base.PluginTypeDevice)
   244  	if err != nil {
   245  		t.Fatalf("err: %s", err)
   246  	}
   247  
   248  	impl, ok := raw.(DevicePlugin)
   249  	if !ok {
   250  		t.Fatalf("bad: %#v", raw)
   251  	}
   252  
   253  	// Create a context
   254  	ctx, cancel := context.WithCancel(context.Background())
   255  	defer cancel()
   256  
   257  	// Get the stream
   258  	stream, err := impl.Fingerprint(ctx)
   259  	require.NoError(err)
   260  
   261  	// Get the first message
   262  	var first *FingerprintResponse
   263  	select {
   264  	case <-time.After(1 * time.Second):
   265  		t.Fatal("timeout")
   266  	case first = <-stream:
   267  	}
   268  
   269  	require.NoError(first.Error)
   270  	require.EqualValues(devices1, first.Devices)
   271  
   272  	// Get the second message
   273  	var second *FingerprintResponse
   274  	select {
   275  	case <-time.After(1 * time.Second):
   276  		t.Fatal("timeout")
   277  	case second = <-stream:
   278  	}
   279  
   280  	require.NoError(second.Error)
   281  	require.EqualValues(devices2, second.Devices)
   282  
   283  	select {
   284  	case _, ok := <-stream:
   285  		require.False(ok)
   286  	case <-time.After(1 * time.Second):
   287  		t.Fatal("stream should be closed")
   288  	}
   289  }
   290  
   291  func TestDevicePlugin_Fingerprint_StreamErr(t *testing.T) {
   292  	ci.Parallel(t)
   293  	require := require.New(t)
   294  
   295  	ferr := fmt.Errorf("mock fingerprinting failed")
   296  	mock := &MockDevicePlugin{
   297  		FingerprintF: func(ctx context.Context) (<-chan *FingerprintResponse, error) {
   298  			outCh := make(chan *FingerprintResponse, 1)
   299  			go func() {
   300  				// Send the error
   301  				select {
   302  				case <-ctx.Done():
   303  					return
   304  				case outCh <- &FingerprintResponse{Error: ferr}:
   305  				}
   306  
   307  				close(outCh)
   308  				return
   309  			}()
   310  			return outCh, nil
   311  		},
   312  	}
   313  
   314  	client, server := plugin.TestPluginGRPCConn(t, map[string]plugin.Plugin{
   315  		base.PluginTypeBase:   &base.PluginBase{Impl: mock},
   316  		base.PluginTypeDevice: &PluginDevice{Impl: mock},
   317  	})
   318  	defer server.Stop()
   319  	defer client.Close()
   320  
   321  	raw, err := client.Dispense(base.PluginTypeDevice)
   322  	if err != nil {
   323  		t.Fatalf("err: %s", err)
   324  	}
   325  
   326  	impl, ok := raw.(DevicePlugin)
   327  	if !ok {
   328  		t.Fatalf("bad: %#v", raw)
   329  	}
   330  
   331  	// Create a context
   332  	ctx, cancel := context.WithCancel(context.Background())
   333  	defer cancel()
   334  
   335  	// Get the stream
   336  	stream, err := impl.Fingerprint(ctx)
   337  	require.NoError(err)
   338  
   339  	// Get the first message
   340  	var first *FingerprintResponse
   341  	select {
   342  	case <-time.After(1 * time.Second):
   343  		t.Fatal("timeout")
   344  	case first = <-stream:
   345  	}
   346  
   347  	errStatus := status.Convert(ferr)
   348  	require.EqualError(first.Error, errStatus.Err().Error())
   349  }
   350  
   351  func TestDevicePlugin_Fingerprint_CancelCtx(t *testing.T) {
   352  	ci.Parallel(t)
   353  	require := require.New(t)
   354  
   355  	mock := &MockDevicePlugin{
   356  		FingerprintF: func(ctx context.Context) (<-chan *FingerprintResponse, error) {
   357  			outCh := make(chan *FingerprintResponse, 1)
   358  			go func() {
   359  				<-ctx.Done()
   360  				close(outCh)
   361  				return
   362  			}()
   363  			return outCh, nil
   364  		},
   365  	}
   366  
   367  	client, server := plugin.TestPluginGRPCConn(t, map[string]plugin.Plugin{
   368  		base.PluginTypeBase:   &base.PluginBase{Impl: mock},
   369  		base.PluginTypeDevice: &PluginDevice{Impl: mock},
   370  	})
   371  	defer server.Stop()
   372  	defer client.Close()
   373  
   374  	raw, err := client.Dispense(base.PluginTypeDevice)
   375  	if err != nil {
   376  		t.Fatalf("err: %s", err)
   377  	}
   378  
   379  	impl, ok := raw.(DevicePlugin)
   380  	if !ok {
   381  		t.Fatalf("bad: %#v", raw)
   382  	}
   383  
   384  	// Create a context
   385  	ctx, cancel := context.WithCancel(context.Background())
   386  
   387  	// Get the stream
   388  	stream, err := impl.Fingerprint(ctx)
   389  	require.NoError(err)
   390  
   391  	// Get the first message
   392  	select {
   393  	case <-time.After(testutil.Timeout(10 * time.Millisecond)):
   394  	case _ = <-stream:
   395  		t.Fatal("bad value")
   396  	}
   397  
   398  	// Cancel the context
   399  	cancel()
   400  
   401  	// Make sure we are done
   402  	select {
   403  	case <-time.After(100 * time.Millisecond):
   404  		t.Fatalf("timeout")
   405  	case v := <-stream:
   406  		require.Error(v.Error)
   407  		require.EqualError(v.Error, context.Canceled.Error())
   408  	}
   409  }
   410  
   411  func TestDevicePlugin_Reserve(t *testing.T) {
   412  	ci.Parallel(t)
   413  	require := require.New(t)
   414  
   415  	reservation := &ContainerReservation{
   416  		Envs: map[string]string{
   417  			"foo": "bar",
   418  		},
   419  		Mounts: []*Mount{
   420  			{
   421  				TaskPath: "foo",
   422  				HostPath: "bar",
   423  				ReadOnly: true,
   424  			},
   425  		},
   426  		Devices: []*DeviceSpec{
   427  			{
   428  				TaskPath:    "foo",
   429  				HostPath:    "bar",
   430  				CgroupPerms: "rx",
   431  			},
   432  		},
   433  	}
   434  
   435  	var received []string
   436  	mock := &MockDevicePlugin{
   437  		ReserveF: func(devices []string) (*ContainerReservation, error) {
   438  			received = devices
   439  			return reservation, nil
   440  		},
   441  	}
   442  
   443  	client, server := plugin.TestPluginGRPCConn(t, map[string]plugin.Plugin{
   444  		base.PluginTypeBase:   &base.PluginBase{Impl: mock},
   445  		base.PluginTypeDevice: &PluginDevice{Impl: mock},
   446  	})
   447  	defer server.Stop()
   448  	defer client.Close()
   449  
   450  	raw, err := client.Dispense(base.PluginTypeDevice)
   451  	if err != nil {
   452  		t.Fatalf("err: %s", err)
   453  	}
   454  
   455  	impl, ok := raw.(DevicePlugin)
   456  	if !ok {
   457  		t.Fatalf("bad: %#v", raw)
   458  	}
   459  
   460  	req := []string{"a", "b"}
   461  	containerRes, err := impl.Reserve(req)
   462  	require.NoError(err)
   463  	require.EqualValues(req, received)
   464  	require.EqualValues(reservation, containerRes)
   465  }
   466  
   467  func TestDevicePlugin_Stats(t *testing.T) {
   468  	ci.Parallel(t)
   469  	require := require.New(t)
   470  
   471  	devices1 := []*DeviceGroupStats{
   472  		{
   473  			Vendor: "nvidia",
   474  			Type:   DeviceTypeGPU,
   475  			Name:   "foo",
   476  			InstanceStats: map[string]*DeviceStats{
   477  				"1": {
   478  					Summary: &psstructs.StatValue{
   479  						IntNumeratorVal:   pointer.Of(int64(10)),
   480  						IntDenominatorVal: pointer.Of(int64(20)),
   481  						Unit:              "MB",
   482  						Desc:              "Unit test",
   483  					},
   484  				},
   485  			},
   486  		},
   487  	}
   488  	devices2 := []*DeviceGroupStats{
   489  		{
   490  			Vendor: "nvidia",
   491  			Type:   DeviceTypeGPU,
   492  			Name:   "foo",
   493  			InstanceStats: map[string]*DeviceStats{
   494  				"1": {
   495  					Summary: &psstructs.StatValue{
   496  						FloatNumeratorVal:   pointer.Of(float64(10.0)),
   497  						FloatDenominatorVal: pointer.Of(float64(20.0)),
   498  						Unit:                "MB",
   499  						Desc:                "Unit test",
   500  					},
   501  				},
   502  			},
   503  		},
   504  		{
   505  			Vendor: "nvidia",
   506  			Type:   DeviceTypeGPU,
   507  			Name:   "bar",
   508  			InstanceStats: map[string]*DeviceStats{
   509  				"1": {
   510  					Summary: &psstructs.StatValue{
   511  						StringVal: pointer.Of("foo"),
   512  						Unit:      "MB",
   513  						Desc:      "Unit test",
   514  					},
   515  				},
   516  			},
   517  		},
   518  		{
   519  			Vendor: "nvidia",
   520  			Type:   DeviceTypeGPU,
   521  			Name:   "baz",
   522  			InstanceStats: map[string]*DeviceStats{
   523  				"1": {
   524  					Summary: &psstructs.StatValue{
   525  						BoolVal: pointer.Of(true),
   526  						Unit:    "MB",
   527  						Desc:    "Unit test",
   528  					},
   529  				},
   530  			},
   531  		},
   532  	}
   533  
   534  	mock := &MockDevicePlugin{
   535  		StatsF: func(ctx context.Context, interval time.Duration) (<-chan *StatsResponse, error) {
   536  			outCh := make(chan *StatsResponse, 1)
   537  			go func() {
   538  				// Send two messages
   539  				for _, devs := range [][]*DeviceGroupStats{devices1, devices2} {
   540  					select {
   541  					case <-ctx.Done():
   542  						return
   543  					case outCh <- &StatsResponse{Groups: devs}:
   544  					}
   545  				}
   546  				close(outCh)
   547  				return
   548  			}()
   549  			return outCh, nil
   550  		},
   551  	}
   552  
   553  	client, server := plugin.TestPluginGRPCConn(t, map[string]plugin.Plugin{
   554  		base.PluginTypeBase:   &base.PluginBase{Impl: mock},
   555  		base.PluginTypeDevice: &PluginDevice{Impl: mock},
   556  	})
   557  	defer server.Stop()
   558  	defer client.Close()
   559  
   560  	raw, err := client.Dispense(base.PluginTypeDevice)
   561  	if err != nil {
   562  		t.Fatalf("err: %s", err)
   563  	}
   564  
   565  	impl, ok := raw.(DevicePlugin)
   566  	if !ok {
   567  		t.Fatalf("bad: %#v", raw)
   568  	}
   569  
   570  	// Create a context
   571  	ctx, cancel := context.WithCancel(context.Background())
   572  	defer cancel()
   573  
   574  	// Get the stream
   575  	stream, err := impl.Stats(ctx, time.Millisecond)
   576  	require.NoError(err)
   577  
   578  	// Get the first message
   579  	var first *StatsResponse
   580  	select {
   581  	case <-time.After(1 * time.Second):
   582  		t.Fatal("timeout")
   583  	case first = <-stream:
   584  	}
   585  
   586  	require.NoError(first.Error)
   587  	require.EqualValues(devices1, first.Groups)
   588  
   589  	// Get the second message
   590  	var second *StatsResponse
   591  	select {
   592  	case <-time.After(1 * time.Second):
   593  		t.Fatal("timeout")
   594  	case second = <-stream:
   595  	}
   596  
   597  	require.NoError(second.Error)
   598  	require.EqualValues(devices2, second.Groups)
   599  
   600  	select {
   601  	case _, ok := <-stream:
   602  		require.False(ok)
   603  	case <-time.After(1 * time.Second):
   604  		t.Fatal("stream should be closed")
   605  	}
   606  }
   607  
   608  func TestDevicePlugin_Stats_StreamErr(t *testing.T) {
   609  	ci.Parallel(t)
   610  	require := require.New(t)
   611  
   612  	ferr := fmt.Errorf("mock stats failed")
   613  	mock := &MockDevicePlugin{
   614  		StatsF: func(ctx context.Context, interval time.Duration) (<-chan *StatsResponse, error) {
   615  			outCh := make(chan *StatsResponse, 1)
   616  			go func() {
   617  				// Send the error
   618  				select {
   619  				case <-ctx.Done():
   620  					return
   621  				case outCh <- &StatsResponse{Error: ferr}:
   622  				}
   623  
   624  				close(outCh)
   625  				return
   626  			}()
   627  			return outCh, nil
   628  		},
   629  	}
   630  
   631  	client, server := plugin.TestPluginGRPCConn(t, map[string]plugin.Plugin{
   632  		base.PluginTypeBase:   &base.PluginBase{Impl: mock},
   633  		base.PluginTypeDevice: &PluginDevice{Impl: mock},
   634  	})
   635  	defer server.Stop()
   636  	defer client.Close()
   637  
   638  	raw, err := client.Dispense(base.PluginTypeDevice)
   639  	if err != nil {
   640  		t.Fatalf("err: %s", err)
   641  	}
   642  
   643  	impl, ok := raw.(DevicePlugin)
   644  	if !ok {
   645  		t.Fatalf("bad: %#v", raw)
   646  	}
   647  
   648  	// Create a context
   649  	ctx, cancel := context.WithCancel(context.Background())
   650  	defer cancel()
   651  
   652  	// Get the stream
   653  	stream, err := impl.Stats(ctx, time.Millisecond)
   654  	require.NoError(err)
   655  
   656  	// Get the first message
   657  	var first *StatsResponse
   658  	select {
   659  	case <-time.After(1 * time.Second):
   660  		t.Fatal("timeout")
   661  	case first = <-stream:
   662  	}
   663  
   664  	errStatus := status.Convert(ferr)
   665  	require.EqualError(first.Error, errStatus.Err().Error())
   666  }
   667  
   668  func TestDevicePlugin_Stats_CancelCtx(t *testing.T) {
   669  	ci.Parallel(t)
   670  	require := require.New(t)
   671  
   672  	mock := &MockDevicePlugin{
   673  		StatsF: func(ctx context.Context, interval time.Duration) (<-chan *StatsResponse, error) {
   674  			outCh := make(chan *StatsResponse, 1)
   675  			go func() {
   676  				<-ctx.Done()
   677  				close(outCh)
   678  				return
   679  			}()
   680  			return outCh, nil
   681  		},
   682  	}
   683  
   684  	client, server := plugin.TestPluginGRPCConn(t, map[string]plugin.Plugin{
   685  		base.PluginTypeBase:   &base.PluginBase{Impl: mock},
   686  		base.PluginTypeDevice: &PluginDevice{Impl: mock},
   687  	})
   688  	defer server.Stop()
   689  	defer client.Close()
   690  
   691  	raw, err := client.Dispense(base.PluginTypeDevice)
   692  	if err != nil {
   693  		t.Fatalf("err: %s", err)
   694  	}
   695  
   696  	impl, ok := raw.(DevicePlugin)
   697  	if !ok {
   698  		t.Fatalf("bad: %#v", raw)
   699  	}
   700  
   701  	// Create a context
   702  	ctx, cancel := context.WithCancel(context.Background())
   703  
   704  	// Get the stream
   705  	stream, err := impl.Stats(ctx, time.Millisecond)
   706  	require.NoError(err)
   707  
   708  	// Get the first message
   709  	select {
   710  	case <-time.After(testutil.Timeout(10 * time.Millisecond)):
   711  	case _ = <-stream:
   712  		t.Fatal("bad value")
   713  	}
   714  
   715  	// Cancel the context
   716  	cancel()
   717  
   718  	// Make sure we are done
   719  	select {
   720  	case <-time.After(100 * time.Millisecond):
   721  		t.Fatalf("timeout")
   722  	case v := <-stream:
   723  		require.Error(v.Error)
   724  		require.EqualError(v.Error, context.Canceled.Error())
   725  	}
   726  }