github.com/manicqin/nomad@v0.9.5/plugins/device/plugin_test.go (about)

     1  package device
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"testing"
     7  	"time"
     8  
     9  	pb "github.com/golang/protobuf/proto"
    10  	plugin "github.com/hashicorp/go-plugin"
    11  	"github.com/hashicorp/nomad/helper"
    12  	"github.com/hashicorp/nomad/nomad/structs"
    13  	"github.com/hashicorp/nomad/plugins/base"
    14  	"github.com/hashicorp/nomad/plugins/shared/hclspec"
    15  	psstructs "github.com/hashicorp/nomad/plugins/shared/structs"
    16  	"github.com/hashicorp/nomad/testutil"
    17  	"github.com/stretchr/testify/require"
    18  	"github.com/zclconf/go-cty/cty"
    19  	"github.com/zclconf/go-cty/cty/msgpack"
    20  	"google.golang.org/grpc/status"
    21  )
    22  
    23  func TestDevicePlugin_PluginInfo(t *testing.T) {
    24  	t.Parallel()
    25  	require := require.New(t)
    26  
    27  	var (
    28  		apiVersions = []string{"v0.1.0", "v0.2.0"}
    29  	)
    30  
    31  	const (
    32  		pluginVersion = "v0.2.1"
    33  		pluginName    = "mock_device"
    34  	)
    35  
    36  	knownType := func() (*base.PluginInfoResponse, error) {
    37  		info := &base.PluginInfoResponse{
    38  			Type:              base.PluginTypeDevice,
    39  			PluginApiVersions: apiVersions,
    40  			PluginVersion:     pluginVersion,
    41  			Name:              pluginName,
    42  		}
    43  		return info, nil
    44  	}
    45  	unknownType := func() (*base.PluginInfoResponse, error) {
    46  		info := &base.PluginInfoResponse{
    47  			Type:              "bad",
    48  			PluginApiVersions: apiVersions,
    49  			PluginVersion:     pluginVersion,
    50  			Name:              pluginName,
    51  		}
    52  		return info, nil
    53  	}
    54  
    55  	mock := &MockDevicePlugin{
    56  		MockPlugin: &base.MockPlugin{
    57  			PluginInfoF: knownType,
    58  		},
    59  	}
    60  
    61  	client, server := plugin.TestPluginGRPCConn(t, map[string]plugin.Plugin{
    62  		base.PluginTypeBase:   &base.PluginBase{Impl: mock},
    63  		base.PluginTypeDevice: &PluginDevice{Impl: mock},
    64  	})
    65  	defer server.Stop()
    66  	defer client.Close()
    67  
    68  	raw, err := client.Dispense(base.PluginTypeDevice)
    69  	if err != nil {
    70  		t.Fatalf("err: %s", err)
    71  	}
    72  
    73  	impl, ok := raw.(DevicePlugin)
    74  	if !ok {
    75  		t.Fatalf("bad: %#v", raw)
    76  	}
    77  
    78  	resp, err := impl.PluginInfo()
    79  	require.NoError(err)
    80  	require.Equal(apiVersions, resp.PluginApiVersions)
    81  	require.Equal(pluginVersion, resp.PluginVersion)
    82  	require.Equal(pluginName, resp.Name)
    83  	require.Equal(base.PluginTypeDevice, resp.Type)
    84  
    85  	// Swap the implementation to return an unknown type
    86  	mock.PluginInfoF = unknownType
    87  	_, err = impl.PluginInfo()
    88  	require.Error(err)
    89  	require.Contains(err.Error(), "unknown type")
    90  }
    91  
    92  func TestDevicePlugin_ConfigSchema(t *testing.T) {
    93  	t.Parallel()
    94  	require := require.New(t)
    95  
    96  	mock := &MockDevicePlugin{
    97  		MockPlugin: &base.MockPlugin{
    98  			ConfigSchemaF: func() (*hclspec.Spec, error) {
    99  				return base.TestSpec, nil
   100  			},
   101  		},
   102  	}
   103  
   104  	client, server := plugin.TestPluginGRPCConn(t, map[string]plugin.Plugin{
   105  		base.PluginTypeBase:   &base.PluginBase{Impl: mock},
   106  		base.PluginTypeDevice: &PluginDevice{Impl: mock},
   107  	})
   108  	defer server.Stop()
   109  	defer client.Close()
   110  
   111  	raw, err := client.Dispense(base.PluginTypeDevice)
   112  	if err != nil {
   113  		t.Fatalf("err: %s", err)
   114  	}
   115  
   116  	impl, ok := raw.(DevicePlugin)
   117  	if !ok {
   118  		t.Fatalf("bad: %#v", raw)
   119  	}
   120  
   121  	specOut, err := impl.ConfigSchema()
   122  	require.NoError(err)
   123  	require.True(pb.Equal(base.TestSpec, specOut))
   124  }
   125  
   126  func TestDevicePlugin_SetConfig(t *testing.T) {
   127  	t.Parallel()
   128  	require := require.New(t)
   129  
   130  	var receivedData []byte
   131  	mock := &MockDevicePlugin{
   132  		MockPlugin: &base.MockPlugin{
   133  			PluginInfoF: func() (*base.PluginInfoResponse, error) {
   134  				return &base.PluginInfoResponse{
   135  					Type:              base.PluginTypeDevice,
   136  					PluginApiVersions: []string{"v0.0.1"},
   137  					PluginVersion:     "v0.0.1",
   138  					Name:              "mock_device",
   139  				}, nil
   140  			},
   141  			ConfigSchemaF: func() (*hclspec.Spec, error) {
   142  				return base.TestSpec, nil
   143  			},
   144  			SetConfigF: func(cfg *base.Config) error {
   145  				receivedData = cfg.PluginConfig
   146  				return nil
   147  			},
   148  		},
   149  	}
   150  
   151  	client, server := plugin.TestPluginGRPCConn(t, map[string]plugin.Plugin{
   152  		base.PluginTypeBase:   &base.PluginBase{Impl: mock},
   153  		base.PluginTypeDevice: &PluginDevice{Impl: mock},
   154  	})
   155  	defer server.Stop()
   156  	defer client.Close()
   157  
   158  	raw, err := client.Dispense(base.PluginTypeDevice)
   159  	if err != nil {
   160  		t.Fatalf("err: %s", err)
   161  	}
   162  
   163  	impl, ok := raw.(DevicePlugin)
   164  	if !ok {
   165  		t.Fatalf("bad: %#v", raw)
   166  	}
   167  
   168  	config := cty.ObjectVal(map[string]cty.Value{
   169  		"foo": cty.StringVal("v1"),
   170  		"bar": cty.NumberIntVal(1337),
   171  		"baz": cty.BoolVal(true),
   172  	})
   173  	cdata, err := msgpack.Marshal(config, config.Type())
   174  	require.NoError(err)
   175  	require.NoError(impl.SetConfig(&base.Config{PluginConfig: cdata}))
   176  	require.Equal(cdata, receivedData)
   177  
   178  	// Decode the value back
   179  	var actual base.TestConfig
   180  	require.NoError(structs.Decode(receivedData, &actual))
   181  	require.Equal("v1", actual.Foo)
   182  	require.EqualValues(1337, actual.Bar)
   183  	require.True(actual.Baz)
   184  }
   185  
   186  func TestDevicePlugin_Fingerprint(t *testing.T) {
   187  	t.Parallel()
   188  	require := require.New(t)
   189  
   190  	devices1 := []*DeviceGroup{
   191  		{
   192  			Vendor: "nvidia",
   193  			Type:   DeviceTypeGPU,
   194  			Name:   "foo",
   195  			Attributes: map[string]*psstructs.Attribute{
   196  				"memory": {
   197  					Int:  helper.Int64ToPtr(4),
   198  					Unit: "GiB",
   199  				},
   200  			},
   201  		},
   202  	}
   203  	devices2 := []*DeviceGroup{
   204  		{
   205  			Vendor: "nvidia",
   206  			Type:   DeviceTypeGPU,
   207  			Name:   "foo",
   208  		},
   209  		{
   210  			Vendor: "nvidia",
   211  			Type:   DeviceTypeGPU,
   212  			Name:   "bar",
   213  		},
   214  	}
   215  
   216  	mock := &MockDevicePlugin{
   217  		FingerprintF: func(ctx context.Context) (<-chan *FingerprintResponse, error) {
   218  			outCh := make(chan *FingerprintResponse, 1)
   219  			go func() {
   220  				// Send two messages
   221  				for _, devs := range [][]*DeviceGroup{devices1, devices2} {
   222  					select {
   223  					case <-ctx.Done():
   224  						return
   225  					case outCh <- &FingerprintResponse{Devices: devs}:
   226  					}
   227  				}
   228  				close(outCh)
   229  				return
   230  			}()
   231  			return outCh, nil
   232  		},
   233  	}
   234  
   235  	client, server := plugin.TestPluginGRPCConn(t, map[string]plugin.Plugin{
   236  		base.PluginTypeBase:   &base.PluginBase{Impl: mock},
   237  		base.PluginTypeDevice: &PluginDevice{Impl: mock},
   238  	})
   239  	defer server.Stop()
   240  	defer client.Close()
   241  
   242  	raw, err := client.Dispense(base.PluginTypeDevice)
   243  	if err != nil {
   244  		t.Fatalf("err: %s", err)
   245  	}
   246  
   247  	impl, ok := raw.(DevicePlugin)
   248  	if !ok {
   249  		t.Fatalf("bad: %#v", raw)
   250  	}
   251  
   252  	// Create a context
   253  	ctx, cancel := context.WithCancel(context.Background())
   254  	defer cancel()
   255  
   256  	// Get the stream
   257  	stream, err := impl.Fingerprint(ctx)
   258  	require.NoError(err)
   259  
   260  	// Get the first message
   261  	var first *FingerprintResponse
   262  	select {
   263  	case <-time.After(1 * time.Second):
   264  		t.Fatal("timeout")
   265  	case first = <-stream:
   266  	}
   267  
   268  	require.NoError(first.Error)
   269  	require.EqualValues(devices1, first.Devices)
   270  
   271  	// Get the second message
   272  	var second *FingerprintResponse
   273  	select {
   274  	case <-time.After(1 * time.Second):
   275  		t.Fatal("timeout")
   276  	case second = <-stream:
   277  	}
   278  
   279  	require.NoError(second.Error)
   280  	require.EqualValues(devices2, second.Devices)
   281  
   282  	select {
   283  	case _, ok := <-stream:
   284  		require.False(ok)
   285  	case <-time.After(1 * time.Second):
   286  		t.Fatal("stream should be closed")
   287  	}
   288  }
   289  
   290  func TestDevicePlugin_Fingerprint_StreamErr(t *testing.T) {
   291  	t.Parallel()
   292  	require := require.New(t)
   293  
   294  	ferr := fmt.Errorf("mock fingerprinting failed")
   295  	mock := &MockDevicePlugin{
   296  		FingerprintF: func(ctx context.Context) (<-chan *FingerprintResponse, error) {
   297  			outCh := make(chan *FingerprintResponse, 1)
   298  			go func() {
   299  				// Send the error
   300  				select {
   301  				case <-ctx.Done():
   302  					return
   303  				case outCh <- &FingerprintResponse{Error: ferr}:
   304  				}
   305  
   306  				close(outCh)
   307  				return
   308  			}()
   309  			return outCh, nil
   310  		},
   311  	}
   312  
   313  	client, server := plugin.TestPluginGRPCConn(t, map[string]plugin.Plugin{
   314  		base.PluginTypeBase:   &base.PluginBase{Impl: mock},
   315  		base.PluginTypeDevice: &PluginDevice{Impl: mock},
   316  	})
   317  	defer server.Stop()
   318  	defer client.Close()
   319  
   320  	raw, err := client.Dispense(base.PluginTypeDevice)
   321  	if err != nil {
   322  		t.Fatalf("err: %s", err)
   323  	}
   324  
   325  	impl, ok := raw.(DevicePlugin)
   326  	if !ok {
   327  		t.Fatalf("bad: %#v", raw)
   328  	}
   329  
   330  	// Create a context
   331  	ctx, cancel := context.WithCancel(context.Background())
   332  	defer cancel()
   333  
   334  	// Get the stream
   335  	stream, err := impl.Fingerprint(ctx)
   336  	require.NoError(err)
   337  
   338  	// Get the first message
   339  	var first *FingerprintResponse
   340  	select {
   341  	case <-time.After(1 * time.Second):
   342  		t.Fatal("timeout")
   343  	case first = <-stream:
   344  	}
   345  
   346  	errStatus := status.Convert(ferr)
   347  	require.EqualError(first.Error, errStatus.Err().Error())
   348  }
   349  
   350  func TestDevicePlugin_Fingerprint_CancelCtx(t *testing.T) {
   351  	t.Parallel()
   352  	require := require.New(t)
   353  
   354  	mock := &MockDevicePlugin{
   355  		FingerprintF: func(ctx context.Context) (<-chan *FingerprintResponse, error) {
   356  			outCh := make(chan *FingerprintResponse, 1)
   357  			go func() {
   358  				<-ctx.Done()
   359  				close(outCh)
   360  				return
   361  			}()
   362  			return outCh, nil
   363  		},
   364  	}
   365  
   366  	client, server := plugin.TestPluginGRPCConn(t, map[string]plugin.Plugin{
   367  		base.PluginTypeBase:   &base.PluginBase{Impl: mock},
   368  		base.PluginTypeDevice: &PluginDevice{Impl: mock},
   369  	})
   370  	defer server.Stop()
   371  	defer client.Close()
   372  
   373  	raw, err := client.Dispense(base.PluginTypeDevice)
   374  	if err != nil {
   375  		t.Fatalf("err: %s", err)
   376  	}
   377  
   378  	impl, ok := raw.(DevicePlugin)
   379  	if !ok {
   380  		t.Fatalf("bad: %#v", raw)
   381  	}
   382  
   383  	// Create a context
   384  	ctx, cancel := context.WithCancel(context.Background())
   385  
   386  	// Get the stream
   387  	stream, err := impl.Fingerprint(ctx)
   388  	require.NoError(err)
   389  
   390  	// Get the first message
   391  	select {
   392  	case <-time.After(testutil.Timeout(10 * time.Millisecond)):
   393  	case _ = <-stream:
   394  		t.Fatal("bad value")
   395  	}
   396  
   397  	// Cancel the context
   398  	cancel()
   399  
   400  	// Make sure we are done
   401  	select {
   402  	case <-time.After(100 * time.Millisecond):
   403  		t.Fatalf("timeout")
   404  	case v := <-stream:
   405  		require.Error(v.Error)
   406  		require.EqualError(v.Error, context.Canceled.Error())
   407  	}
   408  }
   409  
   410  func TestDevicePlugin_Reserve(t *testing.T) {
   411  	t.Parallel()
   412  	require := require.New(t)
   413  
   414  	reservation := &ContainerReservation{
   415  		Envs: map[string]string{
   416  			"foo": "bar",
   417  		},
   418  		Mounts: []*Mount{
   419  			{
   420  				TaskPath: "foo",
   421  				HostPath: "bar",
   422  				ReadOnly: true,
   423  			},
   424  		},
   425  		Devices: []*DeviceSpec{
   426  			{
   427  				TaskPath:    "foo",
   428  				HostPath:    "bar",
   429  				CgroupPerms: "rx",
   430  			},
   431  		},
   432  	}
   433  
   434  	var received []string
   435  	mock := &MockDevicePlugin{
   436  		ReserveF: func(devices []string) (*ContainerReservation, error) {
   437  			received = devices
   438  			return reservation, nil
   439  		},
   440  	}
   441  
   442  	client, server := plugin.TestPluginGRPCConn(t, map[string]plugin.Plugin{
   443  		base.PluginTypeBase:   &base.PluginBase{Impl: mock},
   444  		base.PluginTypeDevice: &PluginDevice{Impl: mock},
   445  	})
   446  	defer server.Stop()
   447  	defer client.Close()
   448  
   449  	raw, err := client.Dispense(base.PluginTypeDevice)
   450  	if err != nil {
   451  		t.Fatalf("err: %s", err)
   452  	}
   453  
   454  	impl, ok := raw.(DevicePlugin)
   455  	if !ok {
   456  		t.Fatalf("bad: %#v", raw)
   457  	}
   458  
   459  	req := []string{"a", "b"}
   460  	containerRes, err := impl.Reserve(req)
   461  	require.NoError(err)
   462  	require.EqualValues(req, received)
   463  	require.EqualValues(reservation, containerRes)
   464  }
   465  
   466  func TestDevicePlugin_Stats(t *testing.T) {
   467  	t.Parallel()
   468  	require := require.New(t)
   469  
   470  	devices1 := []*DeviceGroupStats{
   471  		{
   472  			Vendor: "nvidia",
   473  			Type:   DeviceTypeGPU,
   474  			Name:   "foo",
   475  			InstanceStats: map[string]*DeviceStats{
   476  				"1": {
   477  					Summary: &psstructs.StatValue{
   478  						IntNumeratorVal:   helper.Int64ToPtr(10),
   479  						IntDenominatorVal: helper.Int64ToPtr(20),
   480  						Unit:              "MB",
   481  						Desc:              "Unit test",
   482  					},
   483  				},
   484  			},
   485  		},
   486  	}
   487  	devices2 := []*DeviceGroupStats{
   488  		{
   489  			Vendor: "nvidia",
   490  			Type:   DeviceTypeGPU,
   491  			Name:   "foo",
   492  			InstanceStats: map[string]*DeviceStats{
   493  				"1": {
   494  					Summary: &psstructs.StatValue{
   495  						FloatNumeratorVal:   helper.Float64ToPtr(10.0),
   496  						FloatDenominatorVal: helper.Float64ToPtr(20.0),
   497  						Unit:                "MB",
   498  						Desc:                "Unit test",
   499  					},
   500  				},
   501  			},
   502  		},
   503  		{
   504  			Vendor: "nvidia",
   505  			Type:   DeviceTypeGPU,
   506  			Name:   "bar",
   507  			InstanceStats: map[string]*DeviceStats{
   508  				"1": {
   509  					Summary: &psstructs.StatValue{
   510  						StringVal: helper.StringToPtr("foo"),
   511  						Unit:      "MB",
   512  						Desc:      "Unit test",
   513  					},
   514  				},
   515  			},
   516  		},
   517  		{
   518  			Vendor: "nvidia",
   519  			Type:   DeviceTypeGPU,
   520  			Name:   "baz",
   521  			InstanceStats: map[string]*DeviceStats{
   522  				"1": {
   523  					Summary: &psstructs.StatValue{
   524  						BoolVal: helper.BoolToPtr(true),
   525  						Unit:    "MB",
   526  						Desc:    "Unit test",
   527  					},
   528  				},
   529  			},
   530  		},
   531  	}
   532  
   533  	mock := &MockDevicePlugin{
   534  		StatsF: func(ctx context.Context, interval time.Duration) (<-chan *StatsResponse, error) {
   535  			outCh := make(chan *StatsResponse, 1)
   536  			go func() {
   537  				// Send two messages
   538  				for _, devs := range [][]*DeviceGroupStats{devices1, devices2} {
   539  					select {
   540  					case <-ctx.Done():
   541  						return
   542  					case outCh <- &StatsResponse{Groups: devs}:
   543  					}
   544  				}
   545  				close(outCh)
   546  				return
   547  			}()
   548  			return outCh, nil
   549  		},
   550  	}
   551  
   552  	client, server := plugin.TestPluginGRPCConn(t, map[string]plugin.Plugin{
   553  		base.PluginTypeBase:   &base.PluginBase{Impl: mock},
   554  		base.PluginTypeDevice: &PluginDevice{Impl: mock},
   555  	})
   556  	defer server.Stop()
   557  	defer client.Close()
   558  
   559  	raw, err := client.Dispense(base.PluginTypeDevice)
   560  	if err != nil {
   561  		t.Fatalf("err: %s", err)
   562  	}
   563  
   564  	impl, ok := raw.(DevicePlugin)
   565  	if !ok {
   566  		t.Fatalf("bad: %#v", raw)
   567  	}
   568  
   569  	// Create a context
   570  	ctx, cancel := context.WithCancel(context.Background())
   571  	defer cancel()
   572  
   573  	// Get the stream
   574  	stream, err := impl.Stats(ctx, time.Millisecond)
   575  	require.NoError(err)
   576  
   577  	// Get the first message
   578  	var first *StatsResponse
   579  	select {
   580  	case <-time.After(1 * time.Second):
   581  		t.Fatal("timeout")
   582  	case first = <-stream:
   583  	}
   584  
   585  	require.NoError(first.Error)
   586  	require.EqualValues(devices1, first.Groups)
   587  
   588  	// Get the second message
   589  	var second *StatsResponse
   590  	select {
   591  	case <-time.After(1 * time.Second):
   592  		t.Fatal("timeout")
   593  	case second = <-stream:
   594  	}
   595  
   596  	require.NoError(second.Error)
   597  	require.EqualValues(devices2, second.Groups)
   598  
   599  	select {
   600  	case _, ok := <-stream:
   601  		require.False(ok)
   602  	case <-time.After(1 * time.Second):
   603  		t.Fatal("stream should be closed")
   604  	}
   605  }
   606  
   607  func TestDevicePlugin_Stats_StreamErr(t *testing.T) {
   608  	t.Parallel()
   609  	require := require.New(t)
   610  
   611  	ferr := fmt.Errorf("mock stats failed")
   612  	mock := &MockDevicePlugin{
   613  		StatsF: func(ctx context.Context, interval time.Duration) (<-chan *StatsResponse, error) {
   614  			outCh := make(chan *StatsResponse, 1)
   615  			go func() {
   616  				// Send the error
   617  				select {
   618  				case <-ctx.Done():
   619  					return
   620  				case outCh <- &StatsResponse{Error: ferr}:
   621  				}
   622  
   623  				close(outCh)
   624  				return
   625  			}()
   626  			return outCh, nil
   627  		},
   628  	}
   629  
   630  	client, server := plugin.TestPluginGRPCConn(t, map[string]plugin.Plugin{
   631  		base.PluginTypeBase:   &base.PluginBase{Impl: mock},
   632  		base.PluginTypeDevice: &PluginDevice{Impl: mock},
   633  	})
   634  	defer server.Stop()
   635  	defer client.Close()
   636  
   637  	raw, err := client.Dispense(base.PluginTypeDevice)
   638  	if err != nil {
   639  		t.Fatalf("err: %s", err)
   640  	}
   641  
   642  	impl, ok := raw.(DevicePlugin)
   643  	if !ok {
   644  		t.Fatalf("bad: %#v", raw)
   645  	}
   646  
   647  	// Create a context
   648  	ctx, cancel := context.WithCancel(context.Background())
   649  	defer cancel()
   650  
   651  	// Get the stream
   652  	stream, err := impl.Stats(ctx, time.Millisecond)
   653  	require.NoError(err)
   654  
   655  	// Get the first message
   656  	var first *StatsResponse
   657  	select {
   658  	case <-time.After(1 * time.Second):
   659  		t.Fatal("timeout")
   660  	case first = <-stream:
   661  	}
   662  
   663  	errStatus := status.Convert(ferr)
   664  	require.EqualError(first.Error, errStatus.Err().Error())
   665  }
   666  
   667  func TestDevicePlugin_Stats_CancelCtx(t *testing.T) {
   668  	t.Parallel()
   669  	require := require.New(t)
   670  
   671  	mock := &MockDevicePlugin{
   672  		StatsF: func(ctx context.Context, interval time.Duration) (<-chan *StatsResponse, error) {
   673  			outCh := make(chan *StatsResponse, 1)
   674  			go func() {
   675  				<-ctx.Done()
   676  				close(outCh)
   677  				return
   678  			}()
   679  			return outCh, nil
   680  		},
   681  	}
   682  
   683  	client, server := plugin.TestPluginGRPCConn(t, map[string]plugin.Plugin{
   684  		base.PluginTypeBase:   &base.PluginBase{Impl: mock},
   685  		base.PluginTypeDevice: &PluginDevice{Impl: mock},
   686  	})
   687  	defer server.Stop()
   688  	defer client.Close()
   689  
   690  	raw, err := client.Dispense(base.PluginTypeDevice)
   691  	if err != nil {
   692  		t.Fatalf("err: %s", err)
   693  	}
   694  
   695  	impl, ok := raw.(DevicePlugin)
   696  	if !ok {
   697  		t.Fatalf("bad: %#v", raw)
   698  	}
   699  
   700  	// Create a context
   701  	ctx, cancel := context.WithCancel(context.Background())
   702  
   703  	// Get the stream
   704  	stream, err := impl.Stats(ctx, time.Millisecond)
   705  	require.NoError(err)
   706  
   707  	// Get the first message
   708  	select {
   709  	case <-time.After(testutil.Timeout(10 * time.Millisecond)):
   710  	case _ = <-stream:
   711  		t.Fatal("bad value")
   712  	}
   713  
   714  	// Cancel the context
   715  	cancel()
   716  
   717  	// Make sure we are done
   718  	select {
   719  	case <-time.After(100 * time.Millisecond):
   720  		t.Fatalf("timeout")
   721  	case v := <-stream:
   722  		require.Error(v.Error)
   723  		require.EqualError(v.Error, context.Canceled.Error())
   724  	}
   725  }