trpc.group/trpc-go/trpc-go@v1.0.3/transport/client_transport_stream_test.go (about) 1 // 2 // 3 // Tencent is pleased to support the open source community by making tRPC available. 4 // 5 // Copyright (C) 2023 THL A29 Limited, a Tencent company. 6 // All rights reserved. 7 // 8 // If you have downloaded a copy of the tRPC source code from Tencent, 9 // please note that tRPC source code is licensed under the Apache 2.0 License, 10 // A copy of the Apache 2.0 License is included in this file. 11 // 12 // 13 14 package transport_test 15 16 import ( 17 "context" 18 "encoding/binary" 19 "encoding/json" 20 "fmt" 21 "io" 22 "sync" 23 "sync/atomic" 24 "testing" 25 "time" 26 27 "trpc.group/trpc-go/trpc-go/codec" 28 29 "github.com/stretchr/testify/assert" 30 "github.com/stretchr/testify/require" 31 32 "trpc.group/trpc-go/trpc-go/transport" 33 34 _ "trpc.group/trpc-go/trpc-go" 35 ) 36 37 // TestClientStreamNetworkError test client decode error. 38 func TestClientStreamNetworkError(t *testing.T) { 39 st := transport.NewServerStreamTransport() 40 svrCtx, close := context.WithTimeout(context.Background(), 100*time.Millisecond) 41 defer close() 42 go func() { 43 err := st.ListenAndServe(svrCtx, 44 transport.WithListenNetwork("tcp"), 45 transport.WithListenAddress(":12017"), 46 transport.WithHandler(&echoHandler{}), 47 transport.WithServerFramerBuilder(&multiplexedFramerBuilder{}), 48 ) 49 require.Nil(t, err) 50 }() 51 52 roundTripOpts := []transport.RoundTripOption{ 53 transport.WithDialNetwork("tcp"), 54 transport.WithDialAddress(":12017"), 55 } 56 57 time.Sleep(20 * time.Millisecond) 58 req := &helloRequest{ 59 Name: "trpc", 60 Msg: "HelloWorld", 61 } 62 63 data, err := json.Marshal(req) 64 require.Nil(t, err) 65 66 lenData := make([]byte, 4) 67 binary.BigEndian.PutUint32(lenData, uint32(len(data))) 68 69 headData := make([]byte, 8) 70 binary.BigEndian.PutUint32(headData[:4], defaultStreamID) 71 binary.BigEndian.PutUint32(headData[4:8], uint32(len(data))) 72 73 ctx := context.Background() 74 ctx, msg := codec.WithNewMessage(ctx) 75 msg.WithStreamID(100) 76 77 // test IO EOF. 78 ct := transport.NewClientStreamTransport() 79 fb := &multiplexedFramerBuilder{} 80 fb.SetError(io.EOF) 81 roundTripOpts = append(roundTripOpts, transport.WithClientFramerBuilder(fb), transport.WithMsg(msg)) 82 err = ct.Init(ctx, roundTripOpts...) 83 assert.Nil(t, err) 84 rsp, err := ct.Recv(ctx) 85 assert.Equal(t, io.EOF, err, fmt.Sprintf("current err: %+v", err)) 86 assert.Nil(t, rsp) 87 88 // test ctx canceled. 89 msg.WithStreamID(101) 90 fb = &multiplexedFramerBuilder{} 91 ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) 92 cancel() 93 roundTripOpts = append(roundTripOpts, transport.WithClientFramerBuilder(fb), transport.WithMsg(msg)) 94 err = ct.Init(ctx, roundTripOpts...) 95 assert.NotNil(t, err) 96 97 // test ctx timeout. 98 msg.WithStreamID(102) 99 fb = &multiplexedFramerBuilder{} 100 ctx, cancel = context.WithTimeout(context.Background(), time.Millisecond) 101 defer cancel() 102 <-ctx.Done() 103 roundTripOpts = append(roundTripOpts, transport.WithClientFramerBuilder(fb), transport.WithMsg(msg)) 104 err = ct.Init(ctx, roundTripOpts...) 105 assert.NotNil(t, err) 106 } 107 108 func TestConcurrent(t *testing.T) { 109 st := transport.NewServerStreamTransport() 110 serverFinish := make(chan int) 111 go func() { 112 err := st.ListenAndServe(context.Background(), 113 transport.WithListenNetwork("tcp,udp"), 114 transport.WithListenAddress(":12015"), 115 transport.WithHandler(&echoHandler{}), 116 transport.WithServerFramerBuilder(&multiplexedFramerBuilder{}), 117 ) 118 require.Nil(t, err) 119 serverFinish <- 1 120 }() 121 <-serverFinish 122 123 req := &helloRequest{ 124 Name: "trpc", 125 Msg: "HelloWorld", 126 } 127 data, err := json.Marshal(req) 128 require.Nil(t, err) 129 headData := make([]byte, 8) // head = streamID + data length 130 binary.BigEndian.PutUint32(headData[4:8], uint32(len(data))) 131 reqData := append(headData, data...) 132 133 ct := transport.NewClientStreamTransport(transport.WithMaxConcurrentStreams(20), transport.WithMaxIdleConnsPerHost(2)) 134 135 // close stream send and receive. 136 var wg sync.WaitGroup 137 var index uint32 138 for i := 0; i < 200; i++ { 139 wg.Add(1) 140 go func() { 141 ctx := context.Background() 142 ctx, msg := codec.WithNewMessage(ctx) 143 newIndex := atomic.AddUint32(&index, 1) 144 streamID := defaultStreamID + newIndex 145 msg.WithStreamID(streamID) 146 msg.WithRequestID(streamID) 147 148 err = ct.Init(ctx, transport.WithDialNetwork("tcp"), transport.WithDialAddress(":12015"), 149 transport.WithClientFramerBuilder(&multiplexedFramerBuilder{}), 150 transport.WithMsg(msg)) 151 assert.Nil(t, err) 152 153 copyData := make([]byte, len(reqData)) 154 copy(copyData, reqData) 155 binary.BigEndian.PutUint32(copyData, streamID) 156 157 err = ct.Send(ctx, copyData) 158 assert.Nil(t, err) 159 rspData, err := ct.Recv(ctx) 160 assert.Nil(t, err) 161 assert.Equal(t, copyData, rspData) 162 ct.Close(ctx) 163 wg.Done() 164 }() 165 if i%50 == 0 { 166 time.Sleep(50 * time.Millisecond) 167 } 168 } 169 wg.Wait() 170 } 171 172 /* --------------------------------------------------- mock multiplexed framer ---------------------------------------------- */ 173 174 type multiplexedFramerBuilder struct { 175 errSet bool 176 err error 177 safe bool 178 } 179 180 func (fb *multiplexedFramerBuilder) SetError(err error) { 181 fb.errSet = true 182 fb.err = err 183 } 184 185 func (fb *multiplexedFramerBuilder) ClearError() { 186 fb.errSet = false 187 fb.err = nil 188 } 189 190 func (fb *multiplexedFramerBuilder) New(r io.Reader) codec.Framer { 191 return &multiplexedFramer{r: r, fb: fb} 192 } 193 194 func (fb *multiplexedFramerBuilder) Parse(rc io.Reader) (vid uint32, buf []byte, err error) { 195 buf, err = fb.New(rc).ReadFrame() 196 if err != nil { 197 return 0, nil, err 198 } 199 return binary.BigEndian.Uint32(buf[:4]), buf, nil 200 } 201 202 type multiplexedFramer struct { 203 fb *multiplexedFramerBuilder 204 r io.Reader 205 } 206 207 func (f *multiplexedFramer) ReadFrame() ([]byte, error) { 208 if f.fb.errSet { 209 return nil, f.fb.err 210 } 211 var headData [8]byte 212 213 _, err := io.ReadFull(f.r, headData[:]) 214 if err != nil { 215 return nil, err 216 } 217 218 length := binary.BigEndian.Uint32(headData[4:]) 219 220 msg := make([]byte, len(headData)+int(length)) 221 copy(msg, headData[:]) 222 223 _, err = io.ReadFull(f.r, msg[len(headData):]) 224 if err != nil { 225 return nil, err 226 } 227 return msg, nil 228 } 229 230 func (f *multiplexedFramer) IsSafe() bool { 231 return f.fb.safe 232 }