github.com/m3db/m3@v1.5.1-0.20231129193456-75a402aa583b/src/msg/producer/writer/shard_writer_test.go (about)

     1  // Copyright (c) 2018 Uber Technologies, Inc.
     2  //
     3  // Permission is hereby granted, free of charge, to any person obtaining a copy
     4  // of this software and associated documentation files (the "Software"), to deal
     5  // in the Software without restriction, including without limitation the rights
     6  // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
     7  // copies of the Software, and to permit persons to whom the Software is
     8  // furnished to do so, subject to the following conditions:
     9  //
    10  // The above copyright notice and this permission notice shall be included in
    11  // all copies or substantial portions of the Software.
    12  //
    13  // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    14  // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    15  // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    16  // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    17  // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    18  // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
    19  // THE SOFTWARE.
    20  
    21  package writer
    22  
    23  import (
    24  	"net"
    25  	"sync"
    26  	"testing"
    27  	"time"
    28  
    29  	"github.com/m3db/m3/src/cluster/placement"
    30  	"github.com/m3db/m3/src/cluster/shard"
    31  	"github.com/m3db/m3/src/msg/generated/proto/msgpb"
    32  	"github.com/m3db/m3/src/msg/producer"
    33  	"github.com/m3db/m3/src/msg/protocol/proto"
    34  	xtest "github.com/m3db/m3/src/x/test"
    35  
    36  	"github.com/fortytw2/leaktest"
    37  	"github.com/stretchr/testify/require"
    38  )
    39  
    40  func TestSharedShardWriter(t *testing.T) {
    41  	defer leaktest.Check(t)()
    42  
    43  	a := newAckRouter(2)
    44  	opts := testOptions()
    45  	sw := newSharedShardWriter(1, a, newMessagePool(), opts, testMessageWriterMetrics())
    46  	defer sw.Close()
    47  
    48  	cw1 := newConsumerWriter("i1", a, opts, testConsumerWriterMetrics())
    49  	cw1.Init()
    50  	defer cw1.Close()
    51  
    52  	lis, err := net.Listen("tcp", "127.0.0.1:0")
    53  	require.NoError(t, err)
    54  	defer lis.Close()
    55  
    56  	addr2 := lis.Addr().String()
    57  	cw2 := newConsumerWriter(addr2, a, opts, testConsumerWriterMetrics())
    58  	cw2.Init()
    59  	defer cw2.Close()
    60  
    61  	cws := make(map[string]consumerWriter)
    62  	cws["i1"] = cw1
    63  	cws[addr2] = cw2
    64  
    65  	i1 := placement.NewInstance().SetEndpoint("i1")
    66  	i2 := placement.NewInstance().SetEndpoint(addr2)
    67  
    68  	var wg sync.WaitGroup
    69  	defer wg.Wait()
    70  
    71  	wg.Add(1)
    72  	go func() {
    73  		testConsumeAndAckOnConnectionListener(t, lis, opts.EncoderOptions(), opts.DecoderOptions())
    74  		wg.Done()
    75  	}()
    76  
    77  	sw.UpdateInstances(
    78  		[]placement.Instance{i1},
    79  		cws,
    80  	)
    81  
    82  	ctrl := xtest.NewController(t)
    83  	defer ctrl.Finish()
    84  
    85  	mm := producer.NewMockMessage(ctrl)
    86  	mm.EXPECT().Bytes().Return([]byte("foo"))
    87  	mm.EXPECT().Finalize(producer.Consumed)
    88  	mm.EXPECT().Size().Return(3)
    89  
    90  	sw.Write(producer.NewRefCountedMessage(mm, nil))
    91  
    92  	mw := sw.(*sharedShardWriter).mw
    93  	mw.RLock()
    94  	require.Equal(t, 1, len(mw.consumerWriters))
    95  	require.Equal(t, 1, mw.queue.Len())
    96  	mw.RUnlock()
    97  
    98  	sw.UpdateInstances(
    99  		[]placement.Instance{i1, i2},
   100  		cws,
   101  	)
   102  	mw.RLock()
   103  	require.Equal(t, 2, len(mw.consumerWriters))
   104  	mw.RUnlock()
   105  	for {
   106  		mw.RLock()
   107  		l := mw.queue.Len()
   108  		mw.RUnlock()
   109  		if l == 0 {
   110  			break
   111  		}
   112  		time.Sleep(200 * time.Millisecond)
   113  	}
   114  }
   115  
   116  func TestReplicatedShardWriter(t *testing.T) {
   117  	defer leaktest.Check(t)()
   118  
   119  	a := newAckRouter(3)
   120  	opts := testOptions()
   121  	sw := newReplicatedShardWriter(1, 200, a, newMessagePool(), opts, testMessageWriterMetrics()).(*replicatedShardWriter)
   122  	defer sw.Close()
   123  
   124  	lis1, err := net.Listen("tcp", "127.0.0.1:0")
   125  	require.NoError(t, err)
   126  	defer lis1.Close()
   127  
   128  	lis2, err := net.Listen("tcp", "127.0.0.1:0")
   129  	require.NoError(t, err)
   130  	defer lis2.Close()
   131  
   132  	addr1 := lis1.Addr().String()
   133  	cw1 := newConsumerWriter(addr1, a, opts, testConsumerWriterMetrics())
   134  	cw1.Init()
   135  	defer cw1.Close()
   136  
   137  	addr2 := lis2.Addr().String()
   138  	cw2 := newConsumerWriter(addr2, a, opts, testConsumerWriterMetrics())
   139  	cw2.Init()
   140  	defer cw2.Close()
   141  
   142  	cw3 := newConsumerWriter("i3", a, opts, testConsumerWriterMetrics())
   143  	cw3.Init()
   144  	defer cw3.Close()
   145  
   146  	cws := make(map[string]consumerWriter)
   147  	cws[addr1] = cw1
   148  	cws[addr2] = cw2
   149  	cws["i3"] = cw3
   150  	i1 := placement.NewInstance().
   151  		SetEndpoint(addr1).
   152  		SetShards(shard.NewShards([]shard.Shard{shard.NewShard(1)}))
   153  	i2 := placement.NewInstance().
   154  		SetEndpoint(addr2).
   155  		SetShards(shard.NewShards([]shard.Shard{shard.NewShard(1)}))
   156  	i3 := placement.NewInstance().
   157  		SetEndpoint("i3").
   158  		SetShards(shard.NewShards([]shard.Shard{shard.NewShard(1)}))
   159  
   160  	sw.UpdateInstances(
   161  		[]placement.Instance{i1, i3},
   162  		cws,
   163  	)
   164  	require.Equal(t, 2, len(sw.messageWriters))
   165  
   166  	ctrl := xtest.NewController(t)
   167  	defer ctrl.Finish()
   168  
   169  	mm := producer.NewMockMessage(ctrl)
   170  	mm.EXPECT().Size().Return(3)
   171  	mm.EXPECT().Bytes().Return([]byte("foo")).Times(2)
   172  
   173  	sw.Write(producer.NewRefCountedMessage(mm, nil))
   174  
   175  	mw1 := sw.messageWriters[i1.Endpoint()]
   176  	require.Equal(t, 1, mw1.queue.Len())
   177  	mw3 := sw.messageWriters[i3.Endpoint()]
   178  	require.Equal(t, 1, mw3.queue.Len())
   179  
   180  	var wg sync.WaitGroup
   181  	defer wg.Wait()
   182  
   183  	wg.Add(1)
   184  	go func() {
   185  		testConsumeAndAckOnConnectionListener(t, lis1, opts.EncoderOptions(), opts.DecoderOptions())
   186  		wg.Done()
   187  	}()
   188  
   189  	for {
   190  		mw1.RLock()
   191  		l := mw1.queue.Len()
   192  		mw1.RUnlock()
   193  		if l == 0 {
   194  			break
   195  		}
   196  		time.Sleep(100 * time.Millisecond)
   197  	}
   198  	require.Equal(t, 1, mw3.queue.Len())
   199  
   200  	mm.EXPECT().Finalize(producer.Consumed)
   201  	sw.UpdateInstances(
   202  		[]placement.Instance{i1, i2},
   203  		cws,
   204  	)
   205  
   206  	wg.Add(1)
   207  	go func() {
   208  		testConsumeAndAckOnConnectionListener(t, lis2, opts.EncoderOptions(), opts.DecoderOptions())
   209  		wg.Done()
   210  	}()
   211  
   212  	for {
   213  		mw3.RLock()
   214  		l := mw3.queue.Len()
   215  		mw3.RUnlock()
   216  		if l == 0 {
   217  			break
   218  		}
   219  		time.Sleep(100 * time.Millisecond)
   220  	}
   221  
   222  	mw2 := sw.messageWriters[i2.Endpoint()]
   223  	require.Equal(t, mw3, mw2)
   224  	_, ok := sw.messageWriters[i3.Endpoint()]
   225  	require.False(t, ok)
   226  }
   227  
   228  func TestReplicatedShardWriterRemoveMessageWriter(t *testing.T) {
   229  	defer leaktest.Check(t)()
   230  
   231  	router := newAckRouter(2)
   232  	opts := testOptions()
   233  	sw := newReplicatedShardWriter(
   234  		1, 200, router, newMessagePool(), opts, testMessageWriterMetrics(),
   235  	).(*replicatedShardWriter)
   236  
   237  	lis1, err := net.Listen("tcp", "127.0.0.1:0")
   238  	require.NoError(t, err)
   239  	defer lis1.Close()
   240  
   241  	lis2, err := net.Listen("tcp", "127.0.0.1:0")
   242  	require.NoError(t, err)
   243  	defer lis2.Close()
   244  
   245  	addr1 := lis1.Addr().String()
   246  	cw1 := newConsumerWriter(addr1, router, opts, testConsumerWriterMetrics())
   247  	cw1.Init()
   248  	defer cw1.Close()
   249  
   250  	addr2 := lis2.Addr().String()
   251  	cw2 := newConsumerWriter(addr2, router, opts, testConsumerWriterMetrics())
   252  	cw2.Init()
   253  	defer cw2.Close()
   254  
   255  	cws := make(map[string]consumerWriter)
   256  	cws[addr1] = cw1
   257  	cws[addr2] = cw2
   258  	i1 := placement.NewInstance().
   259  		SetEndpoint(addr1).
   260  		SetShards(shard.NewShards([]shard.Shard{shard.NewShard(1)}))
   261  	i2 := placement.NewInstance().
   262  		SetEndpoint(addr2).
   263  		SetShards(shard.NewShards([]shard.Shard{shard.NewShard(1)}))
   264  
   265  	sw.UpdateInstances(
   266  		[]placement.Instance{i1, i2},
   267  		cws,
   268  	)
   269  
   270  	require.Equal(t, 2, len(sw.messageWriters))
   271  
   272  	mw1 := sw.messageWriters[i1.Endpoint()]
   273  	mw2 := sw.messageWriters[i2.Endpoint()]
   274  	require.Equal(t, 0, mw1.queue.Len())
   275  	require.Equal(t, 0, mw2.queue.Len())
   276  
   277  	ctrl := xtest.NewController(t)
   278  	defer ctrl.Finish()
   279  
   280  	mm := producer.NewMockMessage(ctrl)
   281  	mm.EXPECT().Size().Return(3)
   282  	mm.EXPECT().Bytes().Return([]byte("foo")).Times(2)
   283  
   284  	sw.Write(producer.NewRefCountedMessage(mm, nil))
   285  	require.Equal(t, 1, mw1.queue.Len())
   286  	require.Equal(t, 1, mw2.queue.Len())
   287  
   288  	var wg sync.WaitGroup
   289  	defer wg.Wait()
   290  
   291  	wg.Add(1)
   292  	go func() {
   293  		testConsumeAndAckOnConnectionListener(t, lis1, opts.EncoderOptions(), opts.DecoderOptions())
   294  		wg.Done()
   295  	}()
   296  
   297  	for {
   298  		mw1.RLock()
   299  		l := mw1.queue.Len()
   300  		mw1.RUnlock()
   301  		if l == 0 {
   302  			break
   303  		}
   304  		time.Sleep(100 * time.Millisecond)
   305  	}
   306  
   307  	require.Equal(t, 1, mw2.queue.Len())
   308  
   309  	conn, err := lis2.Accept()
   310  	require.NoError(t, err)
   311  	defer conn.Close()
   312  
   313  	serverEncoder := proto.NewEncoder(opts.EncoderOptions())
   314  	serverDecoder := proto.NewDecoder(conn, opts.DecoderOptions(), 10)
   315  
   316  	var msg msgpb.Message
   317  	require.NoError(t, serverDecoder.Decode(&msg))
   318  	sw.UpdateInstances(
   319  		[]placement.Instance{i1},
   320  		cws,
   321  	)
   322  
   323  	require.Equal(t, 1, len(sw.messageWriters))
   324  
   325  	mm.EXPECT().Finalize(producer.Consumed)
   326  	require.NoError(t, serverEncoder.Encode(&msgpb.Ack{Metadata: []msgpb.Metadata{msg.Metadata}}))
   327  	_, err = conn.Write(serverEncoder.Bytes())
   328  	require.NoError(t, err)
   329  	// Make sure mw2 is closed and removed from router.
   330  	for {
   331  		router.RLock()
   332  		l := len(router.messageWriters)
   333  		router.RUnlock()
   334  		if l == 1 {
   335  			break
   336  		}
   337  		time.Sleep(100 * time.Millisecond)
   338  	}
   339  	mw2.RLock()
   340  	require.Equal(t, 0, mw2.queue.Len())
   341  	mw2.RUnlock()
   342  
   343  	sw.Close()
   344  }
   345  
   346  func TestReplicatedShardWriterUpdate(t *testing.T) {
   347  	defer leaktest.Check(t)()
   348  
   349  	a := newAckRouter(4)
   350  	opts := testOptions()
   351  	sw := newReplicatedShardWriter(1, 200, a, newMessagePool(), opts, testMessageWriterMetrics()).(*replicatedShardWriter)
   352  	defer sw.Close()
   353  
   354  	cw1 := newConsumerWriter("i1", a, opts, testConsumerWriterMetrics())
   355  	cw2 := newConsumerWriter("i2", a, opts, testConsumerWriterMetrics())
   356  	cw3 := newConsumerWriter("i3", a, opts, testConsumerWriterMetrics())
   357  	cw4 := newConsumerWriter("i4", a, opts, testConsumerWriterMetrics())
   358  	cws := make(map[string]consumerWriter)
   359  	cws["i1"] = cw1
   360  	cws["i2"] = cw2
   361  	cws["i3"] = cw3
   362  	cws["i4"] = cw4
   363  
   364  	i1 := placement.NewInstance().
   365  		SetEndpoint("i1").
   366  		SetShards(shard.NewShards([]shard.Shard{shard.NewShard(1).SetCutoffNanos(801).SetCutoverNanos(401)}))
   367  	i2 := placement.NewInstance().
   368  		SetEndpoint("i2").
   369  		SetShards(shard.NewShards([]shard.Shard{shard.NewShard(1).SetCutoffNanos(802).SetCutoverNanos(402)}))
   370  	i3 := placement.NewInstance().
   371  		SetEndpoint("i3").
   372  		SetShards(shard.NewShards([]shard.Shard{shard.NewShard(1).SetCutoffNanos(803).SetCutoverNanos(403)}))
   373  	i4 := placement.NewInstance().
   374  		SetEndpoint("i4").
   375  		SetShards(shard.NewShards([]shard.Shard{shard.NewShard(1).SetCutoffNanos(804).SetCutoverNanos(404)}))
   376  
   377  	sw.UpdateInstances([]placement.Instance{i1, i2}, cws)
   378  	require.Equal(t, 2, int(sw.replicaID))
   379  	require.Equal(t, 2, len(sw.messageWriters))
   380  	mw1 := sw.messageWriters[i1.Endpoint()]
   381  	require.NotNil(t, mw1)
   382  	require.Equal(t, 801, int(mw1.CutoffNanos()))
   383  	require.Equal(t, 401, int(mw1.CutoverNanos()))
   384  	require.NotNil(t, sw.messageWriters[i2.Endpoint()])
   385  	require.Equal(t, 0, int(mw1.MessageTTLNanos()))
   386  
   387  	sw.SetMessageTTLNanos(500)
   388  	require.Equal(t, 500, int(mw1.MessageTTLNanos()))
   389  
   390  	sw.UpdateInstances([]placement.Instance{i2, i3}, cws)
   391  	require.Equal(t, 2, int(sw.replicaID))
   392  	require.Equal(t, 2, len(sw.messageWriters))
   393  	mw2 := sw.messageWriters[i2.Endpoint()]
   394  	require.NotNil(t, mw2)
   395  	mw3 := sw.messageWriters[i3.Endpoint()]
   396  	require.NotNil(t, mw3)
   397  	require.Equal(t, mw1, mw3)
   398  	require.Equal(t, 803, int(mw3.CutoffNanos()))
   399  	require.Equal(t, 403, int(mw3.CutoverNanos()))
   400  	m := make(map[uint64]int, 2)
   401  	m[mw2.ReplicatedShardID()] = 1
   402  	m[mw3.ReplicatedShardID()] = 1
   403  	require.Equal(t, map[uint64]int{1: 1, 201: 1}, m)
   404  	require.Equal(t, 500, int(mw2.MessageTTLNanos()))
   405  	require.Equal(t, 500, int(mw3.MessageTTLNanos()))
   406  
   407  	sw.UpdateInstances([]placement.Instance{i3}, cws)
   408  	require.Equal(t, 2, int(sw.replicaID))
   409  	require.Equal(t, 1, len(sw.messageWriters))
   410  	require.NotNil(t, sw.messageWriters[i3.Endpoint()])
   411  	require.Equal(t, 500, int(mw3.MessageTTLNanos()))
   412  	for {
   413  		mw2.RLock()
   414  		isClosed := mw2.isClosed
   415  		mw2.RUnlock()
   416  		if isClosed {
   417  			break
   418  		}
   419  		time.Sleep(100 * time.Millisecond)
   420  	}
   421  
   422  	sw.SetMessageTTLNanos(800)
   423  	require.Equal(t, 800, int(mw3.MessageTTLNanos()))
   424  	sw.UpdateInstances([]placement.Instance{i1, i2, i3}, cws)
   425  	require.Equal(t, 4, int(sw.replicaID))
   426  	require.Equal(t, 3, len(sw.messageWriters))
   427  	newmw1 := sw.messageWriters[i1.Endpoint()]
   428  	require.NotNil(t, newmw1)
   429  	require.NotEqual(t, &mw1, &newmw1)
   430  	newmw2 := sw.messageWriters[i2.Endpoint()]
   431  	require.NotNil(t, newmw2)
   432  	require.NotEqual(t, &mw2, &newmw2)
   433  	newmw3 := sw.messageWriters[i3.Endpoint()]
   434  	require.NotNil(t, newmw3)
   435  	require.Equal(t, &mw3, &newmw3)
   436  	m = make(map[uint64]int, 3)
   437  	m[newmw1.ReplicatedShardID()] = 1
   438  	m[newmw2.ReplicatedShardID()] = 1
   439  	m[newmw3.ReplicatedShardID()] = 1
   440  	require.Equal(t, map[uint64]int{601: 1, 401: 1, mw3.ReplicatedShardID(): 1}, m)
   441  	require.Equal(t, 800, int(newmw1.MessageTTLNanos()))
   442  	require.Equal(t, 800, int(newmw2.MessageTTLNanos()))
   443  	require.Equal(t, 800, int(newmw3.MessageTTLNanos()))
   444  
   445  	sw.UpdateInstances([]placement.Instance{i2, i4}, cws)
   446  	require.Equal(t, 4, int(sw.replicaID))
   447  	require.Equal(t, 2, len(sw.messageWriters))
   448  	require.NotNil(t, sw.messageWriters[i2.Endpoint()])
   449  	require.NotNil(t, sw.messageWriters[i4.Endpoint()])
   450  
   451  	sw.UpdateInstances([]placement.Instance{i1}, cws)
   452  	require.Equal(t, 4, int(sw.replicaID))
   453  	require.Equal(t, 1, len(sw.messageWriters))
   454  	require.NotNil(t, sw.messageWriters[i1.Endpoint()])
   455  }