github.com/nyan233/littlerpc@v0.4.6-0.20230316182519-0c8d5c48abaf/core/common/msgparser/parser_test.go (about) 1 package msgparser 2 3 import ( 4 "encoding/json" 5 "errors" 6 "github.com/nyan233/littlerpc/core/common/jsonrpc2" 7 "github.com/nyan233/littlerpc/core/container" 8 message2 "github.com/nyan233/littlerpc/core/protocol/message" 9 "github.com/nyan233/littlerpc/core/protocol/message/gen" 10 mux2 "github.com/nyan233/littlerpc/core/protocol/message/mux" 11 "github.com/nyan233/littlerpc/core/utils/random" 12 "github.com/stretchr/testify/assert" 13 "github.com/stretchr/testify/suite" 14 "math" 15 "strconv" 16 "strings" 17 "sync" 18 "testing" 19 ) 20 21 type ParserFullTest struct { 22 parser Parser 23 suite.Suite 24 } 25 26 func TestParser(t *testing.T) { 27 suite.Run(t, new(ParserFullTest)) 28 } 29 30 func (f *ParserFullTest) SetupTest() { 31 f.parser = Get(DefaultParser)(NewDefaultSimpleAllocTor(), 4096).(*lRPCTrait) 32 } 33 34 func (f *ParserFullTest) TestLRPCParser() { 35 t := f.T() 36 t.Run("ParseOnData", testParser(f.parser.(*lRPCTrait), func(data []byte) ([]ParserMessage, error) { 37 return f.parser.Parse(data) 38 })) 39 t.Run("ParseOnReader", testParser(f.parser.(*lRPCTrait), func(data []byte) ([]ParserMessage, error) { 40 var read bool 41 return f.parser.ParseOnReader(func(bytes []byte) (n int, err error) { 42 if read { 43 return -1, errors.New("already read") 44 } 45 read = true 46 return copy(bytes, data), nil 47 }) 48 })) 49 } 50 51 func testParser(p *lRPCTrait, parseFunc func(data []byte) ([]ParserMessage, error)) func(t *testing.T) { 52 return func(t *testing.T) { 53 msg := message2.New() 54 msg.SetMsgId(uint64(random.FastRand())) 55 msg.SetServiceName("TestParser/LocalTest") 56 msg.MetaData.Store("Key", "Value") 57 msg.MetaData.Store("Key2", "Value2") 58 msg.MetaData.Store("Key3", "Value3") 59 msg.AppendPayloads([]byte("hello world")) 60 msg.AppendPayloads([]byte("65536")) 61 msg.Length() 62 var marshalBytes []byte 63 err := message2.Marshal(msg, (*container.Slice[byte])(&marshalBytes)) 64 assert.NoError(t, err) 65 muxBlock := mux2.Block{ 66 Flags: mux2.Enabled, 67 StreamId: random.FastRand(), 68 MsgId: uint64(random.FastRand()), 69 } 70 muxBlock.SetPayloads(marshalBytes) 71 var muxMarshalBytes []byte 72 mux2.Marshal(&muxBlock, (*container.Slice[byte])(&muxMarshalBytes)) 73 marshalBytes = append(marshalBytes, muxMarshalBytes...) 74 _, err = parseFunc(marshalBytes[:11]) 75 assert.NoError(t, err) 76 allMasg, err := parseFunc(marshalBytes[11 : msg.Length()+20]) 77 assert.NoError(t, err) 78 assert.Equal(t, len(allMasg), 1) 79 allMasg, err = parseFunc(marshalBytes[msg.Length()+20:]) 80 assert.NoError(t, err) 81 assert.Equal(t, len(allMasg), 1) 82 assert.Equal(t, len(p.halfBuffer), 0) 83 assert.Equal(t, p.startOffset, 0) 84 assert.Equal(t, p.endOffset, 0) 85 assert.Equal(t, p.clickInterval, 1) 86 } 87 } 88 89 func (f *ParserFullTest) TestConcurrentHalfParse() { 90 const ( 91 ConsumerSize = 16 92 ChanBufferSize = 8 93 CycleSize = 1000 94 OnePushMax = 20 95 ) 96 t := f.T() 97 producer := func(channels []chan []byte, data []byte, cycleSize int) { 98 for i := 0; i < cycleSize; i++ { 99 tmpData := data 100 for len(tmpData) > 0 { 101 var readN int 102 if len(tmpData) >= OnePushMax { 103 readN = OnePushMax 104 } else { 105 readN = len(tmpData) 106 } 107 for _, channel := range channels { 108 channel <- tmpData[:readN] 109 } 110 tmpData = tmpData[readN:] 111 } 112 } 113 for _, channel := range channels { 114 close(channel) 115 } 116 } 117 consumer := func(parser Parser, channel chan []byte, checkHeader byte, wg *sync.WaitGroup) { 118 defer wg.Done() 119 for { 120 select { 121 case data, ok := <-channel: 122 if !ok { 123 return 124 } 125 msgs, err := parser.Parse(data) 126 if err != nil { 127 t.Error(err) 128 } 129 if msgs != nil && len(msgs) > 0 { 130 for _, msg := range msgs { 131 assert.Equal(t, checkHeader, msg.Header) 132 parser.Free(msg.Message) 133 } 134 } 135 } 136 } 137 } 138 consumerChannels := make([]chan []byte, ConsumerSize) 139 for k := range consumerChannels { 140 consumerChannels[k] = make(chan []byte, ChanBufferSize) 141 } 142 var wg sync.WaitGroup 143 wg.Add(ConsumerSize) 144 for _, v := range consumerChannels { 145 go consumer(NewLRPCTrait(NewDefaultSimpleAllocTor(), MaxBufferSize), v, message2.MagicNumber, &wg) 146 } 147 go producer(consumerChannels, gen.NoMuxToBytes(gen.Big), CycleSize) 148 wg.Wait() 149 } 150 151 func (f *ParserFullTest) TestJsonRPC2Parser() { 152 t := f.T() 153 request := new(jsonrpc2.Request) 154 request.Version = jsonrpc2.Version 155 request.MessageType = int(message2.Call) 156 request.Method = "Test.JsonRPC2Case1" 157 request.MetaData = map[string]string{ 158 "context-id": strconv.FormatInt(int64(random.FastRand()), 10), 159 "streamId": strconv.FormatInt(int64(random.FastRand()), 10), 160 "codec": "json", 161 "packer": "text", 162 } 163 request.Id = uint64(random.FastRand()) 164 request.Params = []byte("[1203,\"hello world\",3563]") 165 bytes, err := json.Marshal(request) 166 assert.NoError(t, err) 167 parser := f.parser 168 msg, err := parser.Parse(bytes) 169 assert.Nil(t, err, err) 170 assert.Equal(t, len(msg), 1) 171 172 iter := msg[0].Message.PayloadsIterator() 173 assert.Equal(t, iter.Tail(), 3) 174 var i int 175 for iter.Next() { 176 i++ 177 switch i { 178 case 1: 179 assert.Equal(t, string(iter.Take()), "1203") 180 case 2: 181 assert.Equal(t, string(iter.Take()), "\"hello world\"") 182 case 3: 183 assert.Equal(t, string(iter.Take()), "3563") 184 } 185 } 186 assert.Equal(t, msg[0].Message.GetServiceName(), "Test.JsonRPC2Case1") 187 188 // 测试是否能够处理错误的消息类型 189 request.MessageType = 0x889839 190 bytes, err = json.Marshal(request) 191 assert.Nil(t, err, err) 192 msg, err = parser.Parse(bytes) 193 assert.NotNil(t, err, "input error data but marshal able") 194 } 195 196 func TestHandler(t *testing.T) { 197 for i := uint8(0); true; i++ { 198 GetHandler(i) 199 if i == math.MaxUint8 { 200 break 201 } 202 } 203 defer func() { 204 assert.NotNil(t, recover()) 205 }() 206 RegisterHandler(nil) 207 } 208 209 func parserOnBytes(s string) []byte { 210 s = s[1 : len(s)-1] 211 sp := strings.Split(s, " ") 212 bs := make([]byte, 0, len(sp)) 213 for _, ss := range sp { 214 b, _ := strconv.Atoi(ss) 215 bs = append(bs, byte(b)) 216 } 217 return bs 218 }