github.com/m3db/m3@v1.5.0/src/msg/producer/writer/consumer_service_writer_test.go (about)

     1  // Copyright (c) 2018 Uber Technologies, Inc.
     2  //
     3  // Permission is hereby granted, free of charge, to any person obtaining a copy
     4  // of this software and associated documentation files (the "Software"), to deal
     5  // in the Software without restriction, including without limitation the rights
     6  // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
     7  // copies of the Software, and to permit persons to whom the Software is
     8  // furnished to do so, subject to the following conditions:
     9  //
    10  // The above copyright notice and this permission notice shall be included in
    11  // all copies or substantial portions of the Software.
    12  //
    13  // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    14  // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    15  // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    16  // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    17  // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    18  // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
    19  // THE SOFTWARE.
    20  
    21  package writer
    22  
    23  import (
    24  	"errors"
    25  	"net"
    26  	"sync"
    27  	"testing"
    28  	"time"
    29  
    30  	"github.com/m3db/m3/src/cluster/kv"
    31  	"github.com/m3db/m3/src/cluster/kv/mem"
    32  	"github.com/m3db/m3/src/cluster/placement"
    33  	"github.com/m3db/m3/src/cluster/placement/service"
    34  	"github.com/m3db/m3/src/cluster/placement/storage"
    35  	"github.com/m3db/m3/src/cluster/services"
    36  	"github.com/m3db/m3/src/cluster/shard"
    37  	"github.com/m3db/m3/src/msg/generated/proto/msgpb"
    38  	"github.com/m3db/m3/src/msg/producer"
    39  	"github.com/m3db/m3/src/msg/protocol/proto"
    40  	"github.com/m3db/m3/src/msg/topic"
    41  	xtest "github.com/m3db/m3/src/x/test"
    42  
    43  	"github.com/fortytw2/leaktest"
    44  	"github.com/golang/mock/gomock"
    45  	"github.com/stretchr/testify/require"
    46  )
    47  
    48  func TestConsumerServiceWriterWithSharedConsumerWithNonShardedPlacement(t *testing.T) {
    49  	defer leaktest.Check(t)()
    50  
    51  	ctrl := xtest.NewController(t)
    52  	defer ctrl.Finish()
    53  
    54  	sid := services.NewServiceID().SetName("foo")
    55  	cs := topic.NewConsumerService().SetServiceID(sid).SetConsumptionType(topic.Shared)
    56  	sd := services.NewMockServices(ctrl)
    57  	ps := testPlacementService(mem.NewStore(), sid)
    58  	sd.EXPECT().PlacementService(sid, gomock.Any()).Return(ps, nil)
    59  
    60  	opts := testOptions().SetServiceDiscovery(sd)
    61  	w, err := newConsumerServiceWriter(cs, 2, opts)
    62  	require.NoError(t, err)
    63  
    64  	csw := w.(*consumerServiceWriterImpl)
    65  
    66  	var (
    67  		lock               sync.Mutex
    68  		numConsumerWriters int
    69  	)
    70  	csw.processFn = func(p interface{}) error {
    71  		err := csw.process(p)
    72  		lock.Lock()
    73  		numConsumerWriters = len(csw.consumerWriters)
    74  		lock.Unlock()
    75  		return err
    76  	}
    77  
    78  	require.NoError(t, csw.Init(allowInitValueError))
    79  	lock.Lock()
    80  	require.Equal(t, 0, numConsumerWriters)
    81  	lock.Unlock()
    82  
    83  	lis, err := net.Listen("tcp", "127.0.0.1:0")
    84  	require.NoError(t, err)
    85  	defer lis.Close()
    86  
    87  	p1 := placement.NewPlacement().
    88  		SetInstances([]placement.Instance{
    89  			placement.NewInstance().
    90  				SetID("i1").
    91  				SetEndpoint(lis.Addr().String()),
    92  			placement.NewInstance().
    93  				SetID("i2").
    94  				SetEndpoint("addr2"),
    95  			placement.NewInstance().
    96  				SetID("i3").
    97  				SetEndpoint("addr3"),
    98  		}).
    99  		SetIsSharded(false)
   100  	_, err = ps.Set(p1)
   101  	require.NoError(t, err)
   102  
   103  	for {
   104  		lock.Lock()
   105  		l := numConsumerWriters
   106  		lock.Unlock()
   107  		if l == 3 {
   108  			break
   109  		}
   110  		time.Sleep(100 * time.Millisecond)
   111  	}
   112  
   113  	var wg sync.WaitGroup
   114  	defer wg.Wait()
   115  
   116  	wg.Add(1)
   117  	go func() {
   118  		testConsumeAndAckOnConnectionListener(t, lis, opts.EncoderOptions(), opts.DecoderOptions())
   119  		wg.Done()
   120  	}()
   121  
   122  	mm := producer.NewMockMessage(ctrl)
   123  	mm.EXPECT().Shard().Return(uint32(1))
   124  	mm.EXPECT().Bytes().Return([]byte("foo"))
   125  	mm.EXPECT().Size().Return(3)
   126  	mm.EXPECT().Finalize(producer.Consumed)
   127  
   128  	rm := producer.NewRefCountedMessage(mm, nil)
   129  	csw.Write(rm)
   130  	for {
   131  		if rm.IsDroppedOrConsumed() {
   132  			break
   133  		}
   134  		time.Sleep(100 * time.Millisecond)
   135  	}
   136  
   137  	for _, sw := range w.(*consumerServiceWriterImpl).shardWriters {
   138  		require.Equal(t, 3, len(sw.(*sharedShardWriter).mw.(*messageWriterImpl).consumerWriters))
   139  	}
   140  
   141  	p2 := placement.NewPlacement().
   142  		SetInstances([]placement.Instance{
   143  			placement.NewInstance().
   144  				SetID("i1").
   145  				SetEndpoint(lis.Addr().String()),
   146  			placement.NewInstance().
   147  				SetID("i2").
   148  				SetEndpoint("addr2"),
   149  		}).
   150  		SetIsSharded(false)
   151  	_, err = ps.Set(p2)
   152  	require.NoError(t, err)
   153  
   154  	for {
   155  		lock.Lock()
   156  		l := numConsumerWriters
   157  		lock.Unlock()
   158  		if l == 2 {
   159  			break
   160  		}
   161  		time.Sleep(100 * time.Millisecond)
   162  	}
   163  
   164  	for _, sw := range w.(*consumerServiceWriterImpl).shardWriters {
   165  		require.Equal(t, 2, len(sw.(*sharedShardWriter).mw.(*messageWriterImpl).consumerWriters))
   166  	}
   167  
   168  	csw.Close()
   169  	csw.Close()
   170  }
   171  
   172  func TestConsumerServiceWriterWithSharedConsumerWithShardedPlacement(t *testing.T) {
   173  	defer leaktest.Check(t)()
   174  
   175  	ctrl := xtest.NewController(t)
   176  	defer ctrl.Finish()
   177  
   178  	sid := services.NewServiceID().SetName("foo")
   179  	cs := topic.NewConsumerService().SetServiceID(sid).SetConsumptionType(topic.Shared)
   180  	sd := services.NewMockServices(ctrl)
   181  	ps := testPlacementService(mem.NewStore(), sid)
   182  	sd.EXPECT().PlacementService(sid, gomock.Any()).Return(ps, nil)
   183  
   184  	opts := testOptions().SetServiceDiscovery(sd)
   185  	w, err := newConsumerServiceWriter(cs, 3, opts)
   186  	require.NoError(t, err)
   187  
   188  	csw := w.(*consumerServiceWriterImpl)
   189  
   190  	var (
   191  		lock               sync.Mutex
   192  		numConsumerWriters int
   193  	)
   194  	csw.processFn = func(p interface{}) error {
   195  		err := csw.process(p)
   196  		lock.Lock()
   197  		numConsumerWriters = len(csw.consumerWriters)
   198  		lock.Unlock()
   199  		return err
   200  	}
   201  
   202  	require.NoError(t, csw.Init(allowInitValueError))
   203  	lock.Lock()
   204  	require.Equal(t, 0, numConsumerWriters)
   205  	lock.Unlock()
   206  
   207  	lis, err := net.Listen("tcp", "127.0.0.1:0")
   208  	require.NoError(t, err)
   209  	defer lis.Close()
   210  
   211  	p1 := placement.NewPlacement().
   212  		SetInstances([]placement.Instance{
   213  			placement.NewInstance().
   214  				SetID("i1").
   215  				SetEndpoint(lis.Addr().String()).
   216  				SetShards(shard.NewShards([]shard.Shard{
   217  					shard.NewShard(1).SetState(shard.Available),
   218  					shard.NewShard(2).SetState(shard.Available),
   219  				})),
   220  			placement.NewInstance().
   221  				SetID("i2").
   222  				SetEndpoint("addr2").
   223  				SetShards(shard.NewShards([]shard.Shard{
   224  					shard.NewShard(0).SetState(shard.Available),
   225  					shard.NewShard(2).SetState(shard.Available),
   226  				})),
   227  			placement.NewInstance().
   228  				SetID("i3").
   229  				SetEndpoint("addr3").
   230  				SetShards(shard.NewShards([]shard.Shard{
   231  					shard.NewShard(0).SetState(shard.Available),
   232  					shard.NewShard(1).SetState(shard.Available),
   233  				})),
   234  		}).
   235  		SetShards([]uint32{0, 1, 2}).
   236  		SetReplicaFactor(2).
   237  		SetIsSharded(true)
   238  	_, err = ps.Set(p1)
   239  	require.NoError(t, err)
   240  
   241  	for {
   242  		lock.Lock()
   243  		l := numConsumerWriters
   244  		lock.Unlock()
   245  		if l == 3 {
   246  			break
   247  		}
   248  		time.Sleep(100 * time.Millisecond)
   249  	}
   250  
   251  	var wg sync.WaitGroup
   252  	defer wg.Wait()
   253  
   254  	wg.Add(1)
   255  	go func() {
   256  		testConsumeAndAckOnConnectionListener(t, lis, opts.EncoderOptions(), opts.DecoderOptions())
   257  		wg.Done()
   258  	}()
   259  
   260  	mm := producer.NewMockMessage(ctrl)
   261  	mm.EXPECT().Shard().Return(uint32(1))
   262  	mm.EXPECT().Bytes().Return([]byte("foo"))
   263  	mm.EXPECT().Size().Return(3)
   264  	mm.EXPECT().Finalize(producer.Consumed)
   265  
   266  	rm := producer.NewRefCountedMessage(mm, nil)
   267  	csw.Write(rm)
   268  	for {
   269  		if rm.IsDroppedOrConsumed() {
   270  			break
   271  		}
   272  		time.Sleep(100 * time.Millisecond)
   273  	}
   274  	p2 := placement.NewPlacement().
   275  		SetInstances([]placement.Instance{
   276  			placement.NewInstance().
   277  				SetID("i1").
   278  				SetEndpoint(lis.Addr().String()).
   279  				SetShards(shard.NewShards([]shard.Shard{
   280  					shard.NewShard(1).SetState(shard.Available),
   281  				})),
   282  			placement.NewInstance().
   283  				SetID("i2").
   284  				SetEndpoint("addr2").
   285  				SetShards(shard.NewShards([]shard.Shard{
   286  					shard.NewShard(0).SetState(shard.Available),
   287  					shard.NewShard(2).SetState(shard.Available),
   288  				})),
   289  		}).
   290  		SetShards([]uint32{0, 1, 2}).
   291  		SetReplicaFactor(1).
   292  		SetIsSharded(true)
   293  	_, err = ps.Set(p2)
   294  	require.NoError(t, err)
   295  
   296  	for {
   297  		lock.Lock()
   298  		l := numConsumerWriters
   299  		lock.Unlock()
   300  		if l == 2 {
   301  			break
   302  		}
   303  		time.Sleep(100 * time.Millisecond)
   304  	}
   305  
   306  	csw.Close()
   307  	csw.Close()
   308  }
   309  
   310  func TestConsumerServiceWriterWithReplicatedConsumerWithShardedPlacement(t *testing.T) {
   311  	defer leaktest.Check(t)()
   312  
   313  	ctrl := xtest.NewController(t)
   314  	defer ctrl.Finish()
   315  
   316  	sid := services.NewServiceID().SetName("foo")
   317  	cs := topic.NewConsumerService().SetServiceID(sid).SetConsumptionType(topic.Replicated)
   318  	sd := services.NewMockServices(ctrl)
   319  	ps := testPlacementService(mem.NewStore(), sid)
   320  	sd.EXPECT().PlacementService(sid, gomock.Any()).Return(ps, nil)
   321  
   322  	lis1, err := net.Listen("tcp", "127.0.0.1:0")
   323  	require.NoError(t, err)
   324  	defer lis1.Close()
   325  
   326  	lis2, err := net.Listen("tcp", "127.0.0.1:0")
   327  	require.NoError(t, err)
   328  	defer lis2.Close()
   329  
   330  	p1 := placement.NewPlacement().
   331  		SetInstances([]placement.Instance{
   332  			placement.NewInstance().
   333  				SetID("i1").
   334  				SetEndpoint(lis1.Addr().String()).
   335  				SetShards(shard.NewShards([]shard.Shard{
   336  					shard.NewShard(0).SetState(shard.Available),
   337  					shard.NewShard(1).SetState(shard.Available),
   338  				})),
   339  			placement.NewInstance().
   340  				SetID("i2").
   341  				SetEndpoint(lis2.Addr().String()).
   342  				SetShards(shard.NewShards([]shard.Shard{
   343  					shard.NewShard(1).SetState(shard.Available),
   344  				})),
   345  			placement.NewInstance().
   346  				SetID("i3").
   347  				SetEndpoint("addr3").
   348  				SetShards(shard.NewShards([]shard.Shard{
   349  					shard.NewShard(0).SetState(shard.Available),
   350  				})),
   351  		}).
   352  		SetShards([]uint32{0, 1}).
   353  		SetReplicaFactor(2).
   354  		SetIsSharded(true)
   355  	_, err = ps.Set(p1)
   356  	require.NoError(t, err)
   357  
   358  	opts := testOptions().SetServiceDiscovery(sd)
   359  	w, err := newConsumerServiceWriter(cs, 2, opts)
   360  	csw := w.(*consumerServiceWriterImpl)
   361  	require.NoError(t, err)
   362  	require.NotNil(t, csw)
   363  
   364  	var (
   365  		lock               sync.Mutex
   366  		numConsumerWriters int
   367  	)
   368  	csw.processFn = func(p interface{}) error {
   369  		err := csw.process(p)
   370  		lock.Lock()
   371  		numConsumerWriters = len(csw.consumerWriters)
   372  		lock.Unlock()
   373  		return err
   374  	}
   375  	require.NoError(t, csw.Init(allowInitValueError))
   376  
   377  	for {
   378  		lock.Lock()
   379  		l := numConsumerWriters
   380  		lock.Unlock()
   381  		if l == 3 {
   382  			break
   383  		}
   384  		time.Sleep(100 * time.Millisecond)
   385  	}
   386  
   387  	mm := producer.NewMockMessage(ctrl)
   388  	mm.EXPECT().Shard().Return(uint32(1)).AnyTimes()
   389  	mm.EXPECT().Bytes().Return([]byte("foo")).AnyTimes()
   390  	mm.EXPECT().Size().Return(3)
   391  	mm.EXPECT().Finalize(producer.Consumed)
   392  
   393  	rm := producer.NewRefCountedMessage(mm, nil)
   394  	csw.Write(rm)
   395  	var wg sync.WaitGroup
   396  	wg.Add(1)
   397  	go func() {
   398  		testConsumeAndAckOnConnectionListener(t, lis1, opts.EncoderOptions(), opts.DecoderOptions())
   399  		wg.Done()
   400  	}()
   401  
   402  	wg.Add(1)
   403  	go func() {
   404  		testConsumeAndAckOnConnectionListener(t, lis2, opts.EncoderOptions(), opts.DecoderOptions())
   405  		wg.Done()
   406  	}()
   407  	wg.Wait()
   408  
   409  	for {
   410  		if rm.IsDroppedOrConsumed() {
   411  			break
   412  		}
   413  		time.Sleep(100 * time.Millisecond)
   414  	}
   415  
   416  	p2 := placement.NewPlacement().
   417  		SetInstances([]placement.Instance{
   418  			placement.NewInstance().
   419  				SetID("i1").
   420  				SetEndpoint(lis1.Addr().String()).
   421  				SetShards(shard.NewShards([]shard.Shard{
   422  					shard.NewShard(0).SetState(shard.Available),
   423  				})),
   424  			placement.NewInstance().
   425  				SetID("i2").
   426  				SetEndpoint(lis2.Addr().String()).
   427  				SetShards(shard.NewShards([]shard.Shard{
   428  					shard.NewShard(1).SetState(shard.Available),
   429  				})),
   430  		}).
   431  		SetShards([]uint32{0, 1}).
   432  		SetReplicaFactor(1).
   433  		SetIsSharded(true)
   434  	_, err = ps.Set(p2)
   435  	require.NoError(t, err)
   436  
   437  	for {
   438  		lock.Lock()
   439  		l := numConsumerWriters
   440  		lock.Unlock()
   441  		if l == 2 {
   442  			break
   443  		}
   444  		time.Sleep(100 * time.Millisecond)
   445  	}
   446  
   447  	go func() {
   448  		for {
   449  			conn, err := lis2.Accept()
   450  			if err != nil {
   451  				return
   452  			}
   453  			serverEncoder := proto.NewEncoder(opts.EncoderOptions())
   454  			serverDecoder := proto.NewDecoder(conn, opts.DecoderOptions(), 10)
   455  
   456  			var msg msgpb.Message
   457  			err = serverDecoder.Decode(&msg)
   458  			if err != nil {
   459  				conn.Close()
   460  				continue
   461  			}
   462  			require.NoError(t, serverEncoder.Encode(&msgpb.Ack{
   463  				Metadata: []msgpb.Metadata{
   464  					msg.Metadata,
   465  				},
   466  			}))
   467  			_, err = conn.Write(serverEncoder.Bytes())
   468  			require.NoError(t, err)
   469  			conn.Close()
   470  		}
   471  	}()
   472  
   473  	mm.EXPECT().Finalize(producer.Consumed)
   474  	mm.EXPECT().Size().Return(3)
   475  	rm = producer.NewRefCountedMessage(mm, nil)
   476  	csw.Write(rm)
   477  	for {
   478  		if rm.IsDroppedOrConsumed() {
   479  			break
   480  		}
   481  		time.Sleep(100 * time.Millisecond)
   482  	}
   483  
   484  	csw.Close()
   485  	csw.Close()
   486  }
   487  
   488  func TestConsumerServiceWriterFilter(t *testing.T) {
   489  	defer leaktest.Check(t)()
   490  
   491  	ctrl := xtest.NewController(t)
   492  	defer ctrl.Finish()
   493  
   494  	sid := services.NewServiceID().SetName("foo")
   495  	cs := topic.NewConsumerService().SetServiceID(sid).SetConsumptionType(topic.Replicated)
   496  	sd := services.NewMockServices(ctrl)
   497  	ps := testPlacementService(mem.NewStore(), sid)
   498  	sd.EXPECT().PlacementService(sid, gomock.Any()).Return(ps, nil)
   499  
   500  	opts := testOptions().SetServiceDiscovery(sd)
   501  	csw, err := newConsumerServiceWriter(cs, 3, opts)
   502  	require.NoError(t, err)
   503  
   504  	sw0 := NewMockshardWriter(ctrl)
   505  	sw1 := NewMockshardWriter(ctrl)
   506  	csw.(*consumerServiceWriterImpl).shardWriters[0] = sw0
   507  	csw.(*consumerServiceWriterImpl).shardWriters[1] = sw1
   508  
   509  	mm0 := producer.NewMockMessage(ctrl)
   510  	mm0.EXPECT().Shard().Return(uint32(0)).AnyTimes()
   511  	mm0.EXPECT().Size().Return(3).AnyTimes()
   512  	mm1 := producer.NewMockMessage(ctrl)
   513  	mm1.EXPECT().Shard().Return(uint32(1)).AnyTimes()
   514  	mm1.EXPECT().Size().Return(3).AnyTimes()
   515  
   516  	sw0.EXPECT().Write(gomock.Any())
   517  	csw.Write(producer.NewRefCountedMessage(mm0, nil))
   518  	sw1.EXPECT().Write(gomock.Any())
   519  	csw.Write(producer.NewRefCountedMessage(mm1, nil))
   520  
   521  	csw.RegisterFilter(func(m producer.Message) bool { return m.Shard() == uint32(0) })
   522  	csw.Write(producer.NewRefCountedMessage(mm1, nil))
   523  
   524  	sw0.EXPECT().Write(gomock.Any())
   525  	csw.Write(producer.NewRefCountedMessage(mm0, nil))
   526  
   527  	csw.UnregisterFilter()
   528  	sw1.EXPECT().Write(gomock.Any())
   529  	csw.Write(producer.NewRefCountedMessage(mm1, nil))
   530  }
   531  
   532  func TestConsumerServiceWriterAllowInitValueErrorWithCreateWatchError(t *testing.T) {
   533  	defer leaktest.Check(t)()
   534  
   535  	ctrl := xtest.NewController(t)
   536  	defer ctrl.Finish()
   537  
   538  	sid := services.NewServiceID().SetName("foo")
   539  	cs := topic.NewConsumerService().SetServiceID(sid).SetConsumptionType(topic.Shared)
   540  
   541  	ps := placement.NewMockService(ctrl)
   542  	ps.EXPECT().Watch().Return(nil, errors.New("mock err")).AnyTimes()
   543  
   544  	sd := services.NewMockServices(ctrl)
   545  	sd.EXPECT().PlacementService(sid, gomock.Any()).Return(ps, nil)
   546  
   547  	opts := testOptions().SetServiceDiscovery(sd)
   548  	w, err := newConsumerServiceWriter(cs, 3, opts)
   549  	require.NoError(t, err)
   550  	defer w.Close()
   551  
   552  	require.Error(t, w.Init(allowInitValueError))
   553  }
   554  
   555  func TestConsumerServiceWriterAllowInitValueErrorWithInitValueError(t *testing.T) {
   556  	defer leaktest.Check(t)()
   557  
   558  	ctrl := xtest.NewController(t)
   559  	defer ctrl.Finish()
   560  
   561  	sid := services.NewServiceID().SetName("foo")
   562  	cs := topic.NewConsumerService().SetServiceID(sid).SetConsumptionType(topic.Shared)
   563  
   564  	ps := testPlacementService(mem.NewStore(), sid)
   565  	sd := services.NewMockServices(ctrl)
   566  	sd.EXPECT().PlacementService(sid, gomock.Any()).Return(ps, nil)
   567  
   568  	opts := testOptions().SetServiceDiscovery(sd)
   569  	w, err := newConsumerServiceWriter(cs, 3, opts)
   570  	require.NoError(t, err)
   571  	defer w.Close()
   572  
   573  	require.NoError(t, w.Init(allowInitValueError))
   574  }
   575  
   576  func TestConsumerServiceWriterInitError(t *testing.T) {
   577  	defer leaktest.Check(t)()
   578  
   579  	ctrl := xtest.NewController(t)
   580  	defer ctrl.Finish()
   581  
   582  	sid := services.NewServiceID().SetName("foo")
   583  	cs := topic.NewConsumerService().SetServiceID(sid).SetConsumptionType(topic.Shared)
   584  
   585  	ps := placement.NewMockService(ctrl)
   586  	ps.EXPECT().Watch().Return(nil, errors.New("mock err")).AnyTimes()
   587  
   588  	sd := services.NewMockServices(ctrl)
   589  	sd.EXPECT().PlacementService(sid, gomock.Any()).Return(ps, nil)
   590  
   591  	opts := testOptions().SetServiceDiscovery(sd)
   592  	w, err := newConsumerServiceWriter(cs, 3, opts)
   593  	require.NoError(t, err)
   594  	defer w.Close()
   595  
   596  	err = w.Init(failOnError)
   597  	require.Error(t, err)
   598  	require.Contains(t, err.Error(), "consumer service writer init error")
   599  }
   600  
   601  func TestConsumerServiceWriterUpdateNonShardedPlacementWithReplicatedConsumptionType(t *testing.T) {
   602  	defer leaktest.Check(t)()
   603  
   604  	ctrl := xtest.NewController(t)
   605  	defer ctrl.Finish()
   606  
   607  	sid := services.NewServiceID().SetName("foo")
   608  	cs := topic.NewConsumerService().SetServiceID(sid).SetConsumptionType(topic.Replicated)
   609  	sd := services.NewMockServices(ctrl)
   610  	pOpts := placement.NewOptions().SetIsSharded(false)
   611  	ps := service.NewPlacementService(storage.NewPlacementStorage(mem.NewStore(), sid.String(), pOpts),
   612  		service.WithPlacementOptions(pOpts))
   613  	sd.EXPECT().PlacementService(sid, gomock.Any()).Return(ps, nil)
   614  	_, err := ps.BuildInitialPlacement([]placement.Instance{
   615  		placement.NewInstance().SetID("i1").SetEndpoint("i1").SetWeight(1),
   616  	}, 0, 1)
   617  	require.NoError(t, err)
   618  	opts := testOptions().SetServiceDiscovery(sd)
   619  	w, err := newConsumerServiceWriter(cs, 2, opts)
   620  	require.NoError(t, err)
   621  	err = w.Init(failOnError)
   622  	require.Error(t, err)
   623  	require.Contains(t, err.Error(), "non-sharded placement for replicated consumer")
   624  	w.Close()
   625  }
   626  
   627  func TestConsumerServiceCloseShardWritersConcurrently(t *testing.T) {
   628  	defer leaktest.Check(t)()
   629  
   630  	ctrl := xtest.NewController(t)
   631  	defer ctrl.Finish()
   632  
   633  	sid := services.NewServiceID().SetName("foo")
   634  	cs := topic.NewConsumerService().SetServiceID(sid).SetConsumptionType(topic.Shared)
   635  	sd := services.NewMockServices(ctrl)
   636  	ps := testPlacementService(mem.NewStore(), sid)
   637  	sd.EXPECT().PlacementService(sid, gomock.Any()).Return(ps, nil)
   638  	opts := testOptions().SetServiceDiscovery(sd).SetCloseCheckInterval(time.Second)
   639  
   640  	numShards := uint32(1024)
   641  	w, err := newConsumerServiceWriter(cs, numShards, opts)
   642  	require.NoError(t, err)
   643  	require.NoError(t, w.Init(allowInitValueError))
   644  
   645  	// Write one message to each shard, so each shard needs to tick
   646  	// and wait for the queue to be cleaned up.
   647  	b := []byte{}
   648  	for i := uint32(0); i < numShards; i++ {
   649  		mm := producer.NewMockMessage(ctrl)
   650  		mm.EXPECT().Shard().Return(i)
   651  		mm.EXPECT().Bytes().Return(b).AnyTimes()
   652  		mm.EXPECT().Size().Return(0).AnyTimes()
   653  		mm.EXPECT().Finalize(gomock.Any())
   654  		w.Write(producer.NewRefCountedMessage(mm, nil))
   655  	}
   656  
   657  	ch := make(chan struct{})
   658  	go func() {
   659  		w.Close()
   660  		close(ch)
   661  	}()
   662  
   663  	select {
   664  	case <-ch:
   665  		return
   666  	case <-time.After(10 * time.Second):
   667  		require.FailNow(t, "taking too long to close consumer service writer")
   668  	}
   669  }
   670  
   671  func testPlacementService(store kv.Store, sid services.ServiceID) placement.Service {
   672  	return service.NewPlacementService(
   673  		storage.NewPlacementStorage(store, sid.String(), placement.NewOptions()),
   674  	)
   675  }