google.golang.org/grpc@v1.72.2/internal/transport/handler_server_test.go (about)

     1  /*
     2   *
     3   * Copyright 2016 gRPC authors.
     4   *
     5   * Licensed under the Apache License, Version 2.0 (the "License");
     6   * you may not use this file except in compliance with the License.
     7   * You may obtain a copy of the License at
     8   *
     9   *     http://www.apache.org/licenses/LICENSE-2.0
    10   *
    11   * Unless required by applicable law or agreed to in writing, software
    12   * distributed under the License is distributed on an "AS IS" BASIS,
    13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14   * See the License for the specific language governing permissions and
    15   * limitations under the License.
    16   *
    17   */
    18  
    19  package transport
    20  
    21  import (
    22  	"context"
    23  	"errors"
    24  	"fmt"
    25  	"io"
    26  	"net/http"
    27  	"net/http/httptest"
    28  	"net/url"
    29  	"reflect"
    30  	"sync"
    31  	"testing"
    32  	"time"
    33  
    34  	epb "google.golang.org/genproto/googleapis/rpc/errdetails"
    35  	"google.golang.org/grpc/codes"
    36  	"google.golang.org/grpc/mem"
    37  	"google.golang.org/grpc/metadata"
    38  	"google.golang.org/grpc/status"
    39  	"google.golang.org/protobuf/proto"
    40  	"google.golang.org/protobuf/protoadapt"
    41  	"google.golang.org/protobuf/types/known/durationpb"
    42  )
    43  
    44  func (s) TestHandlerTransport_NewServerHandlerTransport(t *testing.T) {
    45  	type testCase struct {
    46  		name        string
    47  		req         *http.Request
    48  		wantErr     string
    49  		wantErrCode int
    50  		modrw       func(http.ResponseWriter) http.ResponseWriter
    51  		check       func(*serverHandlerTransport, *testCase) error
    52  	}
    53  	tests := []testCase{
    54  		{
    55  			name: "bad method",
    56  			req: &http.Request{
    57  				ProtoMajor: 2,
    58  				Method:     "GET",
    59  				Header:     http.Header{},
    60  			},
    61  			wantErr:     `invalid gRPC request method "GET"`,
    62  			wantErrCode: http.StatusMethodNotAllowed,
    63  		},
    64  		{
    65  			name: "bad content type",
    66  			req: &http.Request{
    67  				ProtoMajor: 2,
    68  				Method:     "POST",
    69  				Header: http.Header{
    70  					"Content-Type": {"application/foo"},
    71  				},
    72  			},
    73  			wantErr:     `invalid gRPC request content-type "application/foo"`,
    74  			wantErrCode: http.StatusUnsupportedMediaType,
    75  		},
    76  		{
    77  			name: "http/1.1",
    78  			req: &http.Request{
    79  				ProtoMajor: 1,
    80  				ProtoMinor: 1,
    81  				Method:     "POST",
    82  				Header:     http.Header{"Content-Type": []string{"application/grpc"}},
    83  			},
    84  			wantErr:     "gRPC requires HTTP/2",
    85  			wantErrCode: http.StatusHTTPVersionNotSupported,
    86  		},
    87  		{
    88  			name: "not flusher",
    89  			req: &http.Request{
    90  				ProtoMajor: 2,
    91  				Method:     "POST",
    92  				Header: http.Header{
    93  					"Content-Type": {"application/grpc"},
    94  				},
    95  			},
    96  			modrw: func(w http.ResponseWriter) http.ResponseWriter {
    97  				// Return w without its Flush method
    98  				type onlyCloseNotifier interface {
    99  					http.ResponseWriter
   100  				}
   101  				return struct{ onlyCloseNotifier }{w.(onlyCloseNotifier)}
   102  			},
   103  			wantErr:     "gRPC requires a ResponseWriter supporting http.Flusher",
   104  			wantErrCode: http.StatusInternalServerError,
   105  		},
   106  		{
   107  			name: "valid",
   108  			req: &http.Request{
   109  				ProtoMajor: 2,
   110  				Method:     "POST",
   111  				Header: http.Header{
   112  					"Content-Type": {"application/grpc"},
   113  				},
   114  				URL: &url.URL{
   115  					Path: "/service/foo.bar",
   116  				},
   117  			},
   118  			check: func(t *serverHandlerTransport, tt *testCase) error {
   119  				if t.req != tt.req {
   120  					return fmt.Errorf("t.req = %p; want %p", t.req, tt.req)
   121  				}
   122  				if t.rw == nil {
   123  					return errors.New("t.rw = nil; want non-nil")
   124  				}
   125  				return nil
   126  			},
   127  		},
   128  		{
   129  			name: "with timeout",
   130  			req: &http.Request{
   131  				ProtoMajor: 2,
   132  				Method:     "POST",
   133  				Header: http.Header{
   134  					"Content-Type": []string{"application/grpc"},
   135  					"Grpc-Timeout": {"200m"},
   136  				},
   137  				URL: &url.URL{
   138  					Path: "/service/foo.bar",
   139  				},
   140  			},
   141  			check: func(t *serverHandlerTransport, _ *testCase) error {
   142  				if !t.timeoutSet {
   143  					return errors.New("timeout not set")
   144  				}
   145  				if want := 200 * time.Millisecond; t.timeout != want {
   146  					return fmt.Errorf("timeout = %v; want %v", t.timeout, want)
   147  				}
   148  				return nil
   149  			},
   150  		},
   151  		{
   152  			name: "with bad timeout",
   153  			req: &http.Request{
   154  				ProtoMajor: 2,
   155  				Method:     "POST",
   156  				Header: http.Header{
   157  					"Content-Type": []string{"application/grpc"},
   158  					"Grpc-Timeout": {"tomorrow"},
   159  				},
   160  				URL: &url.URL{
   161  					Path: "/service/foo.bar",
   162  				},
   163  			},
   164  			wantErr:     `rpc error: code = Internal desc = malformed grpc-timeout: transport: timeout unit is not recognized: "tomorrow"`,
   165  			wantErrCode: http.StatusBadRequest,
   166  		},
   167  		{
   168  			name: "with metadata",
   169  			req: &http.Request{
   170  				ProtoMajor: 2,
   171  				Method:     "POST",
   172  				Header: http.Header{
   173  					"Content-Type": []string{"application/grpc"},
   174  					"meta-foo":     {"foo-val"},
   175  					"meta-bar":     {"bar-val1", "bar-val2"},
   176  					"user-agent":   {"x/y a/b"},
   177  				},
   178  				URL: &url.URL{
   179  					Path: "/service/foo.bar",
   180  				},
   181  			},
   182  			check: func(ht *serverHandlerTransport, _ *testCase) error {
   183  				want := metadata.MD{
   184  					"meta-bar":     {"bar-val1", "bar-val2"},
   185  					"user-agent":   {"x/y a/b"},
   186  					"meta-foo":     {"foo-val"},
   187  					"content-type": {"application/grpc"},
   188  				}
   189  
   190  				if !reflect.DeepEqual(ht.headerMD, want) {
   191  					return fmt.Errorf("metadata = %#v; want %#v", ht.headerMD, want)
   192  				}
   193  				return nil
   194  			},
   195  		},
   196  	}
   197  
   198  	for _, tt := range tests {
   199  		rrec := httptest.NewRecorder()
   200  		rw := http.ResponseWriter(testHandlerResponseWriter{
   201  			ResponseRecorder: rrec,
   202  		})
   203  
   204  		if tt.modrw != nil {
   205  			rw = tt.modrw(rw)
   206  		}
   207  		got, gotErr := NewServerHandlerTransport(rw, tt.req, nil, mem.DefaultBufferPool())
   208  		if (gotErr != nil) != (tt.wantErr != "") || (gotErr != nil && gotErr.Error() != tt.wantErr) {
   209  			t.Errorf("%s: error = %q; want %q", tt.name, gotErr.Error(), tt.wantErr)
   210  			continue
   211  		}
   212  		if tt.wantErrCode == 0 {
   213  			tt.wantErrCode = http.StatusOK
   214  		}
   215  		if rrec.Code != tt.wantErrCode {
   216  			t.Errorf("%s: code = %d; want %d", tt.name, rrec.Code, tt.wantErrCode)
   217  			continue
   218  		}
   219  		if gotErr != nil {
   220  			continue
   221  		}
   222  		if tt.check != nil {
   223  			if err := tt.check(got.(*serverHandlerTransport), &tt); err != nil {
   224  				t.Errorf("%s: %v", tt.name, err)
   225  			}
   226  		}
   227  	}
   228  }
   229  
   230  type testHandlerResponseWriter struct {
   231  	*httptest.ResponseRecorder
   232  }
   233  
   234  func (w testHandlerResponseWriter) Flush() {}
   235  
   236  func newTestHandlerResponseWriter() http.ResponseWriter {
   237  	return testHandlerResponseWriter{
   238  		ResponseRecorder: httptest.NewRecorder(),
   239  	}
   240  }
   241  
   242  type handleStreamTest struct {
   243  	t     *testing.T
   244  	bodyw *io.PipeWriter
   245  	rw    testHandlerResponseWriter
   246  	ht    *serverHandlerTransport
   247  }
   248  
   249  func newHandleStreamTest(t *testing.T) *handleStreamTest {
   250  	bodyr, bodyw := io.Pipe()
   251  	req := &http.Request{
   252  		ProtoMajor: 2,
   253  		Method:     "POST",
   254  		Header: http.Header{
   255  			"Content-Type": {"application/grpc"},
   256  		},
   257  		URL: &url.URL{
   258  			Path: "/service/foo.bar",
   259  		},
   260  		Body: bodyr,
   261  	}
   262  	rw := newTestHandlerResponseWriter().(testHandlerResponseWriter)
   263  	ht, err := NewServerHandlerTransport(rw, req, nil, mem.DefaultBufferPool())
   264  	if err != nil {
   265  		t.Fatal(err)
   266  	}
   267  	return &handleStreamTest{
   268  		t:     t,
   269  		bodyw: bodyw,
   270  		ht:    ht.(*serverHandlerTransport),
   271  		rw:    rw,
   272  	}
   273  }
   274  
   275  func (s) TestHandlerTransport_HandleStreams(t *testing.T) {
   276  	st := newHandleStreamTest(t)
   277  	handleStream := func(s *ServerStream) {
   278  		if want := "/service/foo.bar"; s.method != want {
   279  			t.Errorf("stream method = %q; want %q", s.method, want)
   280  		}
   281  
   282  		if err := s.SetHeader(metadata.Pairs("custom-header", "Custom header value")); err != nil {
   283  			t.Error(err)
   284  		}
   285  
   286  		if err := s.SetTrailer(metadata.Pairs("custom-trailer", "Custom trailer value")); err != nil {
   287  			t.Error(err)
   288  		}
   289  
   290  		if err := s.SetSendCompress("gzip"); err != nil {
   291  			t.Error(err)
   292  		}
   293  
   294  		md := metadata.Pairs("custom-header", "Another custom header value")
   295  		if err := s.SendHeader(md); err != nil {
   296  			t.Error(err)
   297  		}
   298  		delete(md, "custom-header")
   299  
   300  		if err := s.SetHeader(metadata.Pairs("too-late", "Header value that should be ignored")); err == nil {
   301  			t.Error("expected SetHeader call after SendHeader to fail")
   302  		}
   303  
   304  		if err := s.SendHeader(metadata.Pairs("too-late", "This header value should be ignored as well")); err == nil {
   305  			t.Error("expected second SendHeader call to fail")
   306  		}
   307  
   308  		if err := s.SetSendCompress("snappy"); err == nil {
   309  			t.Error("expected second SetSendCompress call to fail")
   310  		}
   311  
   312  		st.bodyw.Close() // no body
   313  		s.WriteStatus(status.New(codes.OK, ""))
   314  	}
   315  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   316  	defer cancel()
   317  	st.ht.HandleStreams(
   318  		ctx, func(s *ServerStream) { go handleStream(s) },
   319  	)
   320  	wantHeader := http.Header{
   321  		"Date":          nil,
   322  		"Content-Type":  {"application/grpc"},
   323  		"Trailer":       {"Grpc-Status", "Grpc-Message", "Grpc-Status-Details-Bin"},
   324  		"Custom-Header": {"Custom header value", "Another custom header value"},
   325  		"Grpc-Encoding": {"gzip"},
   326  	}
   327  	wantTrailer := http.Header{
   328  		"Grpc-Status":    {"0"},
   329  		"Custom-Trailer": {"Custom trailer value"},
   330  	}
   331  	checkHeaderAndTrailer(t, st.rw, wantHeader, wantTrailer)
   332  }
   333  
   334  // Tests that codes.Unimplemented will close the body, per comment in handler_server.go.
   335  func (s) TestHandlerTransport_HandleStreams_Unimplemented(t *testing.T) {
   336  	handleStreamCloseBodyTest(t, codes.Unimplemented, "thingy is unimplemented")
   337  }
   338  
   339  // Tests that codes.InvalidArgument will close the body, per comment in handler_server.go.
   340  func (s) TestHandlerTransport_HandleStreams_InvalidArgument(t *testing.T) {
   341  	handleStreamCloseBodyTest(t, codes.InvalidArgument, "bad arg")
   342  }
   343  
   344  func handleStreamCloseBodyTest(t *testing.T, statusCode codes.Code, msg string) {
   345  	st := newHandleStreamTest(t)
   346  
   347  	handleStream := func(s *ServerStream) {
   348  		s.WriteStatus(status.New(statusCode, msg))
   349  	}
   350  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   351  	defer cancel()
   352  	st.ht.HandleStreams(
   353  		ctx, func(s *ServerStream) { go handleStream(s) },
   354  	)
   355  	wantHeader := http.Header{
   356  		"Date":         nil,
   357  		"Content-Type": {"application/grpc"},
   358  		"Trailer":      {"Grpc-Status", "Grpc-Message", "Grpc-Status-Details-Bin"},
   359  	}
   360  	wantTrailer := http.Header{
   361  		"Grpc-Status":  {fmt.Sprint(uint32(statusCode))},
   362  		"Grpc-Message": {encodeGrpcMessage(msg)},
   363  	}
   364  	checkHeaderAndTrailer(t, st.rw, wantHeader, wantTrailer)
   365  }
   366  
   367  func (s) TestHandlerTransport_HandleStreams_Timeout(t *testing.T) {
   368  	bodyr, bodyw := io.Pipe()
   369  	req := &http.Request{
   370  		ProtoMajor: 2,
   371  		Method:     "POST",
   372  		Header: http.Header{
   373  			"Content-Type": {"application/grpc"},
   374  			"Grpc-Timeout": {"200m"},
   375  		},
   376  		URL: &url.URL{
   377  			Path: "/service/foo.bar",
   378  		},
   379  		Body: bodyr,
   380  	}
   381  	rw := newTestHandlerResponseWriter().(testHandlerResponseWriter)
   382  	ht, err := NewServerHandlerTransport(rw, req, nil, mem.DefaultBufferPool())
   383  	if err != nil {
   384  		t.Fatal(err)
   385  	}
   386  	runStream := func(s *ServerStream) {
   387  		defer bodyw.Close()
   388  		select {
   389  		case <-s.ctx.Done():
   390  		case <-time.After(5 * time.Second):
   391  			t.Errorf("timeout waiting for ctx.Done")
   392  			return
   393  		}
   394  		err := s.ctx.Err()
   395  		if err != context.DeadlineExceeded {
   396  			t.Errorf("ctx.Err = %v; want %v", err, context.DeadlineExceeded)
   397  			return
   398  		}
   399  		s.WriteStatus(status.New(codes.DeadlineExceeded, "too slow"))
   400  	}
   401  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   402  	defer cancel()
   403  	ht.HandleStreams(
   404  		ctx, func(s *ServerStream) { go runStream(s) },
   405  	)
   406  	wantHeader := http.Header{
   407  		"Date":         nil,
   408  		"Content-Type": {"application/grpc"},
   409  		"Trailer":      {"Grpc-Status", "Grpc-Message", "Grpc-Status-Details-Bin"},
   410  	}
   411  	wantTrailer := http.Header{
   412  		"Grpc-Status":  {"4"},
   413  		"Grpc-Message": {encodeGrpcMessage("too slow")},
   414  	}
   415  	checkHeaderAndTrailer(t, rw, wantHeader, wantTrailer)
   416  }
   417  
   418  // TestHandlerTransport_HandleStreams_MultiWriteStatus ensures that
   419  // concurrent "WriteStatus"s do not panic writing to closed "writes" channel.
   420  func (s) TestHandlerTransport_HandleStreams_MultiWriteStatus(t *testing.T) {
   421  	testHandlerTransportHandleStreams(t, func(st *handleStreamTest, s *ServerStream) {
   422  		if want := "/service/foo.bar"; s.method != want {
   423  			t.Errorf("stream method = %q; want %q", s.method, want)
   424  		}
   425  		st.bodyw.Close() // no body
   426  
   427  		var wg sync.WaitGroup
   428  		wg.Add(5)
   429  		for i := 0; i < 5; i++ {
   430  			go func() {
   431  				defer wg.Done()
   432  				s.WriteStatus(status.New(codes.OK, ""))
   433  			}()
   434  		}
   435  		wg.Wait()
   436  	})
   437  }
   438  
   439  // TestHandlerTransport_HandleStreams_WriteStatusWrite ensures that "Write"
   440  // following "WriteStatus" does not panic writing to closed "writes" channel.
   441  func (s) TestHandlerTransport_HandleStreams_WriteStatusWrite(t *testing.T) {
   442  	testHandlerTransportHandleStreams(t, func(st *handleStreamTest, s *ServerStream) {
   443  		if want := "/service/foo.bar"; s.method != want {
   444  			t.Errorf("stream method = %q; want %q", s.method, want)
   445  		}
   446  		st.bodyw.Close() // no body
   447  
   448  		s.WriteStatus(status.New(codes.OK, ""))
   449  		s.Write([]byte("hdr"), newBufferSlice([]byte("data")), &WriteOptions{})
   450  	})
   451  }
   452  
   453  func testHandlerTransportHandleStreams(t *testing.T, handleStream func(st *handleStreamTest, s *ServerStream)) {
   454  	st := newHandleStreamTest(t)
   455  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   456  	t.Cleanup(cancel)
   457  	st.ht.HandleStreams(
   458  		ctx, func(s *ServerStream) { go handleStream(st, s) },
   459  	)
   460  }
   461  
   462  func (s) TestHandlerTransport_HandleStreams_ErrDetails(t *testing.T) {
   463  	errDetails := []protoadapt.MessageV1{
   464  		&epb.RetryInfo{
   465  			RetryDelay: &durationpb.Duration{Seconds: 60},
   466  		},
   467  		&epb.ResourceInfo{
   468  			ResourceType: "foo bar",
   469  			ResourceName: "service.foo.bar",
   470  			Owner:        "User",
   471  		},
   472  	}
   473  
   474  	statusCode := codes.ResourceExhausted
   475  	msg := "you are being throttled"
   476  	st, err := status.New(statusCode, msg).WithDetails(errDetails...)
   477  	if err != nil {
   478  		t.Fatal(err)
   479  	}
   480  
   481  	stBytes, err := proto.Marshal(st.Proto())
   482  	if err != nil {
   483  		t.Fatal(err)
   484  	}
   485  
   486  	hst := newHandleStreamTest(t)
   487  	handleStream := func(s *ServerStream) {
   488  		s.WriteStatus(st)
   489  	}
   490  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   491  	defer cancel()
   492  	hst.ht.HandleStreams(
   493  		ctx, func(s *ServerStream) { go handleStream(s) },
   494  	)
   495  	wantHeader := http.Header{
   496  		"Date":         nil,
   497  		"Content-Type": {"application/grpc"},
   498  		"Trailer":      {"Grpc-Status", "Grpc-Message", "Grpc-Status-Details-Bin"},
   499  	}
   500  	wantTrailer := http.Header{
   501  		"Grpc-Status":             {fmt.Sprint(uint32(statusCode))},
   502  		"Grpc-Message":            {encodeGrpcMessage(msg)},
   503  		"Grpc-Status-Details-Bin": {encodeBinHeader(stBytes)},
   504  	}
   505  
   506  	checkHeaderAndTrailer(t, hst.rw, wantHeader, wantTrailer)
   507  }
   508  
   509  // TestHandlerTransport_Drain verifies that Drain() is not implemented
   510  // by `serverHandlerTransport`.
   511  func (s) TestHandlerTransport_Drain(t *testing.T) {
   512  	defer func() { recover() }()
   513  	st := newHandleStreamTest(t)
   514  	st.ht.Drain("whatever")
   515  	t.Errorf("serverHandlerTransport.Drain() should have panicked")
   516  }
   517  
   518  // checkHeaderAndTrailer checks that the resulting header and trailer matches the expectation.
   519  func checkHeaderAndTrailer(t *testing.T, rw testHandlerResponseWriter, wantHeader, wantTrailer http.Header) {
   520  	// For trailer-only responses, the trailer values might be reported as part of the Header. They will however
   521  	// be present in Trailer in either case. Hence, normalize the header by removing all trailer values.
   522  	actualHeader := rw.Result().Header.Clone()
   523  	for _, trailerKey := range actualHeader["Trailer"] {
   524  		actualHeader.Del(trailerKey)
   525  	}
   526  
   527  	if !reflect.DeepEqual(actualHeader, wantHeader) {
   528  		t.Errorf("Header mismatch.\n got: %#v\n want: %#v", actualHeader, wantHeader)
   529  	}
   530  	if actualTrailer := rw.Result().Trailer; !reflect.DeepEqual(actualTrailer, wantTrailer) {
   531  		t.Errorf("Trailer mismatch.\n got: %#v\n want: %#v", actualTrailer, wantTrailer)
   532  	}
   533  }