go.uber.org/yarpc@v1.72.1/x/yarpctest/request_unary.go (about)

     1  // Copyright (c) 2022 Uber Technologies, Inc.
     2  //
     3  // Permission is hereby granted, free of charge, to any person obtaining a copy
     4  // of this software and associated documentation files (the "Software"), to deal
     5  // in the Software without restriction, including without limitation the rights
     6  // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
     7  // copies of the Software, and to permit persons to whom the Software is
     8  // furnished to do so, subject to the following conditions:
     9  //
    10  // The above copyright notice and this permission notice shall be included in
    11  // all copies or substantial portions of the Software.
    12  //
    13  // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    14  // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    15  // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    16  // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    17  // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    18  // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
    19  // THE SOFTWARE.
    20  
    21  package yarpctest
    22  
    23  import (
    24  	"bytes"
    25  	"context"
    26  	"errors"
    27  	"fmt"
    28  	"io/ioutil"
    29  	"testing"
    30  	"time"
    31  
    32  	"github.com/stretchr/testify/assert"
    33  	"github.com/stretchr/testify/require"
    34  	"go.uber.org/yarpc"
    35  	"go.uber.org/yarpc/api/middleware"
    36  	"go.uber.org/yarpc/api/transport"
    37  	"go.uber.org/yarpc/transport/grpc"
    38  	"go.uber.org/yarpc/transport/http"
    39  	"go.uber.org/yarpc/transport/tchannel"
    40  	"go.uber.org/yarpc/x/yarpctest/api"
    41  )
    42  
    43  // HTTPRequest creates a new YARPC http request.
    44  func HTTPRequest(options ...api.RequestOption) api.Action {
    45  	return api.ActionFunc(func(t testing.TB) {
    46  		opts := api.NewRequestOpts()
    47  		for _, option := range options {
    48  			option.ApplyRequest(&opts)
    49  		}
    50  
    51  		trans := http.NewTransport()
    52  		httpOut := trans.NewSingleOutbound(fmt.Sprintf("http://127.0.0.1:%d/", opts.Port))
    53  		out := middleware.ApplyUnaryOutbound(httpOut, yarpc.UnaryOutboundMiddleware(opts.UnaryMiddleware...))
    54  
    55  		require.NoError(t, trans.Start())
    56  		defer func() { assert.NoError(t, trans.Stop()) }()
    57  
    58  		require.NoError(t, out.Start())
    59  		defer func() { assert.NoError(t, out.Stop()) }()
    60  
    61  		sendRequestAndValidateResp(t, out, opts)
    62  	})
    63  }
    64  
    65  // TChannelRequest creates a new tchannel request.
    66  func TChannelRequest(options ...api.RequestOption) api.Action {
    67  	return api.ActionFunc(func(t testing.TB) {
    68  		opts := api.NewRequestOpts()
    69  		for _, option := range options {
    70  			option.ApplyRequest(&opts)
    71  		}
    72  
    73  		trans, err := tchannel.NewTransport(tchannel.ServiceName(opts.GiveRequest.Caller))
    74  		require.NoError(t, err)
    75  		tchannelOut := trans.NewSingleOutbound(fmt.Sprintf("127.0.0.1:%d", opts.Port))
    76  		out := middleware.ApplyUnaryOutbound(tchannelOut, yarpc.UnaryOutboundMiddleware(opts.UnaryMiddleware...))
    77  
    78  		require.NoError(t, trans.Start())
    79  		defer func() { assert.NoError(t, trans.Stop()) }()
    80  
    81  		require.NoError(t, out.Start())
    82  		defer func() { assert.NoError(t, out.Stop()) }()
    83  
    84  		sendRequestAndValidateResp(t, out, opts)
    85  	})
    86  }
    87  
    88  // GRPCRequest creates a new grpc unary request.
    89  func GRPCRequest(options ...api.RequestOption) api.Action {
    90  	return api.ActionFunc(func(t testing.TB) {
    91  		opts := api.NewRequestOpts()
    92  		for _, option := range options {
    93  			option.ApplyRequest(&opts)
    94  		}
    95  
    96  		trans := grpc.NewTransport()
    97  		grpcOut := trans.NewSingleOutbound(fmt.Sprintf("127.0.0.1:%d", opts.Port))
    98  		out := middleware.ApplyUnaryOutbound(grpcOut, yarpc.UnaryOutboundMiddleware(opts.UnaryMiddleware...))
    99  
   100  		require.NoError(t, trans.Start())
   101  		defer func() { assert.NoError(t, trans.Stop()) }()
   102  
   103  		require.NoError(t, out.Start())
   104  		defer func() { assert.NoError(t, out.Stop()) }()
   105  
   106  		sendRequestAndValidateResp(t, out, opts)
   107  	})
   108  }
   109  
   110  func sendRequestAndValidateResp(t testing.TB, out transport.UnaryOutbound, opts api.RequestOpts) {
   111  	f := func(i int) bool {
   112  		resp, cancel, err := sendRequest(out, opts.GiveRequest, opts.GiveTimeout)
   113  		defer cancel()
   114  
   115  		if i == opts.RetryCount {
   116  			validateError(t, err, opts.WantError)
   117  			if opts.WantError == nil {
   118  				validateResponse(t, resp, opts.WantResponse)
   119  			}
   120  			return true
   121  		}
   122  
   123  		if err != nil || matchResponse(resp, opts.WantResponse) != nil {
   124  			return false
   125  		}
   126  
   127  		return true
   128  	}
   129  
   130  	for i := 0; i < opts.RetryCount+1; i++ {
   131  		if ok := f(i); ok {
   132  			return
   133  		}
   134  		time.Sleep(opts.RetryInterval)
   135  	}
   136  }
   137  
   138  func sendRequest(out transport.UnaryOutbound, request *transport.Request, timeout time.Duration) (*transport.Response, context.CancelFunc, error) {
   139  	ctx, cancel := context.WithTimeout(context.Background(), timeout)
   140  	resp, err := out.Call(ctx, request)
   141  	return resp, cancel, err
   142  }
   143  
   144  func validateError(t testing.TB, actualErr error, wantError error) {
   145  	if wantError != nil {
   146  		require.Error(t, actualErr)
   147  		require.Contains(t, actualErr.Error(), wantError.Error())
   148  		return
   149  	}
   150  	require.NoError(t, actualErr)
   151  }
   152  
   153  func validateResponse(t testing.TB, actualResp *transport.Response, expectedResp *transport.Response) {
   154  	require.NoError(t, matchResponse(actualResp, expectedResp), "response mismatch")
   155  }
   156  
   157  func matchResponse(actualResp *transport.Response, expectedResp *transport.Response) error {
   158  	var actualBody []byte
   159  	var expectedBody []byte
   160  	var err error
   161  	if actualResp.Body != nil {
   162  		actualBody, err = ioutil.ReadAll(actualResp.Body)
   163  		if err != nil {
   164  			return fmt.Errorf("failed to read response body")
   165  		}
   166  	}
   167  	if expectedResp.Body != nil {
   168  		expectedBody, err = ioutil.ReadAll(expectedResp.Body)
   169  		if err != nil {
   170  			return fmt.Errorf("failed to read response body")
   171  		}
   172  	}
   173  	if string(actualBody) != string(expectedBody) {
   174  		return fmt.Errorf("response body mismatch, expect %s, got %s",
   175  			expectedBody, actualBody)
   176  	}
   177  	for k, v := range expectedResp.Headers.Items() {
   178  		actualValue, ok := actualResp.Headers.Get(k)
   179  		if !ok {
   180  			return fmt.Errorf("headler %q was not set on the response", k)
   181  		}
   182  		if actualValue != v {
   183  			return fmt.Errorf("headers mismatch for %q, expected %v, got %v",
   184  				k, v, actualValue)
   185  		}
   186  	}
   187  	return nil
   188  }
   189  
   190  // UNARY-SPECIFIC REQUEST OPTIONS
   191  
   192  // Body sets the body on a request to the raw representation of the msg field.
   193  func Body(msg string) api.RequestOption {
   194  	return api.RequestOptionFunc(func(opts *api.RequestOpts) {
   195  		opts.GiveRequest.Body = bytes.NewBufferString(msg)
   196  	})
   197  }
   198  
   199  // GiveTimeout will set the timeout for the request.
   200  func GiveTimeout(duration time.Duration) api.RequestOption {
   201  	return api.RequestOptionFunc(func(opts *api.RequestOpts) {
   202  		opts.GiveTimeout = duration
   203  	})
   204  }
   205  
   206  // UnaryOutboundMiddleware sets unary outbound middleware for a request.
   207  //
   208  // Multiple invocations will append to existing middleware.
   209  func UnaryOutboundMiddleware(mw ...middleware.UnaryOutbound) api.RequestOption {
   210  	return api.RequestOptionFunc(func(opts *api.RequestOpts) {
   211  		opts.UnaryMiddleware = append(opts.UnaryMiddleware, mw...)
   212  	})
   213  }
   214  
   215  // WantError creates an assertion on the request response to validate the
   216  // error.
   217  func WantError(errMsg string) api.RequestOption {
   218  	return api.RequestOptionFunc(func(opts *api.RequestOpts) {
   219  		opts.WantError = errors.New(errMsg)
   220  	})
   221  }
   222  
   223  // WantRespBody will assert that the response body matches at the end of the
   224  // request.
   225  func WantRespBody(body string) api.RequestOption {
   226  	return api.RequestOptionFunc(func(opts *api.RequestOpts) {
   227  		opts.WantResponse.Body = ioutil.NopCloser(bytes.NewBufferString(body))
   228  	})
   229  }
   230  
   231  // GiveAndWantLargeBodyIsEchoed creates an extremely large random byte buffer
   232  // and validates that the body is echoed back to the response.
   233  func GiveAndWantLargeBodyIsEchoed(numOfBytes int) api.RequestOption {
   234  	return api.RequestOptionFunc(func(opts *api.RequestOpts) {
   235  		body := bytes.Repeat([]byte("t"), numOfBytes)
   236  		opts.GiveRequest.Body = bytes.NewReader(body)
   237  		opts.WantResponse.Body = ioutil.NopCloser(bytes.NewReader(body))
   238  	})
   239  }
   240  
   241  // Retry retries the request for a given times, until the request succeeds
   242  // and the response matches.
   243  func Retry(count int, interval time.Duration) api.RequestOption {
   244  	return api.RequestOptionFunc(func(opts *api.RequestOpts) {
   245  		opts.RetryCount = count
   246  		opts.RetryInterval = interval
   247  	})
   248  }