go.uber.org/yarpc@v1.72.1/internal/testutils/testutils.go (about) 1 // Copyright (c) 2022 Uber Technologies, Inc. 2 // 3 // Permission is hereby granted, free of charge, to any person obtaining a copy 4 // of this software and associated documentation files (the "Software"), to deal 5 // in the Software without restriction, including without limitation the rights 6 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 // copies of the Software, and to permit persons to whom the Software is 8 // furnished to do so, subject to the following conditions: 9 // 10 // The above copyright notice and this permission notice shall be included in 11 // all copies or substantial portions of the Software. 12 // 13 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19 // THE SOFTWARE. 20 21 package testutils 22 23 import ( 24 "fmt" 25 "net" 26 "strconv" 27 28 "go.uber.org/multierr" 29 "go.uber.org/yarpc" 30 "go.uber.org/yarpc/api/transport" 31 "go.uber.org/yarpc/encoding/protobuf" 32 "go.uber.org/yarpc/internal/grpcctx" 33 "go.uber.org/yarpc/transport/grpc" 34 "go.uber.org/yarpc/transport/http" 35 "go.uber.org/yarpc/transport/tchannel" 36 "go.uber.org/zap" 37 ggrpc "google.golang.org/grpc" 38 ) 39 40 const ( 41 // TransportTypeHTTP represents using HTTP. 42 TransportTypeHTTP TransportType = iota 43 // TransportTypeTChannel represents using TChannel. 44 TransportTypeTChannel 45 // TransportTypeGRPC represents using GRPC. 46 TransportTypeGRPC 47 ) 48 49 var ( 50 // AllTransportTypes are all TransportTypes, 51 AllTransportTypes = []TransportType{ 52 TransportTypeHTTP, 53 TransportTypeTChannel, 54 TransportTypeGRPC, 55 } 56 ) 57 58 // TransportType is a transport type. 59 type TransportType int 60 61 // String returns a string representation of t. 62 func (t TransportType) String() string { 63 switch t { 64 case TransportTypeHTTP: 65 return "http" 66 case TransportTypeTChannel: 67 return "tchannel" 68 case TransportTypeGRPC: 69 return "grpc" 70 default: 71 return strconv.Itoa(int(t)) 72 } 73 } 74 75 // ParseTransportType parses a transport type from a string. 76 func ParseTransportType(s string) (TransportType, error) { 77 switch s { 78 case "http": 79 return TransportTypeHTTP, nil 80 case "tchannel": 81 return TransportTypeTChannel, nil 82 case "grpc": 83 return TransportTypeGRPC, nil 84 default: 85 return 0, fmt.Errorf("invalid TransportType: %s", s) 86 } 87 } 88 89 // ClientInfo holds the client info for testing. 90 type ClientInfo struct { 91 ClientConfig transport.ClientConfig 92 GRPCClientConn *ggrpc.ClientConn 93 ContextWrapper *grpcctx.ContextWrapper 94 } 95 96 // WithClientInfo wraps a function by setting up a client and server dispatcher and giving 97 // the function the client configuration to use in tests for the given TransportType. 98 // 99 // The server dispatcher will be brought up using all TransportTypes and with the serviceName. 100 // The client dispatcher will be brought up using the given TransportType for Unary, HTTP for 101 // Oneway, and the serviceName with a "-client" suffix. 102 func WithClientInfo(serviceName string, procedures []transport.Procedure, transportType TransportType, logger *zap.Logger, f func(*ClientInfo) error) (err error) { 103 if logger == nil { 104 logger = zap.NewNop() 105 } 106 dispatcherConfig, err := NewDispatcherConfig(serviceName) 107 if err != nil { 108 return err 109 } 110 serverDispatcher, err := NewServerDispatcher(procedures, dispatcherConfig, logger) 111 if err != nil { 112 return err 113 } 114 115 clientDispatcher, err := NewClientDispatcher(transportType, dispatcherConfig, logger) 116 if err != nil { 117 return err 118 } 119 120 if err := serverDispatcher.Start(); err != nil { 121 return err 122 } 123 defer func() { err = multierr.Append(err, serverDispatcher.Stop()) }() 124 125 if err := clientDispatcher.Start(); err != nil { 126 return err 127 } 128 defer func() { err = multierr.Append(err, clientDispatcher.Stop()) }() 129 grpcPort, err := dispatcherConfig.GetPort(TransportTypeGRPC) 130 if err != nil { 131 return err 132 } 133 grpcClientConn, err := ggrpc.Dial(fmt.Sprintf("127.0.0.1:%d", grpcPort), ggrpc.WithInsecure()) 134 if err != nil { 135 return err 136 } 137 return f( 138 &ClientInfo{ 139 clientDispatcher.ClientConfig(serviceName), 140 grpcClientConn, 141 grpcctx.NewContextWrapper(). 142 WithCaller(serviceName + "-client"). 143 WithService(serviceName). 144 WithEncoding(string(protobuf.Encoding)), 145 }, 146 ) 147 } 148 149 // NewClientDispatcher returns a new client Dispatcher. 150 // 151 // HTTP always will be configured as an outbound for Oneway. 152 // gRPC always will be configured as an outbound for Stream. 153 func NewClientDispatcher(transportType TransportType, config *DispatcherConfig, logger *zap.Logger) (*yarpc.Dispatcher, error) { 154 port, err := config.GetPort(transportType) 155 if err != nil { 156 return nil, err 157 } 158 httpPort, err := config.GetPort(TransportTypeHTTP) 159 if err != nil { 160 return nil, err 161 } 162 grpcPort, err := config.GetPort(TransportTypeGRPC) 163 if err != nil { 164 return nil, err 165 } 166 onewayOutbound := http.NewTransport(http.Logger(logger)).NewSingleOutbound(fmt.Sprintf("http://127.0.0.1:%d", httpPort)) 167 streamOutbound := grpc.NewTransport(grpc.Logger(logger)).NewSingleOutbound(fmt.Sprintf("127.0.0.1:%d", grpcPort)) 168 var unaryOutbound transport.UnaryOutbound 169 switch transportType { 170 case TransportTypeTChannel: 171 tchannelTransport, err := tchannel.NewChannelTransport(tchannel.ServiceName(config.GetServiceName()), tchannel.Logger(logger)) 172 if err != nil { 173 return nil, err 174 } 175 unaryOutbound = tchannelTransport.NewSingleOutbound(fmt.Sprintf("127.0.0.1:%d", port)) 176 case TransportTypeHTTP: 177 unaryOutbound = onewayOutbound 178 case TransportTypeGRPC: 179 unaryOutbound = streamOutbound 180 default: 181 return nil, fmt.Errorf("invalid TransportType: %v", transportType) 182 } 183 return yarpc.NewDispatcher( 184 yarpc.Config{ 185 Name: fmt.Sprintf("%s-client", config.GetServiceName()), 186 Outbounds: yarpc.Outbounds{ 187 config.GetServiceName(): { 188 Oneway: onewayOutbound, 189 Unary: unaryOutbound, 190 Stream: streamOutbound, 191 }, 192 }, 193 }, 194 ), nil 195 } 196 197 // NewServerDispatcher returns a new server Dispatcher. 198 func NewServerDispatcher(procedures []transport.Procedure, config *DispatcherConfig, logger *zap.Logger) (*yarpc.Dispatcher, error) { 199 tchannelPort, err := config.GetPort(TransportTypeTChannel) 200 if err != nil { 201 return nil, err 202 } 203 httpPort, err := config.GetPort(TransportTypeHTTP) 204 if err != nil { 205 return nil, err 206 } 207 grpcPort, err := config.GetPort(TransportTypeGRPC) 208 if err != nil { 209 return nil, err 210 } 211 tchannelTransport, err := tchannel.NewChannelTransport( 212 tchannel.ServiceName(config.GetServiceName()), 213 tchannel.ListenAddr(fmt.Sprintf("127.0.0.1:%d", tchannelPort)), 214 tchannel.Logger(logger), 215 ) 216 if err != nil { 217 return nil, err 218 } 219 grpcListener, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", grpcPort)) 220 if err != nil { 221 return nil, err 222 } 223 dispatcher := yarpc.NewDispatcher( 224 yarpc.Config{ 225 Name: config.GetServiceName(), 226 Inbounds: yarpc.Inbounds{ 227 tchannelTransport.NewInbound(), 228 http.NewTransport(http.Logger(logger)).NewInbound(fmt.Sprintf("127.0.0.1:%d", httpPort)), 229 grpc.NewTransport(grpc.Logger(logger)).NewInbound(grpcListener), 230 }, 231 }, 232 ) 233 dispatcher.Register(procedures) 234 return dispatcher, nil 235 } 236 237 // DispatcherConfig is the configuration for a Dispatcher. 238 type DispatcherConfig struct { 239 serviceName string 240 transportTypeToPort map[TransportType]uint16 241 } 242 243 // NewDispatcherConfig returns a new DispatcherConfig with assigned ports. 244 func NewDispatcherConfig(serviceName string) (*DispatcherConfig, error) { 245 transportTypeToPort, err := getTransportTypeToPort() 246 if err != nil { 247 return nil, err 248 } 249 return &DispatcherConfig{ 250 serviceName, 251 transportTypeToPort, 252 }, nil 253 } 254 255 // GetServiceName gets the service name. 256 func (d *DispatcherConfig) GetServiceName() string { 257 return d.serviceName 258 } 259 260 // GetPort gets the port for the TransportType. 261 func (d *DispatcherConfig) GetPort(transportType TransportType) (uint16, error) { 262 port, ok := d.transportTypeToPort[transportType] 263 if !ok { 264 return 0, fmt.Errorf("no port for TransportType %v", transportType) 265 } 266 return port, nil 267 } 268 269 func getTransportTypeToPort() (map[TransportType]uint16, error) { 270 m := make(map[TransportType]uint16, len(AllTransportTypes)) 271 for _, transportType := range AllTransportTypes { 272 port, err := getFreePort() 273 if err != nil { 274 return nil, err 275 } 276 m[transportType] = port 277 } 278 return m, nil 279 } 280 281 func getFreePort() (uint16, error) { 282 address, err := net.ResolveTCPAddr("tcp", "127.0.0.1:0") 283 if err != nil { 284 return 0, err 285 } 286 287 listener, err := net.ListenTCP("tcp", address) 288 if err != nil { 289 return 0, err 290 } 291 port := uint16(listener.Addr().(*net.TCPAddr).Port) 292 if err := listener.Close(); err != nil { 293 return 0, err 294 } 295 return port, nil 296 }