go.uber.org/yarpc@v1.72.1/transport/http/inbound_test.go (about)

     1  // Copyright (c) 2022 Uber Technologies, Inc.
     2  //
     3  // Permission is hereby granted, free of charge, to any person obtaining a copy
     4  // of this software and associated documentation files (the "Software"), to deal
     5  // in the Software without restriction, including without limitation the rights
     6  // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
     7  // copies of the Software, and to permit persons to whom the Software is
     8  // furnished to do so, subject to the following conditions:
     9  //
    10  // The above copyright notice and this permission notice shall be included in
    11  // all copies or substantial portions of the Software.
    12  //
    13  // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    14  // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    15  // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    16  // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    17  // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    18  // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
    19  // THE SOFTWARE.
    20  
    21  package http
    22  
    23  import (
    24  	"bytes"
    25  	"context"
    26  	"fmt"
    27  	"io"
    28  	"io/ioutil"
    29  	"net"
    30  	"net/http"
    31  	"os"
    32  	"syscall"
    33  	"testing"
    34  
    35  	"github.com/golang/mock/gomock"
    36  	"github.com/stretchr/testify/assert"
    37  	"github.com/stretchr/testify/require"
    38  	"go.uber.org/yarpc"
    39  	"go.uber.org/yarpc/api/transport"
    40  	"go.uber.org/yarpc/api/transport/transporttest"
    41  	"go.uber.org/yarpc/encoding/raw"
    42  	"go.uber.org/yarpc/internal/routertest"
    43  	"go.uber.org/yarpc/internal/testtime"
    44  	"go.uber.org/yarpc/internal/yarpctest"
    45  	"go.uber.org/yarpc/yarpcerrors"
    46  )
    47  
    48  func TestStartAddrInUse(t *testing.T) {
    49  	t1 := NewTransport()
    50  	i1 := t1.NewInbound("127.0.0.1:0")
    51  
    52  	assert.Len(t, i1.Transports(), 1, "transports must contain the transport")
    53  	// we use == instead of assert.Equal because we want to do a pointer
    54  	// comparison
    55  	assert.True(t, t1 == i1.Transports()[0], "transports must match")
    56  
    57  	i1.SetRouter(newTestRouter(nil))
    58  	require.NoError(t, i1.Start(), "inbound 1 must start without an error")
    59  	t2 := NewTransport()
    60  	i2 := t2.NewInbound(i1.Addr().String())
    61  	i2.SetRouter(newTestRouter(nil))
    62  	err := i2.Start()
    63  
    64  	require.Error(t, err)
    65  	oe, ok := err.(*net.OpError)
    66  	assert.True(t, ok && oe.Op == "listen", "expected a listen error")
    67  	if ok {
    68  		se, ok := oe.Err.(*os.SyscallError)
    69  		assert.True(t, ok && se.Syscall == "bind" && se.Err == syscall.EADDRINUSE, "expected a EADDRINUSE bind error")
    70  	}
    71  
    72  	assert.NoError(t, i1.Stop())
    73  }
    74  
    75  func TestNilAddrAfterStop(t *testing.T) {
    76  	x := NewTransport()
    77  	i := x.NewInbound("127.0.0.1:0")
    78  	i.SetRouter(newTestRouter(nil))
    79  	require.NoError(t, i.Start())
    80  	assert.NotEqual(t, "127.0.0.1:0", i.Addr().String())
    81  	assert.NotNil(t, i.Addr())
    82  	assert.NoError(t, i.Stop())
    83  	assert.Nil(t, i.Addr())
    84  }
    85  
    86  func TestInboundStartAndStop(t *testing.T) {
    87  	x := NewTransport()
    88  	i := x.NewInbound("127.0.0.1:0")
    89  	i.SetRouter(newTestRouter(nil))
    90  	require.NoError(t, i.Start())
    91  	assert.NotEqual(t, "127.0.0.1:0", i.Addr().String())
    92  	assert.NoError(t, i.Stop())
    93  }
    94  
    95  func TestInboundStartError(t *testing.T) {
    96  	x := NewTransport()
    97  	i := x.NewInbound("invalid")
    98  	i.SetRouter(new(transporttest.MockRouter))
    99  	assert.Error(t, i.Start(), "expected failure")
   100  }
   101  
   102  func TestInboundStartErrorBadGrabHeader(t *testing.T) {
   103  	x := NewTransport()
   104  	i := x.NewInbound("127.0.0.1:0", GrabHeaders("x-valid", "y-invalid"))
   105  	i.SetRouter(new(transporttest.MockRouter))
   106  	assert.Equal(t, yarpcerrors.CodeInvalidArgument, yarpcerrors.FromError(i.Start()).Code())
   107  }
   108  
   109  func TestInboundStopWithoutStarting(t *testing.T) {
   110  	x := NewTransport()
   111  	i := x.NewInbound("127.0.0.1:8000")
   112  	assert.Nil(t, i.Addr())
   113  	assert.NoError(t, i.Stop())
   114  }
   115  
   116  func TestInboundMux(t *testing.T) {
   117  	mockCtrl := gomock.NewController(t)
   118  	defer mockCtrl.Finish()
   119  
   120  	httpTransport := NewTransport()
   121  	defer httpTransport.Stop()
   122  	// TODO transport lifecycle
   123  
   124  	mux := http.NewServeMux()
   125  	mux.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) {
   126  		w.Write([]byte("healthy"))
   127  	})
   128  
   129  	i := httpTransport.NewInbound("127.0.0.1:0", Mux("/rpc/v1", mux))
   130  	h := transporttest.NewMockUnaryHandler(mockCtrl)
   131  	reg := transporttest.NewMockRouter(mockCtrl)
   132  	reg.EXPECT().Procedures()
   133  	i.SetRouter(reg)
   134  	require.NoError(t, i.Start())
   135  
   136  	defer i.Stop()
   137  
   138  	addr := fmt.Sprintf("http://%v/", yarpctest.ZeroAddrToHostPort(i.Addr()))
   139  	resp, err := http.Get(addr + "health")
   140  	if assert.NoError(t, err, "/health failed") {
   141  		defer resp.Body.Close()
   142  		body, err := ioutil.ReadAll(resp.Body)
   143  		if assert.NoError(t, err, "/health body read error") {
   144  			assert.Equal(t, "healthy", string(body), "/health body mismatch")
   145  		}
   146  	}
   147  
   148  	// this should fail
   149  	o := httpTransport.NewSingleOutbound(addr)
   150  	require.NoError(t, o.Start(), "failed to start outbound")
   151  	defer o.Stop()
   152  
   153  	ctx, cancel := context.WithTimeout(context.Background(), testtime.Second)
   154  	defer cancel()
   155  	_, err = o.Call(ctx, &transport.Request{
   156  		Caller:    "foo",
   157  		Service:   "bar",
   158  		Procedure: "hello",
   159  		Encoding:  raw.Encoding,
   160  		Body:      bytes.NewReader([]byte("derp")),
   161  	})
   162  
   163  	if assert.Error(t, err, "RPC call to / should have failed") {
   164  		assert.Equal(t, yarpcerrors.CodeNotFound, yarpcerrors.FromError(err).Code())
   165  	}
   166  
   167  	o.setURLTemplate("http://host:12345/rpc/v1")
   168  	require.NoError(t, o.Start(), "failed to start outbound")
   169  	defer o.Stop()
   170  
   171  	spec := transport.NewUnaryHandlerSpec(h)
   172  	reg.EXPECT().Choose(gomock.Any(), routertest.NewMatcher().
   173  		WithCaller("foo").
   174  		WithService("bar").
   175  		WithProcedure("hello"),
   176  	).Return(spec, nil)
   177  
   178  	h.EXPECT().Handle(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil)
   179  
   180  	res, err := o.Call(ctx, &transport.Request{
   181  		Caller:    "foo",
   182  		Service:   "bar",
   183  		Procedure: "hello",
   184  		Encoding:  raw.Encoding,
   185  		Body:      bytes.NewReader([]byte("derp")),
   186  	})
   187  
   188  	if assert.NoError(t, err, "expected rpc request to succeed") {
   189  		defer res.Body.Close()
   190  		s, err := ioutil.ReadAll(res.Body)
   191  		if assert.NoError(t, err) {
   192  			assert.Empty(t, s)
   193  		}
   194  	}
   195  }
   196  
   197  func TestMuxWithInterceptor(t *testing.T) {
   198  	tests := []struct {
   199  		path string
   200  		want string
   201  	}{
   202  		{
   203  			path: "/health",
   204  			want: "OK",
   205  		},
   206  		{
   207  			path: "/",
   208  			want: "intercepted",
   209  		},
   210  	}
   211  
   212  	mux := http.NewServeMux()
   213  	mux.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) {
   214  		io.WriteString(w, "OK")
   215  	})
   216  	intercept := func(transportHandler http.Handler) http.Handler {
   217  		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   218  			io.WriteString(w, "intercepted")
   219  		})
   220  	}
   221  
   222  	transport := NewTransport()
   223  	inbound := transport.NewInbound("127.0.0.1:0", Mux("/", mux), Interceptor(intercept))
   224  	inbound.SetRouter(newTestRouter(nil))
   225  	require.NoError(t, inbound.Start(), "Failed to start inbound")
   226  	defer inbound.Stop()
   227  
   228  	dispatcher := yarpc.NewDispatcher(yarpc.Config{
   229  		Name:     "server",
   230  		Inbounds: yarpc.Inbounds{inbound},
   231  	})
   232  	require.NoError(t, dispatcher.Start(), "Failed to start dispatcher")
   233  	defer dispatcher.Stop()
   234  
   235  	for _, tt := range tests {
   236  		t.Run(tt.path, func(t *testing.T) {
   237  			url := fmt.Sprintf("http://%v%v", inbound.Addr(), tt.path)
   238  			_, body, err := httpGet(t, url)
   239  			require.NoError(t, err, "request failed")
   240  			assert.Equal(t, tt.want, string(body))
   241  		})
   242  	}
   243  }
   244  
   245  func TestMultipleInterceptors(t *testing.T) {
   246  	const (
   247  		yarpcResp  = "YARPC response"
   248  		viaResp    = "Via response"
   249  		healthResp = "health response"
   250  		userResp   = "user response"
   251  	)
   252  	// This should be the underlying yarpc handler but it can't be set directly
   253  	// For the ease of testing, use an interceptor and register it last.
   254  	baseHandler := Interceptor(func(http.Handler) http.Handler {
   255  		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   256  			io.WriteString(w, yarpcResp)
   257  		})
   258  	})
   259  
   260  	viaInterceptor := Interceptor(func(h http.Handler) http.Handler {
   261  		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   262  			io.WriteString(w, viaResp)
   263  			h.ServeHTTP(w, r)
   264  		})
   265  	})
   266  
   267  	healthInterceptor := Interceptor(func(h http.Handler) http.Handler {
   268  		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   269  			if r.URL.Path == "/health" {
   270  				io.WriteString(w, healthResp)
   271  			} else {
   272  				h.ServeHTTP(w, r)
   273  			}
   274  		})
   275  	})
   276  
   277  	userInterceptor := Interceptor(func(h http.Handler) http.Handler {
   278  		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   279  			if r.URL.Path == "/user" {
   280  				io.WriteString(w, userResp)
   281  			} else {
   282  				h.ServeHTTP(w, r)
   283  			}
   284  		})
   285  	})
   286  
   287  	tests := []struct {
   288  		msg          string
   289  		interceptors []InboundOption
   290  		url          string
   291  		want         string
   292  	}{
   293  		{
   294  			msg:          "no user interceptor, /yarpc",
   295  			interceptors: []InboundOption{healthInterceptor},
   296  			url:          "/yarpc",
   297  			want:         yarpcResp,
   298  		},
   299  		{
   300  			msg:          "no user interceptor, /health",
   301  			interceptors: []InboundOption{healthInterceptor},
   302  			url:          "/health",
   303  			want:         healthResp,
   304  		},
   305  		{
   306  			msg:          "no user interceptor, /user",
   307  			interceptors: []InboundOption{healthInterceptor},
   308  			url:          "/user",
   309  			want:         yarpcResp,
   310  		},
   311  		{
   312  			msg:          "user interceptor, /yarpc",
   313  			interceptors: []InboundOption{healthInterceptor, userInterceptor},
   314  			url:          "/yarpc",
   315  			want:         yarpcResp,
   316  		},
   317  		{
   318  			msg:          "user interceptor, /health",
   319  			interceptors: []InboundOption{healthInterceptor, userInterceptor},
   320  			url:          "/health",
   321  			want:         healthResp,
   322  		},
   323  		{
   324  			msg:          "user interceptor, /user",
   325  			interceptors: []InboundOption{healthInterceptor, userInterceptor},
   326  			url:          "/user",
   327  			want:         userResp,
   328  		},
   329  		{
   330  			msg:          "ordering guaranteed",
   331  			interceptors: []InboundOption{viaInterceptor},
   332  			url:          "/yarpc",
   333  			want:         viaResp + yarpcResp,
   334  		},
   335  	}
   336  
   337  	for _, tt := range tests {
   338  		t.Run(tt.msg, func(t *testing.T) {
   339  			transport := NewTransport()
   340  			inbound := transport.NewInbound("127.0.0.1:0", append(tt.interceptors, baseHandler)...)
   341  			inbound.SetRouter(newTestRouter(nil))
   342  			require.NoError(t, inbound.Start(), "Failed to start inbound")
   343  			defer inbound.Stop()
   344  
   345  			dispatcher := yarpc.NewDispatcher(yarpc.Config{
   346  				Name:     "server",
   347  				Inbounds: yarpc.Inbounds{inbound},
   348  			})
   349  			require.NoError(t, dispatcher.Start(), "Failed to start dispatcher")
   350  			defer dispatcher.Stop()
   351  
   352  			url := fmt.Sprintf("http://%v%v", inbound.Addr(), tt.url)
   353  			_, body, err := httpGet(t, url)
   354  			require.NoError(t, err, "request failed")
   355  			assert.Equal(t, tt.want, string(body))
   356  		})
   357  	}
   358  }
   359  
   360  func TestRequestAfterStop(t *testing.T) {
   361  	mux := http.NewServeMux()
   362  	mux.HandleFunc("/health", func(w http.ResponseWriter, r *http.Request) {
   363  		io.WriteString(w, "OK")
   364  	})
   365  
   366  	transport := NewTransport()
   367  	inbound := transport.NewInbound("127.0.0.1:0", Mux("/", mux))
   368  	inbound.SetRouter(newTestRouter(nil))
   369  	require.NoError(t, inbound.Start(), "Failed to start inbound")
   370  
   371  	url := fmt.Sprintf("http://%v/health", inbound.Addr())
   372  	_, body, err := httpGet(t, url)
   373  	require.NoError(t, err, "expect successful response")
   374  	assert.Equal(t, "OK", body, "response mismatch")
   375  
   376  	require.NoError(t, inbound.Stop(), "Failed to stop inbound")
   377  
   378  	_, _, err = httpGet(t, url)
   379  	assert.Error(t, err, "requests should fail once inbound is stopped")
   380  }
   381  
   382  func httpGet(t *testing.T, url string) (*http.Response, string, error) {
   383  	resp, err := http.Get(url)
   384  	if err != nil {
   385  		return nil, "", fmt.Errorf("GET %v failed: %v", url, err)
   386  	}
   387  	defer resp.Body.Close()
   388  
   389  	body, err := ioutil.ReadAll(resp.Body)
   390  	if err != nil {
   391  		return nil, "", fmt.Errorf("Failed to read reponse from %v: %v", url, err)
   392  	}
   393  
   394  	return resp, string(body), nil
   395  }