google.golang.org/grpc@v1.62.1/internal/transport/handler_server_test.go (about)

     1  /*
     2   *
     3   * Copyright 2016 gRPC authors.
     4   *
     5   * Licensed under the Apache License, Version 2.0 (the "License");
     6   * you may not use this file except in compliance with the License.
     7   * You may obtain a copy of the License at
     8   *
     9   *     http://www.apache.org/licenses/LICENSE-2.0
    10   *
    11   * Unless required by applicable law or agreed to in writing, software
    12   * distributed under the License is distributed on an "AS IS" BASIS,
    13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14   * See the License for the specific language governing permissions and
    15   * limitations under the License.
    16   *
    17   */
    18  
    19  package transport
    20  
    21  import (
    22  	"context"
    23  	"errors"
    24  	"fmt"
    25  	"io"
    26  	"net/http"
    27  	"net/http/httptest"
    28  	"net/url"
    29  	"reflect"
    30  	"sync"
    31  	"testing"
    32  	"time"
    33  
    34  	epb "google.golang.org/genproto/googleapis/rpc/errdetails"
    35  	"google.golang.org/grpc/codes"
    36  	"google.golang.org/grpc/metadata"
    37  	"google.golang.org/grpc/status"
    38  	"google.golang.org/protobuf/proto"
    39  	"google.golang.org/protobuf/protoadapt"
    40  	"google.golang.org/protobuf/types/known/durationpb"
    41  )
    42  
    43  func (s) TestHandlerTransport_NewServerHandlerTransport(t *testing.T) {
    44  	type testCase struct {
    45  		name        string
    46  		req         *http.Request
    47  		wantErr     string
    48  		wantErrCode int
    49  		modrw       func(http.ResponseWriter) http.ResponseWriter
    50  		check       func(*serverHandlerTransport, *testCase) error
    51  	}
    52  	tests := []testCase{
    53  		{
    54  			name: "http/1.1",
    55  			req: &http.Request{
    56  				ProtoMajor: 1,
    57  				ProtoMinor: 1,
    58  			},
    59  			wantErr:     "gRPC requires HTTP/2",
    60  			wantErrCode: http.StatusBadRequest,
    61  		},
    62  		{
    63  			name: "bad method",
    64  			req: &http.Request{
    65  				ProtoMajor: 2,
    66  				Method:     "GET",
    67  				Header:     http.Header{},
    68  			},
    69  			wantErr:     `invalid gRPC request method "GET"`,
    70  			wantErrCode: http.StatusBadRequest,
    71  		},
    72  		{
    73  			name: "bad content type",
    74  			req: &http.Request{
    75  				ProtoMajor: 2,
    76  				Method:     "POST",
    77  				Header: http.Header{
    78  					"Content-Type": {"application/foo"},
    79  				},
    80  			},
    81  			wantErr:     `invalid gRPC request content-type "application/foo"`,
    82  			wantErrCode: http.StatusUnsupportedMediaType,
    83  		},
    84  		{
    85  			name: "not flusher",
    86  			req: &http.Request{
    87  				ProtoMajor: 2,
    88  				Method:     "POST",
    89  				Header: http.Header{
    90  					"Content-Type": {"application/grpc"},
    91  				},
    92  			},
    93  			modrw: func(w http.ResponseWriter) http.ResponseWriter {
    94  				// Return w without its Flush method
    95  				type onlyCloseNotifier interface {
    96  					http.ResponseWriter
    97  				}
    98  				return struct{ onlyCloseNotifier }{w.(onlyCloseNotifier)}
    99  			},
   100  			wantErr:     "gRPC requires a ResponseWriter supporting http.Flusher",
   101  			wantErrCode: http.StatusInternalServerError,
   102  		},
   103  		{
   104  			name: "valid",
   105  			req: &http.Request{
   106  				ProtoMajor: 2,
   107  				Method:     "POST",
   108  				Header: http.Header{
   109  					"Content-Type": {"application/grpc"},
   110  				},
   111  				URL: &url.URL{
   112  					Path: "/service/foo.bar",
   113  				},
   114  			},
   115  			check: func(t *serverHandlerTransport, tt *testCase) error {
   116  				if t.req != tt.req {
   117  					return fmt.Errorf("t.req = %p; want %p", t.req, tt.req)
   118  				}
   119  				if t.rw == nil {
   120  					return errors.New("t.rw = nil; want non-nil")
   121  				}
   122  				return nil
   123  			},
   124  		},
   125  		{
   126  			name: "with timeout",
   127  			req: &http.Request{
   128  				ProtoMajor: 2,
   129  				Method:     "POST",
   130  				Header: http.Header{
   131  					"Content-Type": []string{"application/grpc"},
   132  					"Grpc-Timeout": {"200m"},
   133  				},
   134  				URL: &url.URL{
   135  					Path: "/service/foo.bar",
   136  				},
   137  			},
   138  			check: func(t *serverHandlerTransport, tt *testCase) error {
   139  				if !t.timeoutSet {
   140  					return errors.New("timeout not set")
   141  				}
   142  				if want := 200 * time.Millisecond; t.timeout != want {
   143  					return fmt.Errorf("timeout = %v; want %v", t.timeout, want)
   144  				}
   145  				return nil
   146  			},
   147  		},
   148  		{
   149  			name: "with bad timeout",
   150  			req: &http.Request{
   151  				ProtoMajor: 2,
   152  				Method:     "POST",
   153  				Header: http.Header{
   154  					"Content-Type": []string{"application/grpc"},
   155  					"Grpc-Timeout": {"tomorrow"},
   156  				},
   157  				URL: &url.URL{
   158  					Path: "/service/foo.bar",
   159  				},
   160  			},
   161  			wantErr:     `rpc error: code = Internal desc = malformed grpc-timeout: transport: timeout unit is not recognized: "tomorrow"`,
   162  			wantErrCode: http.StatusBadRequest,
   163  		},
   164  		{
   165  			name: "with metadata",
   166  			req: &http.Request{
   167  				ProtoMajor: 2,
   168  				Method:     "POST",
   169  				Header: http.Header{
   170  					"Content-Type": []string{"application/grpc"},
   171  					"meta-foo":     {"foo-val"},
   172  					"meta-bar":     {"bar-val1", "bar-val2"},
   173  					"user-agent":   {"x/y a/b"},
   174  				},
   175  				URL: &url.URL{
   176  					Path: "/service/foo.bar",
   177  				},
   178  			},
   179  			check: func(ht *serverHandlerTransport, tt *testCase) error {
   180  				want := metadata.MD{
   181  					"meta-bar":     {"bar-val1", "bar-val2"},
   182  					"user-agent":   {"x/y a/b"},
   183  					"meta-foo":     {"foo-val"},
   184  					"content-type": {"application/grpc"},
   185  				}
   186  
   187  				if !reflect.DeepEqual(ht.headerMD, want) {
   188  					return fmt.Errorf("metdata = %#v; want %#v", ht.headerMD, want)
   189  				}
   190  				return nil
   191  			},
   192  		},
   193  	}
   194  
   195  	for _, tt := range tests {
   196  		rrec := httptest.NewRecorder()
   197  		rw := http.ResponseWriter(testHandlerResponseWriter{
   198  			ResponseRecorder: rrec,
   199  		})
   200  
   201  		if tt.modrw != nil {
   202  			rw = tt.modrw(rw)
   203  		}
   204  		got, gotErr := NewServerHandlerTransport(rw, tt.req, nil)
   205  		if (gotErr != nil) != (tt.wantErr != "") || (gotErr != nil && gotErr.Error() != tt.wantErr) {
   206  			t.Errorf("%s: error = %q; want %q", tt.name, gotErr.Error(), tt.wantErr)
   207  			continue
   208  		}
   209  		if tt.wantErrCode == 0 {
   210  			tt.wantErrCode = http.StatusOK
   211  		}
   212  		if rrec.Code != tt.wantErrCode {
   213  			t.Errorf("%s: code = %d; want %d", tt.name, rrec.Code, tt.wantErrCode)
   214  			continue
   215  		}
   216  		if gotErr != nil {
   217  			continue
   218  		}
   219  		if tt.check != nil {
   220  			if err := tt.check(got.(*serverHandlerTransport), &tt); err != nil {
   221  				t.Errorf("%s: %v", tt.name, err)
   222  			}
   223  		}
   224  	}
   225  }
   226  
   227  type testHandlerResponseWriter struct {
   228  	*httptest.ResponseRecorder
   229  }
   230  
   231  func (w testHandlerResponseWriter) Flush() {}
   232  
   233  func newTestHandlerResponseWriter() http.ResponseWriter {
   234  	return testHandlerResponseWriter{
   235  		ResponseRecorder: httptest.NewRecorder(),
   236  	}
   237  }
   238  
   239  type handleStreamTest struct {
   240  	t     *testing.T
   241  	bodyw *io.PipeWriter
   242  	rw    testHandlerResponseWriter
   243  	ht    *serverHandlerTransport
   244  }
   245  
   246  func newHandleStreamTest(t *testing.T) *handleStreamTest {
   247  	bodyr, bodyw := io.Pipe()
   248  	req := &http.Request{
   249  		ProtoMajor: 2,
   250  		Method:     "POST",
   251  		Header: http.Header{
   252  			"Content-Type": {"application/grpc"},
   253  		},
   254  		URL: &url.URL{
   255  			Path: "/service/foo.bar",
   256  		},
   257  		Body: bodyr,
   258  	}
   259  	rw := newTestHandlerResponseWriter().(testHandlerResponseWriter)
   260  	ht, err := NewServerHandlerTransport(rw, req, nil)
   261  	if err != nil {
   262  		t.Fatal(err)
   263  	}
   264  	return &handleStreamTest{
   265  		t:     t,
   266  		bodyw: bodyw,
   267  		ht:    ht.(*serverHandlerTransport),
   268  		rw:    rw,
   269  	}
   270  }
   271  
   272  func (s) TestHandlerTransport_HandleStreams(t *testing.T) {
   273  	st := newHandleStreamTest(t)
   274  	handleStream := func(s *Stream) {
   275  		if want := "/service/foo.bar"; s.method != want {
   276  			t.Errorf("stream method = %q; want %q", s.method, want)
   277  		}
   278  
   279  		if err := s.SetHeader(metadata.Pairs("custom-header", "Custom header value")); err != nil {
   280  			t.Error(err)
   281  		}
   282  
   283  		if err := s.SetTrailer(metadata.Pairs("custom-trailer", "Custom trailer value")); err != nil {
   284  			t.Error(err)
   285  		}
   286  
   287  		if err := s.SetSendCompress("gzip"); err != nil {
   288  			t.Error(err)
   289  		}
   290  
   291  		md := metadata.Pairs("custom-header", "Another custom header value")
   292  		if err := s.SendHeader(md); err != nil {
   293  			t.Error(err)
   294  		}
   295  		delete(md, "custom-header")
   296  
   297  		if err := s.SetHeader(metadata.Pairs("too-late", "Header value that should be ignored")); err == nil {
   298  			t.Error("expected SetHeader call after SendHeader to fail")
   299  		}
   300  
   301  		if err := s.SendHeader(metadata.Pairs("too-late", "This header value should be ignored as well")); err == nil {
   302  			t.Error("expected second SendHeader call to fail")
   303  		}
   304  
   305  		if err := s.SetSendCompress("snappy"); err == nil {
   306  			t.Error("expected second SetSendCompress call to fail")
   307  		}
   308  
   309  		st.bodyw.Close() // no body
   310  		st.ht.WriteStatus(s, status.New(codes.OK, ""))
   311  	}
   312  	st.ht.HandleStreams(
   313  		context.Background(), func(s *Stream) { go handleStream(s) },
   314  	)
   315  	wantHeader := http.Header{
   316  		"Date":          nil,
   317  		"Content-Type":  {"application/grpc"},
   318  		"Trailer":       {"Grpc-Status", "Grpc-Message", "Grpc-Status-Details-Bin"},
   319  		"Custom-Header": {"Custom header value", "Another custom header value"},
   320  		"Grpc-Encoding": {"gzip"},
   321  	}
   322  	wantTrailer := http.Header{
   323  		"Grpc-Status":    {"0"},
   324  		"Custom-Trailer": {"Custom trailer value"},
   325  	}
   326  	checkHeaderAndTrailer(t, st.rw, wantHeader, wantTrailer)
   327  }
   328  
   329  // Tests that codes.Unimplemented will close the body, per comment in handler_server.go.
   330  func (s) TestHandlerTransport_HandleStreams_Unimplemented(t *testing.T) {
   331  	handleStreamCloseBodyTest(t, codes.Unimplemented, "thingy is unimplemented")
   332  }
   333  
   334  // Tests that codes.InvalidArgument will close the body, per comment in handler_server.go.
   335  func (s) TestHandlerTransport_HandleStreams_InvalidArgument(t *testing.T) {
   336  	handleStreamCloseBodyTest(t, codes.InvalidArgument, "bad arg")
   337  }
   338  
   339  func handleStreamCloseBodyTest(t *testing.T, statusCode codes.Code, msg string) {
   340  	st := newHandleStreamTest(t)
   341  
   342  	handleStream := func(s *Stream) {
   343  		st.ht.WriteStatus(s, status.New(statusCode, msg))
   344  	}
   345  	st.ht.HandleStreams(
   346  		context.Background(), func(s *Stream) { go handleStream(s) },
   347  	)
   348  	wantHeader := http.Header{
   349  		"Date":         nil,
   350  		"Content-Type": {"application/grpc"},
   351  		"Trailer":      {"Grpc-Status", "Grpc-Message", "Grpc-Status-Details-Bin"},
   352  	}
   353  	wantTrailer := http.Header{
   354  		"Grpc-Status":  {fmt.Sprint(uint32(statusCode))},
   355  		"Grpc-Message": {encodeGrpcMessage(msg)},
   356  	}
   357  	checkHeaderAndTrailer(t, st.rw, wantHeader, wantTrailer)
   358  }
   359  
   360  func (s) TestHandlerTransport_HandleStreams_Timeout(t *testing.T) {
   361  	bodyr, bodyw := io.Pipe()
   362  	req := &http.Request{
   363  		ProtoMajor: 2,
   364  		Method:     "POST",
   365  		Header: http.Header{
   366  			"Content-Type": {"application/grpc"},
   367  			"Grpc-Timeout": {"200m"},
   368  		},
   369  		URL: &url.URL{
   370  			Path: "/service/foo.bar",
   371  		},
   372  		Body: bodyr,
   373  	}
   374  	rw := newTestHandlerResponseWriter().(testHandlerResponseWriter)
   375  	ht, err := NewServerHandlerTransport(rw, req, nil)
   376  	if err != nil {
   377  		t.Fatal(err)
   378  	}
   379  	runStream := func(s *Stream) {
   380  		defer bodyw.Close()
   381  		select {
   382  		case <-s.ctx.Done():
   383  		case <-time.After(5 * time.Second):
   384  			t.Errorf("timeout waiting for ctx.Done")
   385  			return
   386  		}
   387  		err := s.ctx.Err()
   388  		if err != context.DeadlineExceeded {
   389  			t.Errorf("ctx.Err = %v; want %v", err, context.DeadlineExceeded)
   390  			return
   391  		}
   392  		ht.WriteStatus(s, status.New(codes.DeadlineExceeded, "too slow"))
   393  	}
   394  	ht.HandleStreams(
   395  		context.Background(), func(s *Stream) { go runStream(s) },
   396  	)
   397  	wantHeader := http.Header{
   398  		"Date":         nil,
   399  		"Content-Type": {"application/grpc"},
   400  		"Trailer":      {"Grpc-Status", "Grpc-Message", "Grpc-Status-Details-Bin"},
   401  	}
   402  	wantTrailer := http.Header{
   403  		"Grpc-Status":  {"4"},
   404  		"Grpc-Message": {encodeGrpcMessage("too slow")},
   405  	}
   406  	checkHeaderAndTrailer(t, rw, wantHeader, wantTrailer)
   407  }
   408  
   409  // TestHandlerTransport_HandleStreams_MultiWriteStatus ensures that
   410  // concurrent "WriteStatus"s do not panic writing to closed "writes" channel.
   411  func (s) TestHandlerTransport_HandleStreams_MultiWriteStatus(t *testing.T) {
   412  	testHandlerTransportHandleStreams(t, func(st *handleStreamTest, s *Stream) {
   413  		if want := "/service/foo.bar"; s.method != want {
   414  			t.Errorf("stream method = %q; want %q", s.method, want)
   415  		}
   416  		st.bodyw.Close() // no body
   417  
   418  		var wg sync.WaitGroup
   419  		wg.Add(5)
   420  		for i := 0; i < 5; i++ {
   421  			go func() {
   422  				defer wg.Done()
   423  				st.ht.WriteStatus(s, status.New(codes.OK, ""))
   424  			}()
   425  		}
   426  		wg.Wait()
   427  	})
   428  }
   429  
   430  // TestHandlerTransport_HandleStreams_WriteStatusWrite ensures that "Write"
   431  // following "WriteStatus" does not panic writing to closed "writes" channel.
   432  func (s) TestHandlerTransport_HandleStreams_WriteStatusWrite(t *testing.T) {
   433  	testHandlerTransportHandleStreams(t, func(st *handleStreamTest, s *Stream) {
   434  		if want := "/service/foo.bar"; s.method != want {
   435  			t.Errorf("stream method = %q; want %q", s.method, want)
   436  		}
   437  		st.bodyw.Close() // no body
   438  
   439  		st.ht.WriteStatus(s, status.New(codes.OK, ""))
   440  		st.ht.Write(s, []byte("hdr"), []byte("data"), &Options{})
   441  	})
   442  }
   443  
   444  func testHandlerTransportHandleStreams(t *testing.T, handleStream func(st *handleStreamTest, s *Stream)) {
   445  	st := newHandleStreamTest(t)
   446  	st.ht.HandleStreams(
   447  		context.Background(), func(s *Stream) { go handleStream(st, s) },
   448  	)
   449  }
   450  
   451  func (s) TestHandlerTransport_HandleStreams_ErrDetails(t *testing.T) {
   452  	errDetails := []protoadapt.MessageV1{
   453  		&epb.RetryInfo{
   454  			RetryDelay: &durationpb.Duration{Seconds: 60},
   455  		},
   456  		&epb.ResourceInfo{
   457  			ResourceType: "foo bar",
   458  			ResourceName: "service.foo.bar",
   459  			Owner:        "User",
   460  		},
   461  	}
   462  
   463  	statusCode := codes.ResourceExhausted
   464  	msg := "you are being throttled"
   465  	st, err := status.New(statusCode, msg).WithDetails(errDetails...)
   466  	if err != nil {
   467  		t.Fatal(err)
   468  	}
   469  
   470  	stBytes, err := proto.Marshal(st.Proto())
   471  	if err != nil {
   472  		t.Fatal(err)
   473  	}
   474  
   475  	hst := newHandleStreamTest(t)
   476  	handleStream := func(s *Stream) {
   477  		hst.ht.WriteStatus(s, st)
   478  	}
   479  	hst.ht.HandleStreams(
   480  		context.Background(), func(s *Stream) { go handleStream(s) },
   481  	)
   482  	wantHeader := http.Header{
   483  		"Date":         nil,
   484  		"Content-Type": {"application/grpc"},
   485  		"Trailer":      {"Grpc-Status", "Grpc-Message", "Grpc-Status-Details-Bin"},
   486  	}
   487  	wantTrailer := http.Header{
   488  		"Grpc-Status":             {fmt.Sprint(uint32(statusCode))},
   489  		"Grpc-Message":            {encodeGrpcMessage(msg)},
   490  		"Grpc-Status-Details-Bin": {encodeBinHeader(stBytes)},
   491  	}
   492  
   493  	checkHeaderAndTrailer(t, hst.rw, wantHeader, wantTrailer)
   494  }
   495  
   496  // TestHandlerTransport_Drain verifies that Drain() is not implemented
   497  // by `serverHandlerTransport`.
   498  func (s) TestHandlerTransport_Drain(t *testing.T) {
   499  	defer func() { recover() }()
   500  	st := newHandleStreamTest(t)
   501  	st.ht.Drain("whatever")
   502  	t.Errorf("serverHandlerTransport.Drain() should have panicked")
   503  }
   504  
   505  // checkHeaderAndTrailer checks that the resulting header and trailer matches the expectation.
   506  func checkHeaderAndTrailer(t *testing.T, rw testHandlerResponseWriter, wantHeader, wantTrailer http.Header) {
   507  	// For trailer-only responses, the trailer values might be reported as part of the Header. They will however
   508  	// be present in Trailer in either case. Hence, normalize the header by removing all trailer values.
   509  	actualHeader := rw.Result().Header.Clone()
   510  	for _, trailerKey := range actualHeader["Trailer"] {
   511  		actualHeader.Del(trailerKey)
   512  	}
   513  
   514  	if !reflect.DeepEqual(actualHeader, wantHeader) {
   515  		t.Errorf("Header mismatch.\n got: %#v\n want: %#v", actualHeader, wantHeader)
   516  	}
   517  	if actualTrailer := rw.Result().Trailer; !reflect.DeepEqual(actualTrailer, wantTrailer) {
   518  		t.Errorf("Trailer mismatch.\n got: %#v\n want: %#v", actualTrailer, wantTrailer)
   519  	}
   520  }