github.com/matrixorigin/matrixone@v1.2.0/pkg/common/morpc/server_test.go (about)

     1  // Copyright 2021 - 2022 Matrix Origin
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package morpc
    16  
    17  import (
    18  	"context"
    19  	"os"
    20  	"sync"
    21  	"testing"
    22  	"time"
    23  
    24  	"github.com/fagongzi/goetty/v2"
    25  	"github.com/matrixorigin/matrixone/pkg/logutil"
    26  	"github.com/stretchr/testify/assert"
    27  	"github.com/stretchr/testify/require"
    28  	"go.uber.org/zap"
    29  )
    30  
    31  func TestCreateServerWithOptions(t *testing.T) {
    32  	testRPCServer(t, func(rs *server) {
    33  		assert.Equal(t, 100, rs.options.batchSendSize)
    34  		assert.Equal(t, 200, rs.options.bufferSize)
    35  	}, WithServerBatchSendSize(100),
    36  		WithServerSessionBufferSize(200))
    37  }
    38  
    39  func TestHandleServer(t *testing.T) {
    40  	testRPCServer(t, func(rs *server) {
    41  		c := newTestClient(t)
    42  		defer func() {
    43  			assert.NoError(t, c.Close())
    44  		}()
    45  
    46  		ctx, cancel := context.WithTimeout(context.Background(), time.Second*10000)
    47  		defer cancel()
    48  
    49  		rs.RegisterRequestHandler(func(_ context.Context, request RPCMessage, sequence uint64, cs ClientSession) error {
    50  			return cs.Write(ctx, request.Message)
    51  		})
    52  
    53  		req := newTestMessage(1)
    54  		f, err := c.Send(ctx, testAddr, req)
    55  		assert.NoError(t, err)
    56  
    57  		defer f.Close()
    58  		resp, err := f.Get()
    59  		assert.NoError(t, err)
    60  		assert.Equal(t, req, resp)
    61  	})
    62  }
    63  
    64  func TestHandleServerWithPayloadMessage(t *testing.T) {
    65  	testRPCServer(t, func(rs *server) {
    66  		c := newTestClient(t)
    67  		defer func() {
    68  			assert.NoError(t, c.Close())
    69  		}()
    70  
    71  		ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
    72  		defer cancel()
    73  
    74  		rs.RegisterRequestHandler(func(_ context.Context, request RPCMessage, sequence uint64, cs ClientSession) error {
    75  			return cs.Write(ctx, request.Message)
    76  		})
    77  
    78  		req := &testMessage{id: 1, payload: []byte("payload")}
    79  		f, err := c.Send(ctx, testAddr, req)
    80  		assert.NoError(t, err)
    81  
    82  		defer f.Close()
    83  		resp, err := f.Get()
    84  		assert.NoError(t, err)
    85  		assert.Equal(t, req, resp)
    86  	})
    87  }
    88  
    89  func TestHandleServerWriteWithClosedSession(t *testing.T) {
    90  	wc := make(chan struct{}, 1)
    91  	defer close(wc)
    92  
    93  	testRPCServer(t, func(rs *server) {
    94  		ctx, cancel := context.WithTimeout(context.Background(), time.Second*2)
    95  		defer cancel()
    96  
    97  		c := newTestClient(t)
    98  		rs.RegisterRequestHandler(func(_ context.Context, request RPCMessage, _ uint64, cs ClientSession) error {
    99  			assert.NoError(t, c.Close())
   100  			err := cs.Write(ctx, request.Message)
   101  			assert.Error(t, err)
   102  			return err
   103  		})
   104  
   105  		req := newTestMessage(1)
   106  		f, err := c.Send(ctx, testAddr, req)
   107  		assert.NoError(t, err)
   108  
   109  		defer f.Close()
   110  		resp, err := f.Get()
   111  		assert.Error(t, ctx.Err(), err)
   112  		assert.Nil(t, resp)
   113  	})
   114  }
   115  
   116  func TestHandleServerWriteWithClosedClientSession(t *testing.T) {
   117  	wc := make(chan struct{}, 1)
   118  	defer close(wc)
   119  
   120  	testRPCServer(t, func(rs *server) {
   121  		ctx, cancel := context.WithTimeout(context.Background(), time.Second*2)
   122  		defer cancel()
   123  
   124  		c := newTestClient(t)
   125  		rs.RegisterRequestHandler(func(_ context.Context, request RPCMessage, _ uint64, cs ClientSession) error {
   126  			assert.NoError(t, cs.Close())
   127  			return cs.Write(ctx, request.Message)
   128  		})
   129  
   130  		req := newTestMessage(1)
   131  		f, err := c.Send(ctx, testAddr, req)
   132  		assert.NoError(t, err)
   133  
   134  		defer f.Close()
   135  		_, err = f.Get()
   136  		assert.Error(t, err)
   137  		assert.Equal(t, backendClosed, err)
   138  	})
   139  }
   140  
   141  func TestStreamServer(t *testing.T) {
   142  	testRPCServer(t, func(rs *server) {
   143  		ctx, cancel := context.WithTimeout(context.TODO(), time.Second*10)
   144  		defer cancel()
   145  
   146  		c := newTestClient(t)
   147  		defer func() {
   148  			assert.NoError(t, c.Close())
   149  		}()
   150  
   151  		wg := sync.WaitGroup{}
   152  		wg.Add(1)
   153  		n := 10
   154  		rs.RegisterRequestHandler(func(_ context.Context, request RPCMessage, _ uint64, cs ClientSession) error {
   155  			go func() {
   156  				defer wg.Done()
   157  				for i := 0; i < n; i++ {
   158  					assert.NoError(t, cs.Write(ctx, request.Message))
   159  				}
   160  			}()
   161  			return nil
   162  		})
   163  
   164  		st, err := c.NewStream(testAddr, false)
   165  		assert.NoError(t, err)
   166  		defer func() {
   167  			assert.NoError(t, st.Close(false))
   168  		}()
   169  
   170  		req := newTestMessage(st.ID())
   171  		assert.NoError(t, st.Send(ctx, req))
   172  
   173  		rc, err := st.Receive()
   174  		assert.NoError(t, err)
   175  		for i := 0; i < n; i++ {
   176  			assert.Equal(t, req, <-rc)
   177  		}
   178  
   179  		wg.Wait()
   180  	})
   181  }
   182  
   183  func TestStreamServerWithCache(t *testing.T) {
   184  	testRPCServer(t, func(rs *server) {
   185  		ctx, cancel := context.WithTimeout(context.TODO(), time.Second*10)
   186  		defer cancel()
   187  
   188  		c := newTestClient(t)
   189  		defer func() {
   190  			assert.NoError(t, c.Close())
   191  		}()
   192  
   193  		rs.RegisterRequestHandler(func(ctx context.Context, msg RPCMessage, seq uint64, cs ClientSession) error {
   194  			request := msg.Message
   195  			if seq == 1 {
   196  				cache, err := cs.CreateCache(ctx, request.GetID())
   197  				if err != nil {
   198  					return err
   199  				}
   200  				m := newTestMessage(request.GetID())
   201  				return cache.Add(m)
   202  			} else {
   203  				cache, err := cs.GetCache(request.GetID())
   204  				if err != nil {
   205  					return err
   206  				}
   207  				m, _, err := cache.Pop()
   208  				if err != nil {
   209  					return err
   210  				}
   211  				if err := cs.Write(ctx, m); err != nil {
   212  					return err
   213  				}
   214  				if err := cs.Write(ctx, request); err != nil {
   215  					return err
   216  				}
   217  			}
   218  			return nil
   219  		})
   220  
   221  		st, err := c.NewStream(testAddr, false)
   222  		assert.NoError(t, err)
   223  		defer func() {
   224  			assert.NoError(t, st.Close(false))
   225  		}()
   226  
   227  		req1 := newTestMessage(st.ID())
   228  		req1.payload = []byte{1}
   229  		assert.NoError(t, st.Send(ctx, req1))
   230  
   231  		req2 := newTestMessage(st.ID())
   232  		req2.payload = []byte{2}
   233  		assert.NoError(t, st.Send(ctx, req2))
   234  
   235  		cc, err := st.Receive()
   236  		require.NoError(t, err)
   237  		for i := 0; i < 2; i++ {
   238  			select {
   239  			case <-ctx.Done():
   240  				assert.Fail(t, "message failed")
   241  			case <-cc:
   242  			}
   243  		}
   244  	})
   245  }
   246  
   247  func TestServerTimeoutCacheWillRemoved(t *testing.T) {
   248  	testRPCServer(t, func(rs *server) {
   249  		ctx, cancel := context.WithTimeout(context.TODO(), time.Second*10)
   250  		defer cancel()
   251  
   252  		c := newTestClient(t)
   253  		defer func() {
   254  			assert.NoError(t, c.Close())
   255  		}()
   256  
   257  		cc := make(chan struct{})
   258  		rs.RegisterRequestHandler(func(ctx context.Context, msg RPCMessage, seq uint64, cs ClientSession) error {
   259  			request := msg.Message
   260  			cache, err := cs.CreateCache(ctx, request.GetID())
   261  			if err != nil {
   262  				return err
   263  			}
   264  			close(cc)
   265  			return cache.Add(request)
   266  		})
   267  
   268  		st, err := c.NewStream(testAddr, false)
   269  		assert.NoError(t, err)
   270  		defer func() {
   271  			assert.NoError(t, st.Close(false))
   272  		}()
   273  
   274  		assert.NoError(t, st.Send(ctx, newTestMessage(1)))
   275  		<-cc
   276  		v, ok := rs.sessions.Load(uint64(1))
   277  		if ok {
   278  			cs := v.(*clientSession)
   279  			for {
   280  				cs.mu.RLock()
   281  				if len(cs.mu.caches) == 0 {
   282  					cs.mu.RUnlock()
   283  					return
   284  				}
   285  				cs.mu.RUnlock()
   286  			}
   287  		}
   288  	})
   289  }
   290  
   291  func TestStreamServerWithSequenceNotMatch(t *testing.T) {
   292  	testRPCServer(t, func(rs *server) {
   293  		ctx, cancel := context.WithTimeout(context.TODO(), time.Second*10)
   294  		defer cancel()
   295  
   296  		c := newTestClient(t)
   297  		defer func() {
   298  			assert.NoError(t, c.Close())
   299  		}()
   300  
   301  		rs.RegisterRequestHandler(func(_ context.Context, request RPCMessage, _ uint64, cs ClientSession) error {
   302  			return cs.Write(ctx, request.Message)
   303  		})
   304  
   305  		v, err := c.NewStream(testAddr, false)
   306  		assert.NoError(t, err)
   307  		st := v.(*stream)
   308  		defer func() {
   309  			assert.NoError(t, st.Close(false))
   310  		}()
   311  
   312  		st.sequence = 2
   313  		req := newTestMessage(st.ID())
   314  		assert.NoError(t, st.Send(ctx, req))
   315  
   316  		rc, err := st.Receive()
   317  		assert.NoError(t, err)
   318  		assert.NotNil(t, rc)
   319  		resp := <-rc
   320  		assert.Nil(t, resp)
   321  	})
   322  }
   323  
   324  func TestStreamReadCannotBlockWrite(t *testing.T) {
   325  	testRPCServer(t, func(rs *server) {
   326  		ctx, cancel := context.WithTimeout(context.TODO(), time.Second*10)
   327  		defer cancel()
   328  
   329  		c := newTestClient(t)
   330  		defer func() {
   331  			assert.NoError(t, c.Close())
   332  		}()
   333  
   334  		rs.RegisterRequestHandler(func(_ context.Context, request RPCMessage, _ uint64, cs ClientSession) error {
   335  			return cs.Write(ctx, request.Message)
   336  		})
   337  
   338  		st, err := c.NewStream(testAddr, false)
   339  		assert.NoError(t, err)
   340  		defer func() {
   341  			assert.NoError(t, st.Close(false))
   342  		}()
   343  
   344  		ch, err := st.Receive()
   345  		require.NoError(t, err)
   346  
   347  		cc := make(chan struct{})
   348  		n := 1000
   349  		go func() {
   350  			defer close(cc)
   351  			i := 0
   352  			for {
   353  				<-ch
   354  				i++
   355  				if i == n {
   356  					return
   357  				}
   358  				time.Sleep(time.Millisecond)
   359  			}
   360  		}()
   361  		for i := 0; i < n; i++ {
   362  			require.NoError(t, st.Send(ctx, newTestMessage(st.ID())))
   363  		}
   364  		<-cc
   365  	})
   366  }
   367  
   368  func TestCannotGetClosedBackend(t *testing.T) {
   369  	testRPCServer(t, func(rs *server) {
   370  		ctx, cancel := context.WithTimeout(context.TODO(), time.Second*10)
   371  		defer cancel()
   372  
   373  		c := newTestClient(t, WithClientMaxBackendPerHost(2))
   374  		defer func() {
   375  			assert.NoError(t, c.Close())
   376  		}()
   377  
   378  		rs.RegisterRequestHandler(func(_ context.Context, request RPCMessage, _ uint64, cs ClientSession) error {
   379  			return cs.Write(ctx, request.Message)
   380  		})
   381  
   382  		st, err := c.NewStream(testAddr, true)
   383  		require.NoError(t, err)
   384  		require.NoError(t, st.Close(true))
   385  
   386  		require.NoError(t, c.Ping(ctx, testAddr))
   387  	})
   388  }
   389  
   390  func TestPingError(t *testing.T) {
   391  	testRPCServer(t, func(rs *server) {
   392  		c := newTestClient(t, WithClientMaxBackendPerHost(2))
   393  		defer func() {
   394  			assert.NoError(t, c.Close())
   395  		}()
   396  		rs.RegisterRequestHandler(func(_ context.Context, request RPCMessage, _ uint64, cs ClientSession) error {
   397  			return cs.Write(context.Background(), request.Message)
   398  		})
   399  		ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond)
   400  		cancel()
   401  		require.Error(t, c.Ping(ctx, testAddr))
   402  	})
   403  }
   404  
   405  func BenchmarkSend(b *testing.B) {
   406  	ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
   407  	defer cancel()
   408  
   409  	testRPCServer(b, func(rs *server) {
   410  		c := newTestClient(b,
   411  			WithClientMaxBackendPerHost(1),
   412  			WithClientInitBackends([]string{testAddr}, []int{1}))
   413  		defer func() {
   414  			assert.NoError(b, c.Close())
   415  		}()
   416  
   417  		rs.RegisterRequestHandler(func(_ context.Context, request RPCMessage, sequence uint64, cs ClientSession) error {
   418  			return cs.Write(ctx, request.Message)
   419  		})
   420  
   421  		req := newTestMessage(1)
   422  
   423  		b.ResetTimer()
   424  		for i := 0; i < b.N; i++ {
   425  			f, err := c.Send(ctx, testAddr, req)
   426  			if err == nil {
   427  				_, err := f.Get()
   428  				if err != nil {
   429  					assert.Equal(b, ctx.Err(), err)
   430  				}
   431  				f.Close()
   432  			}
   433  		}
   434  	}, WithServerGoettyOptions(goetty.WithSessionReleaseMsgFunc(func(i interface{}) {
   435  		msg := i.(RPCMessage)
   436  		if !msg.InternalMessage() {
   437  			messagePool.Put(msg.Message)
   438  		}
   439  	})))
   440  }
   441  
   442  func testRPCServer(t assert.TestingT, testFunc func(*server), options ...ServerOption) {
   443  	assert.NoError(t, os.RemoveAll(testUnixFile))
   444  
   445  	options = append(options,
   446  		WithServerLogger(logutil.GetPanicLoggerWithLevel(zap.InfoLevel)))
   447  	s, err := NewRPCServer("test", testAddr, newTestCodec(), options...)
   448  	assert.NoError(t, err)
   449  	assert.NoError(t, s.Start())
   450  	defer func() {
   451  		assert.NoError(t, s.Close())
   452  	}()
   453  
   454  	testFunc(s.(*server))
   455  }
   456  
   457  func newTestClient(t assert.TestingT, options ...ClientOption) RPCClient {
   458  	bf := NewGoettyBasedBackendFactory(newTestCodec())
   459  	c, err := NewClient(
   460  		"",
   461  		bf,
   462  		options...)
   463  	assert.NoError(t, err)
   464  	return c
   465  }
   466  
   467  func TestPing(t *testing.T) {
   468  	testRPCServer(t, func(rs *server) {
   469  		c := newTestClient(t)
   470  		defer func() {
   471  			assert.NoError(t, c.Close())
   472  		}()
   473  
   474  		ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
   475  		defer cancel()
   476  
   477  		assert.NoError(t, c.Ping(ctx, testAddr))
   478  	})
   479  }