go.uber.org/yarpc@v1.72.1/x/yarpctest/request_unary.go (about) 1 // Copyright (c) 2022 Uber Technologies, Inc. 2 // 3 // Permission is hereby granted, free of charge, to any person obtaining a copy 4 // of this software and associated documentation files (the "Software"), to deal 5 // in the Software without restriction, including without limitation the rights 6 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 // copies of the Software, and to permit persons to whom the Software is 8 // furnished to do so, subject to the following conditions: 9 // 10 // The above copyright notice and this permission notice shall be included in 11 // all copies or substantial portions of the Software. 12 // 13 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19 // THE SOFTWARE. 20 21 package yarpctest 22 23 import ( 24 "bytes" 25 "context" 26 "errors" 27 "fmt" 28 "io/ioutil" 29 "testing" 30 "time" 31 32 "github.com/stretchr/testify/assert" 33 "github.com/stretchr/testify/require" 34 "go.uber.org/yarpc" 35 "go.uber.org/yarpc/api/middleware" 36 "go.uber.org/yarpc/api/transport" 37 "go.uber.org/yarpc/transport/grpc" 38 "go.uber.org/yarpc/transport/http" 39 "go.uber.org/yarpc/transport/tchannel" 40 "go.uber.org/yarpc/x/yarpctest/api" 41 ) 42 43 // HTTPRequest creates a new YARPC http request. 44 func HTTPRequest(options ...api.RequestOption) api.Action { 45 return api.ActionFunc(func(t testing.TB) { 46 opts := api.NewRequestOpts() 47 for _, option := range options { 48 option.ApplyRequest(&opts) 49 } 50 51 trans := http.NewTransport() 52 httpOut := trans.NewSingleOutbound(fmt.Sprintf("http://127.0.0.1:%d/", opts.Port)) 53 out := middleware.ApplyUnaryOutbound(httpOut, yarpc.UnaryOutboundMiddleware(opts.UnaryMiddleware...)) 54 55 require.NoError(t, trans.Start()) 56 defer func() { assert.NoError(t, trans.Stop()) }() 57 58 require.NoError(t, out.Start()) 59 defer func() { assert.NoError(t, out.Stop()) }() 60 61 sendRequestAndValidateResp(t, out, opts) 62 }) 63 } 64 65 // TChannelRequest creates a new tchannel request. 66 func TChannelRequest(options ...api.RequestOption) api.Action { 67 return api.ActionFunc(func(t testing.TB) { 68 opts := api.NewRequestOpts() 69 for _, option := range options { 70 option.ApplyRequest(&opts) 71 } 72 73 trans, err := tchannel.NewTransport(tchannel.ServiceName(opts.GiveRequest.Caller)) 74 require.NoError(t, err) 75 tchannelOut := trans.NewSingleOutbound(fmt.Sprintf("127.0.0.1:%d", opts.Port)) 76 out := middleware.ApplyUnaryOutbound(tchannelOut, yarpc.UnaryOutboundMiddleware(opts.UnaryMiddleware...)) 77 78 require.NoError(t, trans.Start()) 79 defer func() { assert.NoError(t, trans.Stop()) }() 80 81 require.NoError(t, out.Start()) 82 defer func() { assert.NoError(t, out.Stop()) }() 83 84 sendRequestAndValidateResp(t, out, opts) 85 }) 86 } 87 88 // GRPCRequest creates a new grpc unary request. 89 func GRPCRequest(options ...api.RequestOption) api.Action { 90 return api.ActionFunc(func(t testing.TB) { 91 opts := api.NewRequestOpts() 92 for _, option := range options { 93 option.ApplyRequest(&opts) 94 } 95 96 trans := grpc.NewTransport() 97 grpcOut := trans.NewSingleOutbound(fmt.Sprintf("127.0.0.1:%d", opts.Port)) 98 out := middleware.ApplyUnaryOutbound(grpcOut, yarpc.UnaryOutboundMiddleware(opts.UnaryMiddleware...)) 99 100 require.NoError(t, trans.Start()) 101 defer func() { assert.NoError(t, trans.Stop()) }() 102 103 require.NoError(t, out.Start()) 104 defer func() { assert.NoError(t, out.Stop()) }() 105 106 sendRequestAndValidateResp(t, out, opts) 107 }) 108 } 109 110 func sendRequestAndValidateResp(t testing.TB, out transport.UnaryOutbound, opts api.RequestOpts) { 111 f := func(i int) bool { 112 resp, cancel, err := sendRequest(out, opts.GiveRequest, opts.GiveTimeout) 113 defer cancel() 114 115 if i == opts.RetryCount { 116 validateError(t, err, opts.WantError) 117 if opts.WantError == nil { 118 validateResponse(t, resp, opts.WantResponse) 119 } 120 return true 121 } 122 123 if err != nil || matchResponse(resp, opts.WantResponse) != nil { 124 return false 125 } 126 127 return true 128 } 129 130 for i := 0; i < opts.RetryCount+1; i++ { 131 if ok := f(i); ok { 132 return 133 } 134 time.Sleep(opts.RetryInterval) 135 } 136 } 137 138 func sendRequest(out transport.UnaryOutbound, request *transport.Request, timeout time.Duration) (*transport.Response, context.CancelFunc, error) { 139 ctx, cancel := context.WithTimeout(context.Background(), timeout) 140 resp, err := out.Call(ctx, request) 141 return resp, cancel, err 142 } 143 144 func validateError(t testing.TB, actualErr error, wantError error) { 145 if wantError != nil { 146 require.Error(t, actualErr) 147 require.Contains(t, actualErr.Error(), wantError.Error()) 148 return 149 } 150 require.NoError(t, actualErr) 151 } 152 153 func validateResponse(t testing.TB, actualResp *transport.Response, expectedResp *transport.Response) { 154 require.NoError(t, matchResponse(actualResp, expectedResp), "response mismatch") 155 } 156 157 func matchResponse(actualResp *transport.Response, expectedResp *transport.Response) error { 158 var actualBody []byte 159 var expectedBody []byte 160 var err error 161 if actualResp.Body != nil { 162 actualBody, err = ioutil.ReadAll(actualResp.Body) 163 if err != nil { 164 return fmt.Errorf("failed to read response body") 165 } 166 } 167 if expectedResp.Body != nil { 168 expectedBody, err = ioutil.ReadAll(expectedResp.Body) 169 if err != nil { 170 return fmt.Errorf("failed to read response body") 171 } 172 } 173 if string(actualBody) != string(expectedBody) { 174 return fmt.Errorf("response body mismatch, expect %s, got %s", 175 expectedBody, actualBody) 176 } 177 for k, v := range expectedResp.Headers.Items() { 178 actualValue, ok := actualResp.Headers.Get(k) 179 if !ok { 180 return fmt.Errorf("headler %q was not set on the response", k) 181 } 182 if actualValue != v { 183 return fmt.Errorf("headers mismatch for %q, expected %v, got %v", 184 k, v, actualValue) 185 } 186 } 187 return nil 188 } 189 190 // UNARY-SPECIFIC REQUEST OPTIONS 191 192 // Body sets the body on a request to the raw representation of the msg field. 193 func Body(msg string) api.RequestOption { 194 return api.RequestOptionFunc(func(opts *api.RequestOpts) { 195 opts.GiveRequest.Body = bytes.NewBufferString(msg) 196 }) 197 } 198 199 // GiveTimeout will set the timeout for the request. 200 func GiveTimeout(duration time.Duration) api.RequestOption { 201 return api.RequestOptionFunc(func(opts *api.RequestOpts) { 202 opts.GiveTimeout = duration 203 }) 204 } 205 206 // UnaryOutboundMiddleware sets unary outbound middleware for a request. 207 // 208 // Multiple invocations will append to existing middleware. 209 func UnaryOutboundMiddleware(mw ...middleware.UnaryOutbound) api.RequestOption { 210 return api.RequestOptionFunc(func(opts *api.RequestOpts) { 211 opts.UnaryMiddleware = append(opts.UnaryMiddleware, mw...) 212 }) 213 } 214 215 // WantError creates an assertion on the request response to validate the 216 // error. 217 func WantError(errMsg string) api.RequestOption { 218 return api.RequestOptionFunc(func(opts *api.RequestOpts) { 219 opts.WantError = errors.New(errMsg) 220 }) 221 } 222 223 // WantRespBody will assert that the response body matches at the end of the 224 // request. 225 func WantRespBody(body string) api.RequestOption { 226 return api.RequestOptionFunc(func(opts *api.RequestOpts) { 227 opts.WantResponse.Body = ioutil.NopCloser(bytes.NewBufferString(body)) 228 }) 229 } 230 231 // GiveAndWantLargeBodyIsEchoed creates an extremely large random byte buffer 232 // and validates that the body is echoed back to the response. 233 func GiveAndWantLargeBodyIsEchoed(numOfBytes int) api.RequestOption { 234 return api.RequestOptionFunc(func(opts *api.RequestOpts) { 235 body := bytes.Repeat([]byte("t"), numOfBytes) 236 opts.GiveRequest.Body = bytes.NewReader(body) 237 opts.WantResponse.Body = ioutil.NopCloser(bytes.NewReader(body)) 238 }) 239 } 240 241 // Retry retries the request for a given times, until the request succeeds 242 // and the response matches. 243 func Retry(count int, interval time.Duration) api.RequestOption { 244 return api.RequestOptionFunc(func(opts *api.RequestOpts) { 245 opts.RetryCount = count 246 opts.RetryInterval = interval 247 }) 248 }