github.com/m3db/m3@v1.5.1-0.20231129193456-75a402aa583b/src/msg/producer/writer/consumer_service_writer_test.go (about)

     1  // Copyright (c) 2018 Uber Technologies, Inc.
     2  //
     3  // Permission is hereby granted, free of charge, to any person obtaining a copy
     4  // of this software and associated documentation files (the "Software"), to deal
     5  // in the Software without restriction, including without limitation the rights
     6  // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
     7  // copies of the Software, and to permit persons to whom the Software is
     8  // furnished to do so, subject to the following conditions:
     9  //
    10  // The above copyright notice and this permission notice shall be included in
    11  // all copies or substantial portions of the Software.
    12  //
    13  // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    14  // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    15  // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    16  // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    17  // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    18  // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
    19  // THE SOFTWARE.
    20  
    21  package writer
    22  
    23  import (
    24  	"errors"
    25  	"net"
    26  	"sync"
    27  	"testing"
    28  	"time"
    29  
    30  	"github.com/m3db/m3/src/cluster/kv"
    31  	"github.com/m3db/m3/src/cluster/kv/mem"
    32  	"github.com/m3db/m3/src/cluster/placement"
    33  	"github.com/m3db/m3/src/cluster/placement/service"
    34  	"github.com/m3db/m3/src/cluster/placement/storage"
    35  	"github.com/m3db/m3/src/cluster/services"
    36  	"github.com/m3db/m3/src/cluster/shard"
    37  	"github.com/m3db/m3/src/msg/generated/proto/msgpb"
    38  	"github.com/m3db/m3/src/msg/producer"
    39  	"github.com/m3db/m3/src/msg/protocol/proto"
    40  	"github.com/m3db/m3/src/msg/topic"
    41  	xtest "github.com/m3db/m3/src/x/test"
    42  
    43  	"github.com/fortytw2/leaktest"
    44  	"github.com/golang/mock/gomock"
    45  	"github.com/stretchr/testify/require"
    46  )
    47  
    48  func TestConsumerServiceWriterWithSharedConsumerWithNonShardedPlacement(t *testing.T) {
    49  	defer leaktest.Check(t)()
    50  
    51  	ctrl := xtest.NewController(t)
    52  	defer ctrl.Finish()
    53  
    54  	sid := services.NewServiceID().SetName("foo")
    55  	cs := topic.NewConsumerService().SetServiceID(sid).SetConsumptionType(topic.Shared)
    56  	sd := services.NewMockServices(ctrl)
    57  	ps := testPlacementService(mem.NewStore(), sid)
    58  	sd.EXPECT().PlacementService(sid, gomock.Any()).Return(ps, nil)
    59  
    60  	opts := testOptions().SetServiceDiscovery(sd)
    61  	w, err := newConsumerServiceWriter(cs, 2, opts)
    62  	require.NoError(t, err)
    63  
    64  	csw := w.(*consumerServiceWriterImpl)
    65  
    66  	var (
    67  		lock               sync.Mutex
    68  		numConsumerWriters int
    69  	)
    70  	csw.processFn = func(p interface{}) error {
    71  		err := csw.process(p)
    72  		lock.Lock()
    73  		numConsumerWriters = len(csw.consumerWriters)
    74  		lock.Unlock()
    75  		return err
    76  	}
    77  
    78  	require.NoError(t, csw.Init(allowInitValueError))
    79  	lock.Lock()
    80  	require.Equal(t, 0, numConsumerWriters)
    81  	lock.Unlock()
    82  
    83  	lis, err := net.Listen("tcp", "127.0.0.1:0")
    84  	require.NoError(t, err)
    85  	defer lis.Close()
    86  
    87  	p1 := placement.NewPlacement().
    88  		SetInstances([]placement.Instance{
    89  			placement.NewInstance().
    90  				SetID("i1").
    91  				SetEndpoint(lis.Addr().String()),
    92  			placement.NewInstance().
    93  				SetID("i2").
    94  				SetEndpoint("addr2"),
    95  			placement.NewInstance().
    96  				SetID("i3").
    97  				SetEndpoint("addr3"),
    98  		}).
    99  		SetIsSharded(false)
   100  	_, err = ps.Set(p1)
   101  	require.NoError(t, err)
   102  
   103  	for {
   104  		lock.Lock()
   105  		l := numConsumerWriters
   106  		lock.Unlock()
   107  		if l == 3 {
   108  			break
   109  		}
   110  		time.Sleep(100 * time.Millisecond)
   111  	}
   112  
   113  	var wg sync.WaitGroup
   114  	defer wg.Wait()
   115  
   116  	wg.Add(1)
   117  	go func() {
   118  		testConsumeAndAckOnConnectionListener(t, lis, opts.EncoderOptions(), opts.DecoderOptions())
   119  		wg.Done()
   120  	}()
   121  
   122  	mm := producer.NewMockMessage(ctrl)
   123  	mm.EXPECT().Shard().Return(uint32(1))
   124  	mm.EXPECT().Bytes().Return([]byte("foo"))
   125  	mm.EXPECT().Size().Return(3)
   126  	mm.EXPECT().Finalize(producer.Consumed)
   127  
   128  	rm := producer.NewRefCountedMessage(mm, nil)
   129  	csw.Write(rm)
   130  	for {
   131  		if rm.IsDroppedOrConsumed() {
   132  			break
   133  		}
   134  		time.Sleep(100 * time.Millisecond)
   135  	}
   136  
   137  	for _, sw := range w.(*consumerServiceWriterImpl).shardWriters {
   138  		require.Equal(t, 3, len(sw.(*sharedShardWriter).mw.consumerWriters))
   139  	}
   140  
   141  	p2 := placement.NewPlacement().
   142  		SetInstances([]placement.Instance{
   143  			placement.NewInstance().
   144  				SetID("i1").
   145  				SetEndpoint(lis.Addr().String()),
   146  			placement.NewInstance().
   147  				SetID("i2").
   148  				SetEndpoint("addr2"),
   149  		}).
   150  		SetIsSharded(false)
   151  	_, err = ps.Set(p2)
   152  	require.NoError(t, err)
   153  
   154  	for {
   155  		lock.Lock()
   156  		l := numConsumerWriters
   157  		lock.Unlock()
   158  		if l == 2 {
   159  			break
   160  		}
   161  		time.Sleep(100 * time.Millisecond)
   162  	}
   163  
   164  	for _, sw := range w.(*consumerServiceWriterImpl).shardWriters {
   165  		require.Equal(t, 2, len(sw.(*sharedShardWriter).mw.consumerWriters))
   166  	}
   167  
   168  	csw.Close()
   169  	csw.Close()
   170  }
   171  
   172  func TestConsumerServiceWriterWithSharedConsumerWithShardedPlacement(t *testing.T) {
   173  	defer leaktest.Check(t)()
   174  
   175  	ctrl := xtest.NewController(t)
   176  	defer ctrl.Finish()
   177  
   178  	sid := services.NewServiceID().SetName("foo")
   179  	cs := topic.NewConsumerService().SetServiceID(sid).SetConsumptionType(topic.Shared)
   180  	sd := services.NewMockServices(ctrl)
   181  	ps := testPlacementService(mem.NewStore(), sid)
   182  	sd.EXPECT().PlacementService(sid, gomock.Any()).Return(ps, nil)
   183  
   184  	opts := testOptions().SetServiceDiscovery(sd)
   185  	w, err := newConsumerServiceWriter(cs, 3, opts)
   186  	require.NoError(t, err)
   187  
   188  	csw := w.(*consumerServiceWriterImpl)
   189  
   190  	var (
   191  		lock               sync.Mutex
   192  		numConsumerWriters int
   193  	)
   194  	csw.processFn = func(p interface{}) error {
   195  		err := csw.process(p)
   196  		lock.Lock()
   197  		numConsumerWriters = len(csw.consumerWriters)
   198  		lock.Unlock()
   199  		return err
   200  	}
   201  
   202  	require.NoError(t, csw.Init(allowInitValueError))
   203  	lock.Lock()
   204  	require.Equal(t, 0, numConsumerWriters)
   205  	lock.Unlock()
   206  
   207  	lis, err := net.Listen("tcp", "127.0.0.1:0")
   208  	require.NoError(t, err)
   209  	defer lis.Close()
   210  
   211  	p1 := placement.NewPlacement().
   212  		SetInstances([]placement.Instance{
   213  			placement.NewInstance().
   214  				SetID("i1").
   215  				SetEndpoint(lis.Addr().String()).
   216  				SetShards(shard.NewShards([]shard.Shard{
   217  					shard.NewShard(1).SetState(shard.Available),
   218  					shard.NewShard(2).SetState(shard.Available),
   219  				})),
   220  			placement.NewInstance().
   221  				SetID("i2").
   222  				SetEndpoint("addr2").
   223  				SetShards(shard.NewShards([]shard.Shard{
   224  					shard.NewShard(0).SetState(shard.Available),
   225  					shard.NewShard(2).SetState(shard.Available),
   226  				})),
   227  			placement.NewInstance().
   228  				SetID("i3").
   229  				SetEndpoint("addr3").
   230  				SetShards(shard.NewShards([]shard.Shard{
   231  					shard.NewShard(0).SetState(shard.Available),
   232  					shard.NewShard(1).SetState(shard.Available),
   233  				})),
   234  		}).
   235  		SetShards([]uint32{0, 1, 2}).
   236  		SetReplicaFactor(2).
   237  		SetIsSharded(true)
   238  	_, err = ps.Set(p1)
   239  	require.NoError(t, err)
   240  
   241  	for {
   242  		lock.Lock()
   243  		l := numConsumerWriters
   244  		lock.Unlock()
   245  		if l == 3 {
   246  			break
   247  		}
   248  		time.Sleep(100 * time.Millisecond)
   249  	}
   250  
   251  	var wg sync.WaitGroup
   252  	defer wg.Wait()
   253  
   254  	wg.Add(1)
   255  	go func() {
   256  		testConsumeAndAckOnConnectionListener(t, lis, opts.EncoderOptions(), opts.DecoderOptions())
   257  		wg.Done()
   258  	}()
   259  
   260  	mm := producer.NewMockMessage(ctrl)
   261  	mm.EXPECT().Shard().Return(uint32(1))
   262  	mm.EXPECT().Bytes().Return([]byte("foo"))
   263  	mm.EXPECT().Size().Return(3)
   264  	mm.EXPECT().Finalize(producer.Consumed)
   265  
   266  	rm := producer.NewRefCountedMessage(mm, nil)
   267  	csw.Write(rm)
   268  	for {
   269  		if rm.IsDroppedOrConsumed() {
   270  			break
   271  		}
   272  		time.Sleep(100 * time.Millisecond)
   273  	}
   274  	p2 := placement.NewPlacement().
   275  		SetInstances([]placement.Instance{
   276  			placement.NewInstance().
   277  				SetID("i1").
   278  				SetEndpoint(lis.Addr().String()).
   279  				SetShards(shard.NewShards([]shard.Shard{
   280  					shard.NewShard(1).SetState(shard.Available),
   281  				})),
   282  			placement.NewInstance().
   283  				SetID("i2").
   284  				SetEndpoint("addr2").
   285  				SetShards(shard.NewShards([]shard.Shard{
   286  					shard.NewShard(0).SetState(shard.Available),
   287  					shard.NewShard(2).SetState(shard.Available),
   288  				})),
   289  		}).
   290  		SetShards([]uint32{0, 1, 2}).
   291  		SetReplicaFactor(1).
   292  		SetIsSharded(true)
   293  	_, err = ps.Set(p2)
   294  	require.NoError(t, err)
   295  
   296  	for {
   297  		lock.Lock()
   298  		l := numConsumerWriters
   299  		lock.Unlock()
   300  		if l == 2 {
   301  			break
   302  		}
   303  		time.Sleep(100 * time.Millisecond)
   304  	}
   305  
   306  	csw.Close()
   307  	csw.Close()
   308  }
   309  
   310  func TestConsumerServiceWriterWithReplicatedConsumerWithShardedPlacement(t *testing.T) {
   311  	defer leaktest.Check(t)()
   312  
   313  	ctrl := xtest.NewController(t)
   314  	defer ctrl.Finish()
   315  
   316  	sid := services.NewServiceID().SetName("foo")
   317  	cs := topic.NewConsumerService().SetServiceID(sid).SetConsumptionType(topic.Replicated)
   318  	sd := services.NewMockServices(ctrl)
   319  	ps := testPlacementService(mem.NewStore(), sid)
   320  	sd.EXPECT().PlacementService(sid, gomock.Any()).Return(ps, nil)
   321  
   322  	lis1, err := net.Listen("tcp", "127.0.0.1:0")
   323  	require.NoError(t, err)
   324  	defer lis1.Close()
   325  
   326  	lis2, err := net.Listen("tcp", "127.0.0.1:0")
   327  	require.NoError(t, err)
   328  	defer lis2.Close()
   329  
   330  	p1 := placement.NewPlacement().
   331  		SetInstances([]placement.Instance{
   332  			placement.NewInstance().
   333  				SetID("i1").
   334  				SetEndpoint(lis1.Addr().String()).
   335  				SetShards(shard.NewShards([]shard.Shard{
   336  					shard.NewShard(0).SetState(shard.Available),
   337  					shard.NewShard(1).SetState(shard.Available),
   338  				})),
   339  			placement.NewInstance().
   340  				SetID("i2").
   341  				SetEndpoint(lis2.Addr().String()).
   342  				SetShards(shard.NewShards([]shard.Shard{
   343  					shard.NewShard(1).SetState(shard.Available),
   344  				})),
   345  			placement.NewInstance().
   346  				SetID("i3").
   347  				SetEndpoint("addr3").
   348  				SetShards(shard.NewShards([]shard.Shard{
   349  					shard.NewShard(0).SetState(shard.Available),
   350  				})),
   351  		}).
   352  		SetShards([]uint32{0, 1}).
   353  		SetReplicaFactor(2).
   354  		SetIsSharded(true)
   355  	_, err = ps.Set(p1)
   356  	require.NoError(t, err)
   357  
   358  	opts := testOptions().SetServiceDiscovery(sd)
   359  	w, err := newConsumerServiceWriter(cs, 2, opts)
   360  	csw := w.(*consumerServiceWriterImpl)
   361  	require.NoError(t, err)
   362  	require.NotNil(t, csw)
   363  
   364  	var (
   365  		lock               sync.Mutex
   366  		numConsumerWriters int
   367  	)
   368  	csw.processFn = func(p interface{}) error {
   369  		err := csw.process(p)
   370  		lock.Lock()
   371  		numConsumerWriters = len(csw.consumerWriters)
   372  		lock.Unlock()
   373  		return err
   374  	}
   375  	require.NoError(t, csw.Init(allowInitValueError))
   376  
   377  	for {
   378  		lock.Lock()
   379  		l := numConsumerWriters
   380  		lock.Unlock()
   381  		if l == 3 {
   382  			break
   383  		}
   384  		time.Sleep(100 * time.Millisecond)
   385  	}
   386  
   387  	mm := producer.NewMockMessage(ctrl)
   388  	mm.EXPECT().Shard().Return(uint32(1)).AnyTimes()
   389  	mm.EXPECT().Bytes().Return([]byte("foo")).AnyTimes()
   390  	mm.EXPECT().Size().Return(3)
   391  	mm.EXPECT().Finalize(producer.Consumed)
   392  
   393  	rm := producer.NewRefCountedMessage(mm, nil)
   394  	csw.Write(rm)
   395  	var wg sync.WaitGroup
   396  	wg.Add(1)
   397  	go func() {
   398  		testConsumeAndAckOnConnectionListener(t, lis1, opts.EncoderOptions(), opts.DecoderOptions())
   399  		wg.Done()
   400  	}()
   401  
   402  	wg.Add(1)
   403  	go func() {
   404  		testConsumeAndAckOnConnectionListener(t, lis2, opts.EncoderOptions(), opts.DecoderOptions())
   405  		wg.Done()
   406  	}()
   407  	wg.Wait()
   408  
   409  	for {
   410  		if rm.IsDroppedOrConsumed() {
   411  			break
   412  		}
   413  		time.Sleep(100 * time.Millisecond)
   414  	}
   415  
   416  	p2 := placement.NewPlacement().
   417  		SetInstances([]placement.Instance{
   418  			placement.NewInstance().
   419  				SetID("i1").
   420  				SetEndpoint(lis1.Addr().String()).
   421  				SetShards(shard.NewShards([]shard.Shard{
   422  					shard.NewShard(0).SetState(shard.Available),
   423  				})),
   424  			placement.NewInstance().
   425  				SetID("i2").
   426  				SetEndpoint(lis2.Addr().String()).
   427  				SetShards(shard.NewShards([]shard.Shard{
   428  					shard.NewShard(1).SetState(shard.Available),
   429  				})),
   430  		}).
   431  		SetShards([]uint32{0, 1}).
   432  		SetReplicaFactor(1).
   433  		SetIsSharded(true)
   434  	_, err = ps.Set(p2)
   435  	require.NoError(t, err)
   436  
   437  	for {
   438  		lock.Lock()
   439  		l := numConsumerWriters
   440  		lock.Unlock()
   441  		if l == 2 {
   442  			break
   443  		}
   444  		time.Sleep(100 * time.Millisecond)
   445  	}
   446  
   447  	go func() {
   448  		for {
   449  			conn, err := lis2.Accept()
   450  			if err != nil {
   451  				return
   452  			}
   453  			serverEncoder := proto.NewEncoder(opts.EncoderOptions())
   454  			serverDecoder := proto.NewDecoder(conn, opts.DecoderOptions(), 10)
   455  
   456  			var msg msgpb.Message
   457  			err = serverDecoder.Decode(&msg)
   458  			if err != nil {
   459  				conn.Close()
   460  				continue
   461  			}
   462  			require.NoError(t, serverEncoder.Encode(&msgpb.Ack{
   463  				Metadata: []msgpb.Metadata{
   464  					msg.Metadata,
   465  				},
   466  			}))
   467  			_, err = conn.Write(serverEncoder.Bytes())
   468  			require.NoError(t, err)
   469  			conn.Close()
   470  		}
   471  	}()
   472  
   473  	mm.EXPECT().Finalize(producer.Consumed)
   474  	mm.EXPECT().Size().Return(3)
   475  	rm = producer.NewRefCountedMessage(mm, nil)
   476  	csw.Write(rm)
   477  	for {
   478  		if rm.IsDroppedOrConsumed() {
   479  			break
   480  		}
   481  		time.Sleep(100 * time.Millisecond)
   482  	}
   483  
   484  	csw.Close()
   485  	csw.Close()
   486  }
   487  
   488  func TestConsumerServiceWriterFilter(t *testing.T) {
   489  	defer leaktest.Check(t)()
   490  
   491  	ctrl := xtest.NewController(t)
   492  	defer ctrl.Finish()
   493  
   494  	sid := services.NewServiceID().SetName("foo")
   495  	cs := topic.NewConsumerService().SetServiceID(sid).SetConsumptionType(topic.Replicated)
   496  	sd := services.NewMockServices(ctrl)
   497  	ps := testPlacementService(mem.NewStore(), sid)
   498  	sd.EXPECT().PlacementService(sid, gomock.Any()).Return(ps, nil)
   499  
   500  	opts := testOptions().SetServiceDiscovery(sd)
   501  	csw, err := newConsumerServiceWriter(cs, 3, opts)
   502  	require.NoError(t, err)
   503  
   504  	sw0 := NewMockshardWriter(ctrl)
   505  	sw1 := NewMockshardWriter(ctrl)
   506  	csw.(*consumerServiceWriterImpl).shardWriters[0] = sw0
   507  	csw.(*consumerServiceWriterImpl).shardWriters[1] = sw1
   508  
   509  	mm0 := producer.NewMockMessage(ctrl)
   510  	mm0.EXPECT().Shard().Return(uint32(0)).AnyTimes()
   511  	mm0.EXPECT().Size().Return(3).AnyTimes()
   512  	mm1 := producer.NewMockMessage(ctrl)
   513  	mm1.EXPECT().Shard().Return(uint32(1)).AnyTimes()
   514  	mm1.EXPECT().Size().Return(3).AnyTimes()
   515  	mm2 := producer.NewMockMessage(ctrl)
   516  	mm2.EXPECT().Shard().Return(uint32(0)).AnyTimes()
   517  	mm2.EXPECT().Size().Return(4).AnyTimes()
   518  
   519  	sw0.EXPECT().Write(gomock.Any())
   520  	csw.Write(producer.NewRefCountedMessage(mm0, nil))
   521  	sw1.EXPECT().Write(gomock.Any())
   522  	csw.Write(producer.NewRefCountedMessage(mm1, nil))
   523  
   524  	csw.RegisterFilter(func(m producer.Message) bool { return m.Shard() == uint32(0) })
   525  	// Write is not expected due to mm1 shard != 0
   526  	csw.Write(producer.NewRefCountedMessage(mm1, nil))
   527  
   528  	sw0.EXPECT().Write(gomock.Any())
   529  	// Write is expected due to mm0 shard == 0
   530  	csw.Write(producer.NewRefCountedMessage(mm0, nil))
   531  
   532  	csw.RegisterFilter(func(m producer.Message) bool { return m.Size() == 3 })
   533  	sw0.EXPECT().Write(gomock.Any())
   534  	// Write is expected because to mm0 shard == 0 and mm0 size == 3
   535  	csw.Write(producer.NewRefCountedMessage(mm0, nil))
   536  
   537  	// Write is not expected because to mm2 size != 3
   538  	csw.Write(producer.NewRefCountedMessage(mm2, nil))
   539  
   540  	// All messages are expected to write after unregistering filters
   541  	csw.UnregisterFilters()
   542  	sw0.EXPECT().Write(gomock.Any())
   543  	csw.Write(producer.NewRefCountedMessage(mm0, nil))
   544  	sw1.EXPECT().Write(gomock.Any())
   545  	csw.Write(producer.NewRefCountedMessage(mm1, nil))
   546  	sw0.EXPECT().Write(gomock.Any())
   547  	csw.Write(producer.NewRefCountedMessage(mm2, nil))
   548  }
   549  
   550  func TestConsumerServiceWriterAllowInitValueErrorWithCreateWatchError(t *testing.T) {
   551  	defer leaktest.Check(t)()
   552  
   553  	ctrl := xtest.NewController(t)
   554  	defer ctrl.Finish()
   555  
   556  	sid := services.NewServiceID().SetName("foo")
   557  	cs := topic.NewConsumerService().SetServiceID(sid).SetConsumptionType(topic.Shared)
   558  
   559  	ps := placement.NewMockService(ctrl)
   560  	ps.EXPECT().Watch().Return(nil, errors.New("mock err")).AnyTimes()
   561  
   562  	sd := services.NewMockServices(ctrl)
   563  	sd.EXPECT().PlacementService(sid, gomock.Any()).Return(ps, nil)
   564  
   565  	opts := testOptions().SetServiceDiscovery(sd)
   566  	w, err := newConsumerServiceWriter(cs, 3, opts)
   567  	require.NoError(t, err)
   568  	defer w.Close()
   569  
   570  	require.Error(t, w.Init(allowInitValueError))
   571  }
   572  
   573  func TestConsumerServiceWriterAllowInitValueErrorWithInitValueError(t *testing.T) {
   574  	defer leaktest.Check(t)()
   575  
   576  	ctrl := xtest.NewController(t)
   577  	defer ctrl.Finish()
   578  
   579  	sid := services.NewServiceID().SetName("foo")
   580  	cs := topic.NewConsumerService().SetServiceID(sid).SetConsumptionType(topic.Shared)
   581  
   582  	ps := testPlacementService(mem.NewStore(), sid)
   583  	sd := services.NewMockServices(ctrl)
   584  	sd.EXPECT().PlacementService(sid, gomock.Any()).Return(ps, nil)
   585  
   586  	opts := testOptions().SetServiceDiscovery(sd)
   587  	w, err := newConsumerServiceWriter(cs, 3, opts)
   588  	require.NoError(t, err)
   589  	defer w.Close()
   590  
   591  	require.NoError(t, w.Init(allowInitValueError))
   592  }
   593  
   594  func TestConsumerServiceWriterInitError(t *testing.T) {
   595  	defer leaktest.Check(t)()
   596  
   597  	ctrl := xtest.NewController(t)
   598  	defer ctrl.Finish()
   599  
   600  	sid := services.NewServiceID().SetName("foo")
   601  	cs := topic.NewConsumerService().SetServiceID(sid).SetConsumptionType(topic.Shared)
   602  
   603  	ps := placement.NewMockService(ctrl)
   604  	ps.EXPECT().Watch().Return(nil, errors.New("mock err")).AnyTimes()
   605  
   606  	sd := services.NewMockServices(ctrl)
   607  	sd.EXPECT().PlacementService(sid, gomock.Any()).Return(ps, nil)
   608  
   609  	opts := testOptions().SetServiceDiscovery(sd)
   610  	w, err := newConsumerServiceWriter(cs, 3, opts)
   611  	require.NoError(t, err)
   612  	defer w.Close()
   613  
   614  	err = w.Init(failOnError)
   615  	require.Error(t, err)
   616  	require.Contains(t, err.Error(), "consumer service writer init error")
   617  }
   618  
   619  func TestConsumerServiceWriterUpdateNonShardedPlacementWithReplicatedConsumptionType(t *testing.T) {
   620  	defer leaktest.Check(t)()
   621  
   622  	ctrl := xtest.NewController(t)
   623  	defer ctrl.Finish()
   624  
   625  	sid := services.NewServiceID().SetName("foo")
   626  	cs := topic.NewConsumerService().SetServiceID(sid).SetConsumptionType(topic.Replicated)
   627  	sd := services.NewMockServices(ctrl)
   628  	pOpts := placement.NewOptions().SetIsSharded(false)
   629  	ps := service.NewPlacementService(storage.NewPlacementStorage(mem.NewStore(), sid.String(), pOpts),
   630  		service.WithPlacementOptions(pOpts))
   631  	sd.EXPECT().PlacementService(sid, gomock.Any()).Return(ps, nil)
   632  	_, err := ps.BuildInitialPlacement([]placement.Instance{
   633  		placement.NewInstance().SetID("i1").SetEndpoint("i1").SetWeight(1),
   634  	}, 0, 1)
   635  	require.NoError(t, err)
   636  	opts := testOptions().SetServiceDiscovery(sd)
   637  	w, err := newConsumerServiceWriter(cs, 2, opts)
   638  	require.NoError(t, err)
   639  	err = w.Init(failOnError)
   640  	require.Error(t, err)
   641  	require.Contains(t, err.Error(), "non-sharded placement for replicated consumer")
   642  	w.Close()
   643  }
   644  
   645  func TestConsumerServiceCloseShardWritersConcurrently(t *testing.T) {
   646  	defer leaktest.Check(t)()
   647  
   648  	ctrl := xtest.NewController(t)
   649  	defer ctrl.Finish()
   650  
   651  	sid := services.NewServiceID().SetName("foo")
   652  	cs := topic.NewConsumerService().SetServiceID(sid).SetConsumptionType(topic.Shared)
   653  	sd := services.NewMockServices(ctrl)
   654  	ps := testPlacementService(mem.NewStore(), sid)
   655  	sd.EXPECT().PlacementService(sid, gomock.Any()).Return(ps, nil)
   656  	opts := testOptions().SetServiceDiscovery(sd).SetCloseCheckInterval(time.Second)
   657  
   658  	numShards := uint32(1024)
   659  	w, err := newConsumerServiceWriter(cs, numShards, opts)
   660  	require.NoError(t, err)
   661  	require.NoError(t, w.Init(allowInitValueError))
   662  
   663  	// Write one message to each shard, so each shard needs to tick
   664  	// and wait for the queue to be cleaned up.
   665  	b := []byte{}
   666  	for i := uint32(0); i < numShards; i++ {
   667  		mm := producer.NewMockMessage(ctrl)
   668  		mm.EXPECT().Shard().Return(i)
   669  		mm.EXPECT().Bytes().Return(b).AnyTimes()
   670  		mm.EXPECT().Size().Return(0).AnyTimes()
   671  		mm.EXPECT().Finalize(gomock.Any())
   672  		w.Write(producer.NewRefCountedMessage(mm, nil))
   673  	}
   674  
   675  	ch := make(chan struct{})
   676  	go func() {
   677  		w.Close()
   678  		close(ch)
   679  	}()
   680  
   681  	select {
   682  	case <-ch:
   683  		return
   684  	case <-time.After(10 * time.Second):
   685  		require.FailNow(t, "taking too long to close consumer service writer")
   686  	}
   687  }
   688  
   689  func testPlacementService(store kv.Store, sid services.ServiceID) placement.Service {
   690  	return service.NewPlacementService(
   691  		storage.NewPlacementStorage(store, sid.String(), placement.NewOptions()),
   692  	)
   693  }