go.uber.org/yarpc@v1.72.1/transport/http/roundtrip_test.go (about)

     1  // Copyright (c) 2022 Uber Technologies, Inc.
     2  //
     3  // Permission is hereby granted, free of charge, to any person obtaining a copy
     4  // of this software and associated documentation files (the "Software"), to deal
     5  // in the Software without restriction, including without limitation the rights
     6  // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
     7  // copies of the Software, and to permit persons to whom the Software is
     8  // furnished to do so, subject to the following conditions:
     9  //
    10  // The above copyright notice and this permission notice shall be included in
    11  // all copies or substantial portions of the Software.
    12  //
    13  // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    14  // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    15  // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    16  // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    17  // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    18  // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
    19  // THE SOFTWARE.
    20  
    21  package http
    22  
    23  import (
    24  	"bytes"
    25  	"context"
    26  	"io/ioutil"
    27  	"net/http"
    28  	"net/http/httptest"
    29  	"testing"
    30  	"time"
    31  
    32  	"github.com/stretchr/testify/assert"
    33  	"github.com/stretchr/testify/require"
    34  	"go.uber.org/yarpc/api/transport"
    35  	"go.uber.org/yarpc/internal/testtime"
    36  	"go.uber.org/yarpc/yarpcerrors"
    37  )
    38  
    39  func TestRoundTripSuccess(t *testing.T) {
    40  	headerKey, headerVal := "foo", "bar"
    41  	giveBody := "successful response"
    42  
    43  	echoServer := httptest.NewServer(http.HandlerFunc(
    44  		func(w http.ResponseWriter, req *http.Request) {
    45  			defer req.Body.Close()
    46  
    47  			// copy header
    48  			header := req.Header.Get(headerKey)
    49  			w.Header().Set(headerKey, header)
    50  
    51  			// copy body
    52  			body, err := ioutil.ReadAll(req.Body)
    53  			if err != nil {
    54  				t.Error("error reading body")
    55  			}
    56  			_, err = w.Write(body)
    57  			if err != nil {
    58  				t.Error("error writing body")
    59  			}
    60  		},
    61  	))
    62  	defer echoServer.Close()
    63  
    64  	// start outbound
    65  	httpTransport := NewTransport()
    66  	defer httpTransport.Stop()
    67  	var out transport.UnaryOutbound = httpTransport.NewSingleOutbound(echoServer.URL)
    68  	require.NoError(t, out.Start(), "failed to start outbound")
    69  	defer out.Stop()
    70  
    71  	// create request
    72  	hreq := httptest.NewRequest("GET", echoServer.URL, bytes.NewReader([]byte(giveBody)))
    73  	hreq.Header.Add(headerKey, headerVal)
    74  
    75  	// add deadline
    76  	ctx, cancel := context.WithTimeout(context.Background(), testtime.Second)
    77  	defer cancel()
    78  	hreq = hreq.WithContext(ctx)
    79  
    80  	// make call
    81  	rt, ok := out.(http.RoundTripper)
    82  	assert.True(t, ok, "unable to convert an outbound to a http.RoundTripper")
    83  
    84  	res, err := rt.RoundTrip(hreq)
    85  	require.NoError(t, err, "could not make call")
    86  	defer res.Body.Close()
    87  
    88  	// validate header
    89  	gotHeaderVal := res.Header.Get(headerKey)
    90  	assert.Equal(t, headerVal, gotHeaderVal, "header did not match")
    91  
    92  	// validate body
    93  	gotBody, err := ioutil.ReadAll(res.Body)
    94  	require.NoError(t, err)
    95  	assert.Equal(t, giveBody, string(gotBody), "body did not match")
    96  }
    97  
    98  func TestRoundTripTimeout(t *testing.T) {
    99  	server := httptest.NewServer(http.HandlerFunc(
   100  		func(w http.ResponseWriter, r *http.Request) {
   101  			<-r.Context().Done() // never respond
   102  		}))
   103  	defer server.Close()
   104  
   105  	tran := NewTransport()
   106  	defer tran.Stop()
   107  	// start outbound
   108  	out := tran.NewSingleOutbound(server.URL)
   109  	require.NoError(t, out.Start(), "failed to start outbound")
   110  	defer out.Stop()
   111  
   112  	// create request
   113  	req, err := http.NewRequest("POST", server.URL, nil /* body */)
   114  	require.NoError(t, err)
   115  
   116  	// set a small deadline so the the call times out quickly
   117  	ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond)
   118  	defer cancel()
   119  	req = req.WithContext(ctx)
   120  
   121  	// make call
   122  	client := http.Client{Transport: out}
   123  	res, err := client.Do(req)
   124  
   125  	// validate response
   126  	if assert.Error(t, err) {
   127  		// we use a Contains here since the returned error is really a
   128  		// url.Error wrapping a yarpcerror
   129  		assert.Contains(t, err.Error(), yarpcerrors.CodeDeadlineExceeded.String())
   130  	}
   131  	assert.Equal(t, context.DeadlineExceeded, ctx.Err())
   132  	assert.Nil(t, res)
   133  }
   134  
   135  func TestRoundTripNoDeadline(t *testing.T) {
   136  	URL := "http://foo-host"
   137  
   138  	tran := NewTransport()
   139  	defer tran.Stop()
   140  	out := tran.NewSingleOutbound(URL)
   141  	require.NoError(t, out.Start(), "could not start outbound")
   142  	defer out.Stop()
   143  
   144  	hreq, err := http.NewRequest("GET", URL, nil /* body */)
   145  	require.NoError(t, err)
   146  
   147  	resp, err := out.RoundTrip(hreq)
   148  	assert.Equal(t, yarpcerrors.Newf(yarpcerrors.CodeInvalidArgument, "missing context deadline"), err)
   149  	assert.Nil(t, resp)
   150  }
   151  
   152  func TestRoundTripNotRunning(t *testing.T) {
   153  	URL := "http://foo-host"
   154  	out := NewTransport().NewSingleOutbound(URL)
   155  
   156  	req, err := http.NewRequest("POST", URL, nil /* body */)
   157  	require.NoError(t, err)
   158  
   159  	ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond)
   160  	defer cancel()
   161  	req = req.WithContext(ctx)
   162  
   163  	client := http.Client{Transport: out}
   164  	res, err := client.Do(req)
   165  
   166  	if assert.Error(t, err) {
   167  		assert.Contains(t, err.Error(), "waiting for HTTP outbound to start")
   168  	}
   169  	assert.Nil(t, res)
   170  }