trpc.group/trpc-go/trpc-go@v1.0.3/transport/tnet/multiplex/multiplex_test.go (about) 1 // 2 // 3 // Tencent is pleased to support the open source community by making tRPC available. 4 // 5 // Copyright (C) 2023 THL A29 Limited, a Tencent company. 6 // All rights reserved. 7 // 8 // If you have downloaded a copy of the tRPC source code from Tencent, 9 // please note that tRPC source code is licensed under the Apache 2.0 License, 10 // A copy of the Apache 2.0 License is included in this file. 11 // 12 // 13 14 //go:build linux || freebsd || dragonfly || darwin 15 // +build linux freebsd dragonfly darwin 16 17 package multiplex_test 18 19 import ( 20 "bytes" 21 "context" 22 "encoding/binary" 23 "errors" 24 "io" 25 "net" 26 "sync" 27 "sync/atomic" 28 "testing" 29 "time" 30 31 "github.com/stretchr/testify/require" 32 33 "trpc.group/trpc-go/trpc-go/pool/connpool" 34 "trpc.group/trpc-go/trpc-go/pool/multiplexed" 35 "trpc.group/trpc-go/trpc-go/transport/tnet" 36 "trpc.group/trpc-go/trpc-go/transport/tnet/multiplex" 37 ) 38 39 var ( 40 helloworld = []byte("hello world") 41 reqID uint32 42 ) 43 44 var ( 45 _ (multiplexed.FrameParser) = (*simpleFrameParser)(nil) 46 ) 47 48 /* 49 | 4 byte | 4 byte | bodyLen byte | 50 | bodyLen | id | body | 51 */ 52 type simpleFrameParser struct { 53 isParseFail bool 54 } 55 56 func (fr *simpleFrameParser) Parse(reader io.Reader) (uint32, []byte, error) { 57 head := make([]byte, 8) 58 n, err := io.ReadFull(reader, head) 59 if err != nil { 60 return 0, nil, err 61 } 62 63 if fr.isParseFail { 64 return 0, nil, errors.New("decode fail") 65 } 66 67 if n != 8 { 68 return 0, nil, errors.New("invalid read full num") 69 } 70 71 bodyLen := binary.BigEndian.Uint32(head[:4]) 72 id := binary.BigEndian.Uint32(head[4:8]) 73 body := make([]byte, int(bodyLen)) 74 75 n, err = io.ReadFull(reader, body) 76 if err != nil { 77 return 0, nil, err 78 } 79 80 if n != int(bodyLen) { 81 return 0, nil, errors.New("invalid read full body") 82 } 83 84 return id, body, nil 85 } 86 87 func encodeFrame(id uint32, body []byte) []byte { 88 bodyLen := len(body) 89 buf := bytes.NewBuffer(make([]byte, 0, 8+bodyLen)) 90 if err := binary.Write(buf, binary.BigEndian, uint32(bodyLen)); err != nil { 91 panic(err) 92 } 93 if err := binary.Write(buf, binary.BigEndian, uint32(id)); err != nil { 94 panic(err) 95 } 96 97 if _, err := buf.Write(body); err != nil { 98 panic(err) 99 } 100 101 return buf.Bytes() 102 } 103 104 func getReqID() uint32 { 105 return atomic.AddUint32(&reqID, 1) 106 } 107 108 func echo(c net.Conn) { 109 io.Copy(c, c) 110 } 111 112 func beginServer(t *testing.T, handle func(net.Conn)) (net.Addr, context.CancelFunc) { 113 ctx, cancel := context.WithCancel(context.Background()) 114 addrCh := make(chan net.Addr, 1) 115 go func() { 116 l, err := net.Listen("tcp", "127.0.0.1:0") 117 require.Nil(t, err) 118 addrCh <- l.Addr() 119 go func() { 120 for { 121 conn, err := l.Accept() 122 if err != nil { 123 require.NotNil(t, ctx.Err()) 124 return 125 } 126 go handle(conn) 127 } 128 }() 129 <-ctx.Done() 130 l.Close() 131 }() 132 addr := <-addrCh 133 return addr, cancel 134 } 135 136 func TestBasic(t *testing.T) { 137 addr, cancel := beginServer(t, echo) 138 defer cancel() 139 140 getOpts := func() (uint32, multiplexed.GetOptions) { 141 id := getReqID() 142 opts := multiplexed.NewGetOptions() 143 opts.WithFrameParser(&simpleFrameParser{}) 144 opts.WithVID(id) 145 return id, opts 146 } 147 148 t.Run("Multiple Conns Concurrent Read Write", func(t *testing.T) { 149 pool := multiplex.NewPool( 150 tnet.Dial, 151 multiplex.WithEnableMetrics(), 152 multiplex.WithMaxConcurrentVirConnsPerConn(500), 153 ) 154 var wg sync.WaitGroup 155 for i := 0; i < 100; i++ { 156 wg.Add(1) 157 go func() { 158 defer wg.Done() 159 for i := 0; i < 100; i++ { 160 id, opts := getOpts() 161 conn, err := pool.GetMuxConn(context.Background(), addr.Network(), addr.String(), opts) 162 require.Nil(t, err) 163 164 err = conn.Write(encodeFrame(id, helloworld)) 165 require.Nil(t, err) 166 b, err := conn.Read() 167 require.Nil(t, err) 168 require.Equal(t, helloworld, b) 169 conn.Close() 170 } 171 }() 172 } 173 wg.Wait() 174 }) 175 } 176 177 func TestGetConnection(t *testing.T) { 178 addr, cancel := beginServer(t, echo) 179 defer cancel() 180 muxPool := multiplex.NewPool(tnet.Dial) 181 182 getOpts := func() multiplexed.GetOptions { 183 opts := multiplexed.NewGetOptions() 184 opts.WithFrameParser(&simpleFrameParser{}) 185 opts.WithVID(getReqID()) 186 return opts 187 } 188 189 t.Run("Get Once", func(t *testing.T) { 190 opts := getOpts() 191 conn, err := muxPool.GetMuxConn(context.Background(), addr.Network(), addr.String(), opts) 192 require.Nil(t, err) 193 conn.Close() 194 }) 195 t.Run("Get Multiple Succeed", func(t *testing.T) { 196 opts := getOpts() 197 conn, err := muxPool.GetMuxConn(context.Background(), addr.Network(), addr.String(), opts) 198 require.Nil(t, err) 199 conn.Close() 200 localAddr := conn.LocalAddr() 201 for i := 0; i < 9; i++ { 202 opts := getOpts() 203 conn, err := muxPool.GetMuxConn(context.Background(), addr.Network(), addr.String(), opts) 204 require.Nil(t, err) 205 require.Equal(t, localAddr, conn.LocalAddr()) 206 conn.Close() 207 } 208 }) 209 t.Run("Exceed MaxConcurrentVirConns", func(t *testing.T) { 210 muxPool := multiplex.NewPool(tnet.Dial, multiplex.WithMaxConcurrentVirConnsPerConn(1)) 211 212 opts := getOpts() 213 c1, err := muxPool.GetMuxConn(context.Background(), addr.Network(), addr.String(), opts) 214 require.Nil(t, err) 215 defer c1.Close() 216 217 opts = getOpts() 218 c2, err := muxPool.GetMuxConn(context.Background(), addr.Network(), addr.String(), opts) 219 require.Nil(t, err) 220 require.NotEqual(t, c1.LocalAddr(), c2.LocalAddr()) 221 defer c2.Close() 222 }) 223 t.Run("Request ID Already Exist", func(t *testing.T) { 224 opts := getOpts() 225 c1, err := muxPool.GetMuxConn(context.Background(), addr.Network(), addr.String(), opts) 226 require.Nil(t, err) 227 defer c1.Close() 228 229 _, err = muxPool.GetMuxConn(context.Background(), addr.Network(), addr.String(), opts) 230 require.Equal(t, multiplex.ErrDuplicateID, err) 231 }) 232 t.Run("Empty FrameParser", func(t *testing.T) { 233 opts := getOpts() 234 opts.WithFrameParser(nil) 235 _, err := muxPool.GetMuxConn(context.Background(), addr.Network(), addr.String(), opts) 236 require.Contains(t, "frame parser is not provided", err.Error()) 237 }) 238 } 239 240 func TestDial(t *testing.T) { 241 addr, cancel := beginServer(t, echo) 242 defer cancel() 243 244 getOpts := func() (context.Context, context.CancelFunc, multiplexed.GetOptions) { 245 ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(200*time.Millisecond)) 246 opts := multiplexed.NewGetOptions() 247 opts.WithFrameParser(&simpleFrameParser{}) 248 opts.WithVID(getReqID()) 249 return ctx, cancel, opts 250 } 251 252 t.Run("Dial Succeed", func(t *testing.T) { 253 muxPool := multiplex.NewPool(tnet.Dial) 254 ctx, cancel, opts := getOpts() 255 defer cancel() 256 conn, err := muxPool.GetMuxConn(ctx, addr.Network(), addr.String(), opts) 257 require.Nil(t, err) 258 conn.Close() 259 }) 260 t.Run("Dial Timeout", func(t *testing.T) { 261 muxPool := multiplex.NewPool(func(opts *connpool.DialOptions) (net.Conn, error) { 262 time.Sleep(time.Second) 263 return nil, errors.New("dial fail") 264 }) 265 ctx, cancel, opts := getOpts() 266 defer cancel() 267 _, err := muxPool.GetMuxConn(ctx, addr.Network(), addr.String(), opts) 268 require.Equal(t, context.DeadlineExceeded, err) 269 }) 270 t.Run("Dial Error", func(t *testing.T) { 271 muxPool := multiplex.NewPool(func(opts *connpool.DialOptions) (net.Conn, error) { 272 return nil, errors.New("dial error") 273 }) 274 ctx, cancel, opts := getOpts() 275 defer cancel() 276 _, err := muxPool.GetMuxConn(ctx, addr.Network(), addr.String(), opts) 277 require.Equal(t, errors.New("dial error"), err) 278 }) 279 t.Run("Dial Gonet", func(t *testing.T) { 280 muxPool := multiplex.NewPool(func(opts *connpool.DialOptions) (net.Conn, error) { 281 return net.Dial(opts.Network, opts.Address) 282 }) 283 ctx, cancel, opts := getOpts() 284 defer cancel() 285 _, err := muxPool.GetMuxConn(ctx, addr.Network(), addr.String(), opts) 286 require.Contains(t, "dialed connection must implements tnet.Conn", err.Error()) 287 }) 288 } 289 290 func TestClose(t *testing.T) { 291 muxPool := multiplex.NewPool(tnet.Dial) 292 getOpts := func() (uint32, multiplexed.GetOptions) { 293 id := getReqID() 294 opts := multiplexed.NewGetOptions() 295 opts.WithFrameParser(&simpleFrameParser{}) 296 opts.WithVID(id) 297 return id, opts 298 } 299 300 t.Run("Server Close Conn After Accept", func(t *testing.T) { 301 addr, cancel := beginServer(t, func(c net.Conn) { 302 c.Close() 303 }) 304 defer cancel() 305 var wg sync.WaitGroup 306 for i := 0; i < 1000; i++ { 307 wg.Add(1) 308 _, opts := getOpts() 309 go func() { 310 defer wg.Done() 311 conn, err := muxPool.GetMuxConn(context.Background(), addr.Network(), addr.String(), opts) 312 if err != nil { 313 return 314 } 315 _, err = conn.Read() 316 require.Contains(t, err.Error(), multiplex.ErrConnClosed.Error()) 317 err = conn.Write(nil) 318 require.Contains(t, err.Error(), multiplex.ErrConnClosed.Error()) 319 conn.Close() 320 }() 321 } 322 wg.Wait() 323 }) 324 325 t.Run("Decode Fail", func(t *testing.T) { 326 addr, cancel := beginServer(t, echo) 327 defer cancel() 328 // return error when decode fail. 329 for i := 0; i < 5; i++ { 330 id, opts := getOpts() 331 opts.WithFrameParser(&simpleFrameParser{isParseFail: true}) 332 conn, err := muxPool.GetMuxConn(context.Background(), addr.Network(), addr.String(), opts) 333 require.Nil(t, err) 334 335 err = conn.Write(encodeFrame(id, helloworld)) 336 require.Nil(t, err) 337 _, err = conn.Read() 338 require.Contains(t, err.Error(), "decode fail") 339 conn.Close() 340 } 341 // return nil when decode succeed. 342 for i := 0; i < 5; i++ { 343 id, opts := getOpts() 344 opts.WithFrameParser(&simpleFrameParser{}) 345 conn, err := muxPool.GetMuxConn(context.Background(), addr.Network(), addr.String(), opts) 346 require.Nil(t, err) 347 348 err = conn.Write(encodeFrame(id, helloworld)) 349 require.Nil(t, err) 350 _, err = conn.Read() 351 require.Nil(t, err) 352 conn.Close() 353 } 354 }) 355 }