github.com/matrixorigin/matrixone@v1.2.0/pkg/common/morpc/handler_test.go (about)

     1  // Copyright 2022 Matrix Origin
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package morpc
    16  
    17  import (
    18  	"context"
    19  	"fmt"
    20  	"os"
    21  	"testing"
    22  	"time"
    23  
    24  	"github.com/fagongzi/goetty/v2/buf"
    25  	"github.com/lni/goutils/leaktest"
    26  	"github.com/matrixorigin/matrixone/pkg/common/moerr"
    27  	"github.com/matrixorigin/matrixone/pkg/common/runtime"
    28  	"github.com/stretchr/testify/assert"
    29  	"github.com/stretchr/testify/require"
    30  )
    31  
    32  func TestRPCSend(t *testing.T) {
    33  	runRPCTests(
    34  		t,
    35  		func(
    36  			addr string,
    37  			c RPCClient,
    38  			h MessageHandler[*testMethodBasedMessage, *testMethodBasedMessage]) {
    39  			fn := func(
    40  				ctx context.Context,
    41  				req, resp *testMethodBasedMessage) error {
    42  				resp.payload = []byte{byte(req.method)}
    43  				return nil
    44  			}
    45  			h.RegisterHandleFunc(
    46  				1,
    47  				fn,
    48  				false)
    49  			h.RegisterHandleFunc(
    50  				2,
    51  				fn,
    52  				false)
    53  			ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
    54  			defer cancel()
    55  
    56  			for i := uint32(0); i <= 2; i++ {
    57  				f, err := c.Send(ctx, addr, &testMethodBasedMessage{method: i})
    58  				require.NoError(t, err)
    59  				defer f.Close()
    60  				v, err := f.Get()
    61  				require.NoError(t, err)
    62  				resp := v.(*testMethodBasedMessage)
    63  				assert.Equal(t, i, resp.method)
    64  				if i == 0 {
    65  					assert.Error(t, resp.UnwrapError())
    66  				} else {
    67  					assert.Equal(t, []byte{byte(i)}, resp.payload)
    68  				}
    69  			}
    70  		},
    71  	)
    72  }
    73  
    74  func TestRequestCanBeFilter(t *testing.T) {
    75  	runRPCTests(
    76  		t,
    77  		func(
    78  			addr string,
    79  			c RPCClient,
    80  			h MessageHandler[*testMethodBasedMessage, *testMethodBasedMessage]) {
    81  			fn := func(
    82  				ctx context.Context,
    83  				req, resp *testMethodBasedMessage) error {
    84  				resp.payload = []byte{byte(req.method)}
    85  				return nil
    86  			}
    87  			h.RegisterHandleFunc(
    88  				1,
    89  				fn,
    90  				false)
    91  			ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*100)
    92  			defer cancel()
    93  
    94  			f, err := c.Send(ctx, addr, &testMethodBasedMessage{method: 1})
    95  			require.NoError(t, err)
    96  			defer f.Close()
    97  			_, err = f.Get()
    98  			require.Error(t, err)
    99  		},
   100  		WithHandleMessageFilter[*testMethodBasedMessage, *testMethodBasedMessage](func(tmbm *testMethodBasedMessage) bool {
   101  			return false
   102  		}),
   103  	)
   104  }
   105  
   106  func runRPCTests(
   107  	t *testing.T,
   108  	fn func(string, RPCClient, MessageHandler[*testMethodBasedMessage, *testMethodBasedMessage]),
   109  	opts ...HandlerOption[*testMethodBasedMessage, *testMethodBasedMessage]) {
   110  	defer leaktest.AfterTest(t)()
   111  	testSockets := fmt.Sprintf("unix:///tmp/%d.sock", time.Now().Nanosecond())
   112  	assert.NoError(t, os.RemoveAll(testSockets[7:]))
   113  	runtime.SetupProcessLevelRuntime(runtime.DefaultRuntime())
   114  
   115  	s, err := NewMessageHandler(
   116  		"test",
   117  		testSockets,
   118  		Config{},
   119  		NewMessagePool(
   120  			func() *testMethodBasedMessage { return &testMethodBasedMessage{} },
   121  			func() *testMethodBasedMessage { return &testMethodBasedMessage{} }),
   122  		opts...)
   123  	require.NoError(t, err)
   124  	defer func() {
   125  		assert.NoError(t, s.Close())
   126  	}()
   127  	require.NoError(t, s.Start())
   128  
   129  	cfg := Config{}
   130  	c, err := cfg.NewClient("ctlservice",
   131  		getLogger().RawLogger(),
   132  		func() Message { return &testMethodBasedMessage{} })
   133  	require.NoError(t, err)
   134  	defer func() {
   135  		assert.NoError(t, c.Close())
   136  	}()
   137  
   138  	fn(testSockets, c, s)
   139  }
   140  
   141  type testMethodBasedMessage struct {
   142  	testMessage
   143  	method uint32
   144  	err    []byte
   145  }
   146  
   147  func (m *testMethodBasedMessage) Reset() {
   148  	*m = testMethodBasedMessage{}
   149  }
   150  
   151  func (m *testMethodBasedMessage) Method() uint32 {
   152  	return m.method
   153  }
   154  
   155  func (m *testMethodBasedMessage) SetMethod(v uint32) {
   156  	m.method = v
   157  }
   158  
   159  func (m *testMethodBasedMessage) WrapError(err error) {
   160  	me := moerr.ConvertGoError(context.TODO(), err).(*moerr.Error)
   161  	data, e := me.MarshalBinary()
   162  	if e != nil {
   163  		panic(e)
   164  	}
   165  	m.err = data
   166  }
   167  
   168  func (m *testMethodBasedMessage) UnwrapError() error {
   169  	if len(m.err) == 0 {
   170  		return nil
   171  	}
   172  
   173  	err := &moerr.Error{}
   174  	if e := err.UnmarshalBinary(m.err); e != nil {
   175  		panic(e)
   176  	}
   177  	return err
   178  }
   179  
   180  func (m *testMethodBasedMessage) Size() int {
   181  	return 12 + len(m.err) + len(m.payload)
   182  }
   183  
   184  func (m *testMethodBasedMessage) MarshalTo(data []byte) (int, error) {
   185  	buf.Uint64ToBytesTo(m.id, data)
   186  	buf.Uint32ToBytesTo(m.method, data[8:])
   187  	if len(m.err) > 0 {
   188  		copy(data[12:], m.err)
   189  	}
   190  	return 12 + len(m.err), nil
   191  }
   192  
   193  func (m *testMethodBasedMessage) Unmarshal(data []byte) error {
   194  	m.id = buf.Byte2Uint64(data)
   195  	m.method = buf.Byte2Uint32(data[8:])
   196  	if len(data) > 12 {
   197  		err := data[12:]
   198  		m.err = make([]byte, len(err))
   199  		copy(m.err, err)
   200  	}
   201  	return nil
   202  }