go.uber.org/yarpc@v1.72.1/transport/http/handler_test.go (about)

     1  // Copyright (c) 2022 Uber Technologies, Inc.
     2  //
     3  // Permission is hereby granted, free of charge, to any person obtaining a copy
     4  // of this software and associated documentation files (the "Software"), to deal
     5  // in the Software without restriction, including without limitation the rights
     6  // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
     7  // copies of the Software, and to permit persons to whom the Software is
     8  // furnished to do so, subject to the following conditions:
     9  //
    10  // The above copyright notice and this permission notice shall be included in
    11  // all copies or substantial portions of the Software.
    12  //
    13  // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    14  // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    15  // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    16  // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    17  // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    18  // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
    19  // THE SOFTWARE.
    20  
    21  package http
    22  
    23  import (
    24  	"bytes"
    25  	"context"
    26  	"fmt"
    27  	"io/ioutil"
    28  	"net/http"
    29  	"net/http/httptest"
    30  	"strconv"
    31  	"strings"
    32  	"testing"
    33  	"time"
    34  
    35  	"github.com/golang/mock/gomock"
    36  	"github.com/opentracing/opentracing-go"
    37  	"github.com/stretchr/testify/assert"
    38  	"github.com/stretchr/testify/require"
    39  	yarpc "go.uber.org/yarpc"
    40  	"go.uber.org/yarpc/api/transport"
    41  	"go.uber.org/yarpc/api/transport/transporttest"
    42  	"go.uber.org/yarpc/encoding/raw"
    43  	"go.uber.org/yarpc/internal/routertest"
    44  	"go.uber.org/yarpc/yarpcerrors"
    45  )
    46  
    47  func TestHandlerSuccess(t *testing.T) {
    48  	mockCtrl := gomock.NewController(t)
    49  	defer mockCtrl.Finish()
    50  
    51  	headers := make(http.Header)
    52  	headers.Set(CallerHeader, "moe")
    53  	headers.Set(EncodingHeader, "raw")
    54  	headers.Set(TTLMSHeader, "1000")
    55  	headers.Set(ProcedureHeader, "nyuck")
    56  	headers.Set(ServiceHeader, "curly")
    57  	headers.Set(ShardKeyHeader, "shard")
    58  	headers.Set(RoutingKeyHeader, "routekey")
    59  	headers.Set(RoutingDelegateHeader, "routedelegate")
    60  	headers.Set(CallerProcedureHeader, "callerprocedure")
    61  
    62  	router := transporttest.NewMockRouter(mockCtrl)
    63  	rpcHandler := transporttest.NewMockUnaryHandler(mockCtrl)
    64  	spec := transport.NewUnaryHandlerSpec(rpcHandler)
    65  
    66  	router.EXPECT().Choose(gomock.Any(), routertest.NewMatcher().
    67  		WithService("curly").
    68  		WithProcedure("nyuck"),
    69  	).Return(spec, nil)
    70  
    71  	rpcHandler.EXPECT().Handle(
    72  		transporttest.NewContextMatcher(t,
    73  			transporttest.ContextTTL(time.Second),
    74  		),
    75  		transporttest.NewRequestMatcher(
    76  			t, &transport.Request{
    77  				Caller:          "moe",
    78  				Service:         "curly",
    79  				Transport:       "http",
    80  				Encoding:        raw.Encoding,
    81  				Procedure:       "nyuck",
    82  				ShardKey:        "shard",
    83  				RoutingKey:      "routekey",
    84  				RoutingDelegate: "routedelegate",
    85  				CallerProcedure: "callerprocedure",
    86  				Body:            bytes.NewReader([]byte("Nyuck Nyuck")),
    87  			},
    88  		),
    89  		gomock.Any(),
    90  	).Return(nil)
    91  
    92  	httpHandler := handler{router: router, tracer: &opentracing.NoopTracer{}, bothResponseError: true}
    93  	req := &http.Request{
    94  		Method: "POST",
    95  		Header: headers,
    96  		Body:   ioutil.NopCloser(bytes.NewReader([]byte("Nyuck Nyuck"))),
    97  	}
    98  	rw := httptest.NewRecorder()
    99  	httpHandler.ServeHTTP(rw, req)
   100  	code := rw.Code
   101  	assert.Equal(t, code, 200, "expected 200 code")
   102  	assert.Equal(t, rw.Body.String(), "")
   103  }
   104  
   105  func TestHandlerHeaders(t *testing.T) {
   106  	mockCtrl := gomock.NewController(t)
   107  	defer mockCtrl.Finish()
   108  
   109  	tests := []struct {
   110  		giveEncoding string
   111  		giveHeaders  http.Header
   112  		grabHeaders  map[string]struct{}
   113  
   114  		wantTTL     time.Duration
   115  		wantHeaders map[string]string
   116  	}{
   117  		{
   118  			giveEncoding: "json",
   119  			giveHeaders: http.Header{
   120  				TTLMSHeader:      {"1000"},
   121  				"Rpc-Header-Foo": {"bar"},
   122  				"X-Baz":          {"bat"},
   123  			},
   124  			grabHeaders: map[string]struct{}{"x-baz": {}},
   125  			wantTTL:     time.Second,
   126  			wantHeaders: map[string]string{
   127  				"foo":   "bar",
   128  				"x-baz": "bat",
   129  			},
   130  		},
   131  		{
   132  			giveEncoding: "raw",
   133  			giveHeaders: http.Header{
   134  				TTLMSHeader: {"100"},
   135  				"Rpc-Foo":   {"ignored"},
   136  			},
   137  			wantTTL:     100 * time.Millisecond,
   138  			wantHeaders: map[string]string{},
   139  		},
   140  		{
   141  			giveEncoding: "thrift",
   142  			giveHeaders: http.Header{
   143  				TTLMSHeader: {"1000"},
   144  			},
   145  			wantTTL:     time.Second,
   146  			wantHeaders: map[string]string{},
   147  		},
   148  		{
   149  			giveEncoding: "proto",
   150  			giveHeaders: http.Header{
   151  				TTLMSHeader: {"1000"},
   152  			},
   153  			wantTTL:     time.Second,
   154  			wantHeaders: map[string]string{},
   155  		},
   156  	}
   157  
   158  	for _, tt := range tests {
   159  		router := transporttest.NewMockRouter(mockCtrl)
   160  		rpcHandler := transporttest.NewMockUnaryHandler(mockCtrl)
   161  		spec := transport.NewUnaryHandlerSpec(rpcHandler)
   162  
   163  		router.EXPECT().Choose(gomock.Any(), routertest.NewMatcher().
   164  			WithService("service").
   165  			WithProcedure("hello"),
   166  		).Return(spec, nil)
   167  
   168  		httpHandler := handler{router: router, tracer: &opentracing.NoopTracer{}, grabHeaders: tt.grabHeaders, bothResponseError: true}
   169  
   170  		rpcHandler.EXPECT().Handle(
   171  			transporttest.NewContextMatcher(t,
   172  				transporttest.ContextTTL(tt.wantTTL),
   173  			),
   174  			transporttest.NewRequestMatcher(t,
   175  				&transport.Request{
   176  					Caller:    "caller",
   177  					Service:   "service",
   178  					Transport: "http",
   179  					Encoding:  transport.Encoding(tt.giveEncoding),
   180  					Procedure: "hello",
   181  					Headers:   transport.HeadersFromMap(tt.wantHeaders),
   182  					Body:      bytes.NewReader([]byte("world")),
   183  				}),
   184  			gomock.Any(),
   185  		).Return(nil)
   186  
   187  		headers := http.Header{}
   188  		for k, vs := range tt.giveHeaders {
   189  			for _, v := range vs {
   190  				headers.Add(k, v)
   191  			}
   192  		}
   193  		headers.Set(CallerHeader, "caller")
   194  		headers.Set(ServiceHeader, "service")
   195  		headers.Set(EncodingHeader, tt.giveEncoding)
   196  		headers.Set(ProcedureHeader, "hello")
   197  
   198  		req := &http.Request{
   199  			Method: "POST",
   200  			Header: headers,
   201  			Body:   ioutil.NopCloser(bytes.NewReader([]byte("world"))),
   202  		}
   203  		rw := httptest.NewRecorder()
   204  		httpHandler.ServeHTTP(rw, req)
   205  		assert.Equal(t, 200, rw.Code, "expected 200 status code")
   206  		assert.Equal(t, getContentType(transport.Encoding(tt.giveEncoding)), rw.Header().Get("Content-Type"))
   207  	}
   208  }
   209  
   210  func TestHandlerFailures(t *testing.T) {
   211  	mockCtrl := gomock.NewController(t)
   212  	defer mockCtrl.Finish()
   213  
   214  	service, procedure := "fake", "hello"
   215  
   216  	baseHeaders := make(http.Header)
   217  	baseHeaders.Set(CallerHeader, "somecaller")
   218  	baseHeaders.Set(EncodingHeader, "raw")
   219  	baseHeaders.Set(TTLMSHeader, "1000")
   220  	baseHeaders.Set(ProcedureHeader, procedure)
   221  	baseHeaders.Set(ServiceHeader, service)
   222  
   223  	headersWithBadTTL := headerCopyWithout(baseHeaders, TTLMSHeader)
   224  	headersWithBadTTL.Set(TTLMSHeader, "not a number")
   225  
   226  	tests := []struct {
   227  		req *http.Request
   228  
   229  		// if we expect an error as a result of the TTL
   230  		errTTL   bool
   231  		wantCode yarpcerrors.Code
   232  	}{
   233  		{
   234  			req:      &http.Request{Method: "GET"},
   235  			wantCode: yarpcerrors.CodeNotFound,
   236  		},
   237  		{
   238  			req: &http.Request{
   239  				Method: "POST",
   240  				Header: headerCopyWithout(baseHeaders, CallerHeader),
   241  			},
   242  			wantCode: yarpcerrors.CodeInvalidArgument,
   243  		},
   244  		{
   245  			req: &http.Request{
   246  				Method: "POST",
   247  				Header: headerCopyWithout(baseHeaders, ServiceHeader),
   248  			},
   249  			wantCode: yarpcerrors.CodeInvalidArgument,
   250  		},
   251  		{
   252  			req: &http.Request{
   253  				Method: "POST",
   254  				Header: headerCopyWithout(baseHeaders, ProcedureHeader),
   255  			},
   256  			wantCode: yarpcerrors.CodeInvalidArgument,
   257  		},
   258  		{
   259  			req: &http.Request{
   260  				Method: "POST",
   261  				Header: headerCopyWithout(baseHeaders, TTLMSHeader),
   262  			},
   263  			wantCode: yarpcerrors.CodeInvalidArgument,
   264  			errTTL:   true,
   265  		},
   266  		{
   267  			req: &http.Request{
   268  				Method: "POST",
   269  			},
   270  			wantCode: yarpcerrors.CodeInvalidArgument,
   271  		},
   272  		{
   273  			req: &http.Request{
   274  				Method: "POST",
   275  				Header: headersWithBadTTL,
   276  			},
   277  			wantCode: yarpcerrors.CodeInvalidArgument,
   278  			errTTL:   true,
   279  		},
   280  	}
   281  
   282  	for _, tt := range tests {
   283  		req := tt.req
   284  		if req.Body == nil {
   285  			req.Body = ioutil.NopCloser(bytes.NewReader([]byte{}))
   286  		}
   287  
   288  		reg := transporttest.NewMockRouter(mockCtrl)
   289  
   290  		if tt.errTTL {
   291  			// since TTL is checked after we've determined the transport type, if we have an
   292  			// error with TTL it will be discovered after we read from the router
   293  			spec := transport.NewUnaryHandlerSpec(panickedHandler{})
   294  			reg.EXPECT().Choose(gomock.Any(), routertest.NewMatcher().
   295  				WithService(service).
   296  				WithProcedure(procedure),
   297  			).Return(spec, nil)
   298  		}
   299  
   300  		h := handler{router: reg, tracer: &opentracing.NoopTracer{}, bothResponseError: true}
   301  
   302  		rw := httptest.NewRecorder()
   303  		h.ServeHTTP(rw, tt.req)
   304  
   305  		httpStatusCode := rw.Code
   306  		assert.True(t, httpStatusCode >= 400 && httpStatusCode < 500, "expected 400 level code")
   307  		code := statusCodeToBestCode(httpStatusCode)
   308  		assert.Equal(t, tt.wantCode, code)
   309  		assert.Equal(t, "text/plain; charset=utf8", rw.Header().Get("Content-Type"))
   310  	}
   311  }
   312  
   313  func TestHandlerInternalFailure(t *testing.T) {
   314  	mockCtrl := gomock.NewController(t)
   315  	defer mockCtrl.Finish()
   316  
   317  	headers := make(http.Header)
   318  	headers.Set(CallerHeader, "somecaller")
   319  	headers.Set(EncodingHeader, "raw")
   320  	headers.Set(TTLMSHeader, "1000")
   321  	headers.Set(ProcedureHeader, "hello")
   322  	headers.Set(ServiceHeader, "fake")
   323  
   324  	request := http.Request{
   325  		Method: "POST",
   326  		Header: headers,
   327  		Body:   ioutil.NopCloser(bytes.NewReader([]byte{})),
   328  	}
   329  
   330  	rpcHandler := transporttest.NewMockUnaryHandler(mockCtrl)
   331  	rpcHandler.EXPECT().Handle(
   332  		transporttest.NewContextMatcher(t, transporttest.ContextTTL(time.Second)),
   333  		transporttest.NewRequestMatcher(
   334  			t, &transport.Request{
   335  				Caller:    "somecaller",
   336  				Service:   "fake",
   337  				Transport: "http",
   338  				Encoding:  raw.Encoding,
   339  				Procedure: "hello",
   340  				Body:      bytes.NewReader([]byte{}),
   341  			},
   342  		),
   343  		gomock.Any(),
   344  	).Return(fmt.Errorf("great sadness"))
   345  
   346  	router := transporttest.NewMockRouter(mockCtrl)
   347  	spec := transport.NewUnaryHandlerSpec(rpcHandler)
   348  
   349  	router.EXPECT().Choose(gomock.Any(), routertest.NewMatcher().
   350  		WithService("fake").
   351  		WithProcedure("hello"),
   352  	).Return(spec, nil)
   353  
   354  	httpHandler := handler{router: router, tracer: &opentracing.NoopTracer{}, bothResponseError: true}
   355  	httpResponse := httptest.NewRecorder()
   356  	httpHandler.ServeHTTP(httpResponse, &request)
   357  
   358  	code := httpResponse.Code
   359  	assert.True(t, code >= 500 && code < 600, "expected 500 level response")
   360  	assert.Equal(t,
   361  		`error for service "fake" and procedure "hello": great sadness`+"\n",
   362  		httpResponse.Body.String())
   363  }
   364  
   365  type panickedHandler struct{}
   366  
   367  func (th panickedHandler) Handle(context.Context, *transport.Request, transport.ResponseWriter) error {
   368  	panic("oops I panicked!")
   369  }
   370  
   371  func TestHandlerPanic(t *testing.T) {
   372  	httpTransport := NewTransport()
   373  	inbound := httpTransport.NewInbound("localhost:0")
   374  	serverDispatcher := yarpc.NewDispatcher(yarpc.Config{
   375  		Name:     "yarpc-test",
   376  		Inbounds: yarpc.Inbounds{inbound},
   377  	})
   378  	serverDispatcher.Register([]transport.Procedure{
   379  		{
   380  			Name:        "panic",
   381  			HandlerSpec: transport.NewUnaryHandlerSpec(panickedHandler{}),
   382  		},
   383  	})
   384  
   385  	require.NoError(t, serverDispatcher.Start())
   386  	defer serverDispatcher.Stop()
   387  
   388  	clientDispatcher := yarpc.NewDispatcher(yarpc.Config{
   389  		Name: "yarpc-test-client",
   390  		Outbounds: yarpc.Outbounds{
   391  			"yarpc-test": {
   392  				Unary: httpTransport.NewSingleOutbound(fmt.Sprintf("http://%s", inbound.Addr().String())),
   393  			},
   394  		},
   395  	})
   396  	require.NoError(t, clientDispatcher.Start())
   397  	defer clientDispatcher.Stop()
   398  
   399  	client := raw.New(clientDispatcher.ClientConfig("yarpc-test"))
   400  	ctx, cancel := context.WithTimeout(context.Background(), time.Second)
   401  	defer cancel()
   402  	_, err := client.Call(ctx, "panic", []byte{})
   403  
   404  	assert.Equal(t, yarpcerrors.CodeUnknown, yarpcerrors.FromError(err).Code())
   405  }
   406  
   407  func headerCopyWithout(headers http.Header, names ...string) http.Header {
   408  	newHeaders := make(http.Header)
   409  	for k, vs := range headers {
   410  		for _, v := range vs {
   411  			newHeaders.Add(k, v)
   412  		}
   413  	}
   414  
   415  	for _, k := range names {
   416  		newHeaders.Del(k)
   417  	}
   418  
   419  	return newHeaders
   420  }
   421  
   422  func TestResponseWriter(t *testing.T) {
   423  	const (
   424  		appErrDetails = "thrift ex message"
   425  		appErrName    = "thrift ex name"
   426  	)
   427  	appErrCode := yarpcerrors.CodeAborted
   428  
   429  	recorder := httptest.NewRecorder()
   430  	writer := newResponseWriter(recorder)
   431  
   432  	headers := transport.HeadersFromMap(map[string]string{
   433  		"foo":       "bar",
   434  		"shard-key": "123",
   435  	})
   436  	writer.AddHeaders(headers)
   437  
   438  	writer.SetApplicationErrorMeta(&transport.ApplicationErrorMeta{
   439  		Details: appErrDetails,
   440  		Name:    appErrName,
   441  		Code:    &appErrCode,
   442  	})
   443  
   444  	_, err := writer.Write([]byte("hello"))
   445  	require.NoError(t, err)
   446  	writer.Close(http.StatusOK)
   447  
   448  	assert.Equal(t, "bar", recorder.Header().Get("rpc-header-foo"))
   449  	assert.Equal(t, "123", recorder.Header().Get("rpc-header-shard-key"))
   450  	assert.Equal(t, "hello", recorder.Body.String())
   451  
   452  	assert.Equal(t, appErrDetails, recorder.Header().Get(_applicationErrorDetailsHeader))
   453  	assert.Equal(t, appErrName, recorder.Header().Get(_applicationErrorNameHeader))
   454  	assert.Equal(t, strconv.Itoa(int(appErrCode)), recorder.Header().Get(_applicationErrorCodeHeader))
   455  }
   456  
   457  func TestTruncatedHeader(t *testing.T) {
   458  	tests := []struct {
   459  		name         string
   460  		value        string
   461  		wantTruncate bool
   462  	}{
   463  		{
   464  			name:  "no-op",
   465  			value: "foo bar",
   466  		},
   467  		{
   468  			name:  "max",
   469  			value: strings.Repeat("a", _maxAppErrDetailsHeaderLen),
   470  		},
   471  		{
   472  			name:         "truncate",
   473  			value:        strings.Repeat("b", _maxAppErrDetailsHeaderLen*2),
   474  			wantTruncate: true,
   475  		},
   476  	}
   477  
   478  	for _, tt := range tests {
   479  		t.Run(tt.name, func(t *testing.T) {
   480  			got := truncateAppErrDetails(tt.value)
   481  
   482  			if !tt.wantTruncate {
   483  				assert.Equal(t, tt.value, got, "expected no-op")
   484  				return
   485  			}
   486  
   487  			assert.True(t, strings.HasSuffix(got, _truncatedHeaderMessage), "unexpected truncate suffix")
   488  			assert.Len(t, got, _maxAppErrDetailsHeaderLen, "did not truncate")
   489  		})
   490  	}
   491  }