trpc.group/trpc-go/trpc-go@v1.0.3/transport/tnet/client_transport_tcp_test.go (about) 1 // 2 // 3 // Tencent is pleased to support the open source community by making tRPC available. 4 // 5 // Copyright (C) 2023 THL A29 Limited, a Tencent company. 6 // All rights reserved. 7 // 8 // If you have downloaded a copy of the tRPC source code from Tencent, 9 // please note that tRPC source code is licensed under the Apache 2.0 License, 10 // A copy of the Apache 2.0 License is included in this file. 11 // 12 // 13 14 //go:build linux || freebsd || dragonfly || darwin 15 // +build linux freebsd dragonfly darwin 16 17 package tnet_test 18 19 import ( 20 "context" 21 "net" 22 "os" 23 "testing" 24 "time" 25 26 "github.com/stretchr/testify/assert" 27 "trpc.group/trpc-go/tnet" 28 29 trpc "trpc.group/trpc-go/trpc-go" 30 "trpc.group/trpc-go/trpc-go/codec" 31 "trpc.group/trpc-go/trpc-go/errs" 32 "trpc.group/trpc-go/trpc-go/pool/connpool" 33 "trpc.group/trpc-go/trpc-go/transport" 34 tnettrans "trpc.group/trpc-go/trpc-go/transport/tnet" 35 ) 36 37 func TestClientTCP(t *testing.T) { 38 startClientTest( 39 t, 40 defaultServerHandle, 41 nil, 42 func(addr string) { 43 ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) 44 defer cancel() 45 rsp, err := tnetRequest(ctx, helloWorld, 46 transport.WithDialAddress(addr), 47 transport.WithDialTimeout(500*time.Millisecond), 48 ) 49 assert.Equal(t, helloWorld, rsp) 50 assert.Nil(t, err) 51 }, 52 ) 53 } 54 55 func TestClientTCP_NoFrameBuilder(t *testing.T) { 56 startClientTest( 57 t, 58 defaultServerHandle, 59 nil, 60 func(addr string) { 61 ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) 62 defer cancel() 63 _, err := tnetRequest(ctx, helloWorld, 64 transport.WithDialAddress(addr), 65 transport.WithClientFramerBuilder(nil), 66 ) 67 assert.Equal(t, errs.RetClientConnectFail, errs.Code(err)) 68 }, 69 ) 70 } 71 72 func TestClientTCP_CtxErr(t *testing.T) { 73 startClientTest( 74 t, 75 defaultServerHandle, 76 nil, 77 func(addr string) { 78 // canceled context error 79 ctx, cancel := context.WithCancel(context.Background()) 80 cancel() 81 _, err := tnetRequest(ctx, helloWorld, 82 transport.WithDialAddress(addr), 83 ) 84 assert.Equal(t, errs.RetClientCanceled, errs.Code(err)) 85 86 // timeout context error 87 ctx, cancel = context.WithDeadline(context.Background(), time.Now().Add(time.Nanosecond)) 88 defer cancel() 89 time.Sleep(time.Nanosecond) 90 _, err = tnetRequest(ctx, helloWorld, 91 transport.WithDialAddress(addr), 92 ) 93 assert.Equal(t, errs.RetClientTimeout, errs.Code(err)) 94 }, 95 ) 96 } 97 98 func TestClientTCP_DisableConnPool(t *testing.T) { 99 // success case 100 startClientTest( 101 t, 102 defaultServerHandle, 103 nil, 104 func(addr string) { 105 rsp, err := tnetRequest( 106 context.Background(), 107 helloWorld, 108 transport.WithDialAddress(addr), 109 transport.WithDisableConnectionPool(), 110 ) 111 assert.Nil(t, err) 112 assert.Equal(t, helloWorld, rsp) 113 }, 114 ) 115 // dial wrong address 116 _, err := tnetRequest( 117 context.Background(), 118 helloWorld, 119 transport.WithDialAddress("0"), 120 transport.WithDisableConnectionPool(), 121 ) 122 assert.Equal(t, errs.RetClientConnectFail, errs.Code(err)) 123 } 124 125 func TestClientTCP_ReadTimeout(t *testing.T) { 126 startClientTest( 127 t, 128 func(ctx context.Context, req []byte) ([]byte, error) { 129 time.Sleep(time.Hour) 130 return nil, nil 131 }, 132 nil, 133 func(addr string) { 134 ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) 135 defer cancel() 136 _, err := tnetRequest( 137 ctx, 138 helloWorld, 139 transport.WithDialAddress(addr), 140 ) 141 assert.Equal(t, errs.RetClientTimeout, errs.Code(err)) 142 }, 143 ) 144 } 145 146 func TestClientTCP_CustomPool(t *testing.T) { 147 startClientTest( 148 t, 149 defaultServerHandle, 150 nil, 151 func(addr string) { 152 ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) 153 defer cancel() 154 rsp, err := tnetRequest( 155 ctx, 156 helloWorld, 157 transport.WithDialAddress(addr), 158 transport.WithDialPool(&customPool{}), 159 ) 160 assert.Equal(t, helloWorld, rsp) 161 assert.Nil(t, err) 162 }, 163 ) 164 } 165 166 func TestClientUDP(t *testing.T) { 167 // UDP is not supported, but it will switch to gonet default transport to roundtrip. 168 startClientTest( 169 t, 170 defaultServerHandle, 171 []transport.ListenServeOption{transport.WithListenNetwork("udp")}, 172 func(addr string) { 173 rsp, err := tnetRequest( 174 context.Background(), 175 helloWorld, 176 transport.WithDialAddress(addr), 177 transport.WithDialNetwork("udp")) 178 assert.Nil(t, err) 179 assert.Equal(t, helloWorld, rsp) 180 }, 181 ) 182 } 183 184 func TestClientUnix(t *testing.T) { 185 // Unix socket is not supported, but it will switch to gonet default transport to roundtrip. 186 unixAddr := "/tmp/server.sock" 187 os.Remove(unixAddr) 188 startClientTest( 189 t, 190 defaultServerHandle, 191 []transport.ListenServeOption{ 192 transport.WithListenAddress(unixAddr), 193 transport.WithListenNetwork("unix"), 194 }, 195 func(addr string) { 196 rsp, err := tnetRequest( 197 context.Background(), 198 helloWorld, 199 transport.WithDialAddress(unixAddr), 200 transport.WithDialNetwork("unix")) 201 assert.Nil(t, err) 202 assert.Equal(t, helloWorld, rsp) 203 }, 204 ) 205 206 } 207 208 func TestClientTCP_Multiplex(t *testing.T) { 209 startClientTest( 210 t, 211 defaultServerHandle, 212 nil, 213 func(addr string) { 214 req := helloWorld 215 ctx, msg := codec.EnsureMessage(context.Background()) 216 reqFrame, err := trpc.DefaultClientCodec.Encode(codec.Message(ctx), req) 217 assert.Nil(t, err) 218 219 cliOpts := getRoundTripOption( 220 transport.WithDialAddress(addr), 221 transport.WithMultiplexed(true), 222 transport.WithMsg(msg), 223 ) 224 clientTrans := tnettrans.NewClientTransport() 225 rspFrame, err := clientTrans.RoundTrip(ctx, reqFrame, cliOpts...) 226 assert.Nil(t, err) 227 rsp, err := trpc.DefaultClientCodec.Decode(msg, rspFrame) 228 assert.Nil(t, err) 229 assert.Equal(t, helloWorld, rsp) 230 }, 231 ) 232 } 233 234 func TestClientTCP_TLS(t *testing.T) { 235 startClientTest( 236 t, 237 defaultServerHandle, 238 []transport.ListenServeOption{transport.WithServeTLS("../../testdata/server.crt", "../../testdata/server.key", "../../testdata/ca.pem")}, 239 func(addr string) { 240 rsp, err := tnetRequest( 241 context.Background(), 242 helloWorld, 243 transport.WithDialAddress(addr), 244 transport.WithDialTLS("../../testdata/client.crt", "../../testdata/client.key", "../../testdata/ca.pem", "localhost"), 245 ) 246 assert.Nil(t, err) 247 assert.Equal(t, helloWorld, rsp) 248 249 rsp, err = tnetRequest( 250 context.Background(), 251 helloWorld, 252 transport.WithDialAddress(addr), 253 transport.WithDialTLS("../../testdata/client.crt", "../../testdata/client.key", "none", ""), 254 ) 255 assert.Nil(t, err) 256 assert.Equal(t, helloWorld, rsp) 257 }, 258 ) 259 } 260 261 func TestClientTCP_HealthCheck(t *testing.T) { 262 addr := getAddr() 263 s := transport.NewServerTransport() 264 serveOpts := getListenServeOption(transport.WithListenAddress(addr)) 265 err := s.ListenAndServe(context.Background(), serveOpts...) 266 assert.Nil(t, err) 267 268 c, err := net.Dial("tcp", addr) 269 assert.Nil(t, err) 270 assert.True(t, tnettrans.HealthChecker(&connpool.PoolConn{Conn: c}, true)) 271 272 c, err = tnet.DialTCP("tcp", addr, 0) 273 assert.Nil(t, err) 274 assert.True(t, tnettrans.HealthChecker(&connpool.PoolConn{Conn: c}, true)) 275 276 c.Close() 277 assert.False(t, tnettrans.HealthChecker(&connpool.PoolConn{Conn: c}, true)) 278 } 279 280 func TestNewConnectionPool(t *testing.T) { 281 p := tnettrans.NewConnectionPool() 282 assert.NotNil(t, p) 283 } 284 285 func startClientTest( 286 t *testing.T, 287 serverHandle func(ctx context.Context, req []byte) ([]byte, error), 288 svrCustomOpts []transport.ListenServeOption, 289 clientHandle func(addr string), 290 ) { 291 addr := getAddr() 292 s := transport.NewServerTransport() 293 handler := newUserDefineHandler(func(ctx context.Context, req []byte) ([]byte, error) { 294 return serverHandle(ctx, req) 295 }) 296 serveOpts := getListenServeOption( 297 transport.WithListenAddress(addr), 298 transport.WithHandler(handler), 299 ) 300 serveOpts = append(serveOpts, svrCustomOpts...) 301 err := s.ListenAndServe(context.Background(), serveOpts...) 302 assert.Nil(t, err) 303 304 clientHandle(addr) 305 } 306 307 type customPool struct{} 308 309 type customConn struct { 310 tnet.Conn 311 framer codec.Framer 312 } 313 314 func (c *customConn) ReadFrame() ([]byte, error) { 315 return c.framer.ReadFrame() 316 } 317 318 func (p *customPool) Get(network string, address string, opts connpool.GetOptions) (net.Conn, error) { 319 c, err := tnet.DialTCP(network, address, opts.DialTimeout) 320 if err != nil { 321 return nil, err 322 } 323 return &customConn{Conn: c, framer: opts.FramerBuilder.New(c)}, nil 324 } 325 326 func tnetRequest(ctx context.Context, req []byte, opts ...transport.RoundTripOption) ([]byte, error) { 327 ctx, _ = codec.EnsureMessage(ctx) 328 reqbytes, err := trpc.DefaultClientCodec.Encode( 329 codec.Message(ctx), 330 req, 331 ) 332 if err != nil { 333 return nil, err 334 } 335 336 cliOpts := getRoundTripOption(opts...) 337 clientTrans := tnettrans.NewClientTransport() 338 rspbytes, err := clientTrans.RoundTrip( 339 ctx, 340 reqbytes, 341 cliOpts..., 342 ) 343 if err != nil { 344 return nil, err 345 } 346 rsp, err := trpc.DefaultClientCodec.Decode( 347 codec.Message(ctx), 348 rspbytes, 349 ) 350 return rsp, err 351 }