github.com/hxx258456/ccgo@v0.0.5-0.20230213014102-48b35f46f66f/grpc/internal/transport/handler_server_test.go (about)

     1  /*
     2   *
     3   * Copyright 2016 gRPC authors.
     4   *
     5   * Licensed under the Apache License, Version 2.0 (the "License");
     6   * you may not use this file except in compliance with the License.
     7   * You may obtain a copy of the License at
     8   *
     9   *     http://www.apache.org/licenses/LICENSE-2.0
    10   *
    11   * Unless required by applicable law or agreed to in writing, software
    12   * distributed under the License is distributed on an "AS IS" BASIS,
    13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14   * See the License for the specific language governing permissions and
    15   * limitations under the License.
    16   *
    17   */
    18  
    19  package transport
    20  
    21  import (
    22  	"context"
    23  	"errors"
    24  	"fmt"
    25  	"io"
    26  	"net/url"
    27  	"reflect"
    28  	"sync"
    29  	"testing"
    30  	"time"
    31  
    32  	"github.com/hxx258456/ccgo/gmhttp/httptest"
    33  
    34  	http "github.com/hxx258456/ccgo/gmhttp"
    35  
    36  	"github.com/golang/protobuf/proto"
    37  	dpb "github.com/golang/protobuf/ptypes/duration"
    38  	"github.com/hxx258456/ccgo/grpc/codes"
    39  	"github.com/hxx258456/ccgo/grpc/metadata"
    40  	"github.com/hxx258456/ccgo/grpc/status"
    41  	epb "google.golang.org/genproto/googleapis/rpc/errdetails"
    42  )
    43  
    44  func (s) TestHandlerTransport_NewServerHandlerTransport(t *testing.T) {
    45  	type testCase struct {
    46  		name    string
    47  		req     *http.Request
    48  		wantErr string
    49  		modrw   func(http.ResponseWriter) http.ResponseWriter
    50  		check   func(*serverHandlerTransport, *testCase) error
    51  	}
    52  	tests := []testCase{
    53  		{
    54  			name: "http/1.1",
    55  			req: &http.Request{
    56  				ProtoMajor: 1,
    57  				ProtoMinor: 1,
    58  			},
    59  			wantErr: "gRPC requires HTTP/2",
    60  		},
    61  		{
    62  			name: "bad method",
    63  			req: &http.Request{
    64  				ProtoMajor: 2,
    65  				Method:     "GET",
    66  				Header:     http.Header{},
    67  			},
    68  			wantErr: "invalid gRPC request method",
    69  		},
    70  		{
    71  			name: "bad content type",
    72  			req: &http.Request{
    73  				ProtoMajor: 2,
    74  				Method:     "POST",
    75  				Header: http.Header{
    76  					"Content-Type": {"application/foo"},
    77  				},
    78  			},
    79  			wantErr: "invalid gRPC request content-type",
    80  		},
    81  		{
    82  			name: "not flusher",
    83  			req: &http.Request{
    84  				ProtoMajor: 2,
    85  				Method:     "POST",
    86  				Header: http.Header{
    87  					"Content-Type": {"application/grpc"},
    88  				},
    89  			},
    90  			modrw: func(w http.ResponseWriter) http.ResponseWriter {
    91  				// Return w without its Flush method
    92  				type onlyCloseNotifier interface {
    93  					http.ResponseWriter
    94  					http.CloseNotifier
    95  				}
    96  				return struct{ onlyCloseNotifier }{w.(onlyCloseNotifier)}
    97  			},
    98  			wantErr: "gRPC requires a ResponseWriter supporting http.Flusher",
    99  		},
   100  		{
   101  			name: "valid",
   102  			req: &http.Request{
   103  				ProtoMajor: 2,
   104  				Method:     "POST",
   105  				Header: http.Header{
   106  					"Content-Type": {"application/grpc"},
   107  				},
   108  				URL: &url.URL{
   109  					Path: "/service/foo.bar",
   110  				},
   111  			},
   112  			check: func(t *serverHandlerTransport, tt *testCase) error {
   113  				if t.req != tt.req {
   114  					return fmt.Errorf("t.req = %p; want %p", t.req, tt.req)
   115  				}
   116  				if t.rw == nil {
   117  					return errors.New("t.rw = nil; want non-nil")
   118  				}
   119  				return nil
   120  			},
   121  		},
   122  		{
   123  			name: "with timeout",
   124  			req: &http.Request{
   125  				ProtoMajor: 2,
   126  				Method:     "POST",
   127  				Header: http.Header{
   128  					"Content-Type": []string{"application/grpc"},
   129  					"Grpc-Timeout": {"200m"},
   130  				},
   131  				URL: &url.URL{
   132  					Path: "/service/foo.bar",
   133  				},
   134  			},
   135  			check: func(t *serverHandlerTransport, tt *testCase) error {
   136  				if !t.timeoutSet {
   137  					return errors.New("timeout not set")
   138  				}
   139  				if want := 200 * time.Millisecond; t.timeout != want {
   140  					return fmt.Errorf("timeout = %v; want %v", t.timeout, want)
   141  				}
   142  				return nil
   143  			},
   144  		},
   145  		{
   146  			name: "with bad timeout",
   147  			req: &http.Request{
   148  				ProtoMajor: 2,
   149  				Method:     "POST",
   150  				Header: http.Header{
   151  					"Content-Type": []string{"application/grpc"},
   152  					"Grpc-Timeout": {"tomorrow"},
   153  				},
   154  				URL: &url.URL{
   155  					Path: "/service/foo.bar",
   156  				},
   157  			},
   158  			wantErr: `rpc error: code = Internal desc = malformed time-out: transport: timeout unit is not recognized: "tomorrow"`,
   159  		},
   160  		{
   161  			name: "with metadata",
   162  			req: &http.Request{
   163  				ProtoMajor: 2,
   164  				Method:     "POST",
   165  				Header: http.Header{
   166  					"Content-Type": []string{"application/grpc"},
   167  					"meta-foo":     {"foo-val"},
   168  					"meta-bar":     {"bar-val1", "bar-val2"},
   169  					"user-agent":   {"x/y a/b"},
   170  				},
   171  				URL: &url.URL{
   172  					Path: "/service/foo.bar",
   173  				},
   174  			},
   175  			check: func(ht *serverHandlerTransport, tt *testCase) error {
   176  				want := metadata.MD{
   177  					"meta-bar":     {"bar-val1", "bar-val2"},
   178  					"user-agent":   {"x/y a/b"},
   179  					"meta-foo":     {"foo-val"},
   180  					"content-type": {"application/grpc"},
   181  				}
   182  
   183  				if !reflect.DeepEqual(ht.headerMD, want) {
   184  					return fmt.Errorf("metdata = %#v; want %#v", ht.headerMD, want)
   185  				}
   186  				return nil
   187  			},
   188  		},
   189  	}
   190  
   191  	for _, tt := range tests {
   192  		rw := newTestHandlerResponseWriter()
   193  		if tt.modrw != nil {
   194  			rw = tt.modrw(rw)
   195  		}
   196  		got, gotErr := NewServerHandlerTransport(rw, tt.req, nil)
   197  		if (gotErr != nil) != (tt.wantErr != "") || (gotErr != nil && gotErr.Error() != tt.wantErr) {
   198  			t.Errorf("%s: error = %q; want %q", tt.name, gotErr.Error(), tt.wantErr)
   199  			continue
   200  		}
   201  		if gotErr != nil {
   202  			continue
   203  		}
   204  		if tt.check != nil {
   205  			if err := tt.check(got.(*serverHandlerTransport), &tt); err != nil {
   206  				t.Errorf("%s: %v", tt.name, err)
   207  			}
   208  		}
   209  	}
   210  }
   211  
   212  type testHandlerResponseWriter struct {
   213  	*httptest.ResponseRecorder
   214  	closeNotify chan bool
   215  }
   216  
   217  func (w testHandlerResponseWriter) CloseNotify() <-chan bool { return w.closeNotify }
   218  func (w testHandlerResponseWriter) Flush()                   {}
   219  
   220  func newTestHandlerResponseWriter() http.ResponseWriter {
   221  	return testHandlerResponseWriter{
   222  		ResponseRecorder: httptest.NewRecorder(),
   223  		closeNotify:      make(chan bool, 1),
   224  	}
   225  }
   226  
   227  type handleStreamTest struct {
   228  	t     *testing.T
   229  	bodyw *io.PipeWriter
   230  	rw    testHandlerResponseWriter
   231  	ht    *serverHandlerTransport
   232  }
   233  
   234  func newHandleStreamTest(t *testing.T) *handleStreamTest {
   235  	bodyr, bodyw := io.Pipe()
   236  	req := &http.Request{
   237  		ProtoMajor: 2,
   238  		Method:     "POST",
   239  		Header: http.Header{
   240  			"Content-Type": {"application/grpc"},
   241  		},
   242  		URL: &url.URL{
   243  			Path: "/service/foo.bar",
   244  		},
   245  		Body: bodyr,
   246  	}
   247  	rw := newTestHandlerResponseWriter().(testHandlerResponseWriter)
   248  	ht, err := NewServerHandlerTransport(rw, req, nil)
   249  	if err != nil {
   250  		t.Fatal(err)
   251  	}
   252  	return &handleStreamTest{
   253  		t:     t,
   254  		bodyw: bodyw,
   255  		ht:    ht.(*serverHandlerTransport),
   256  		rw:    rw,
   257  	}
   258  }
   259  
   260  func (s) TestHandlerTransport_HandleStreams(t *testing.T) {
   261  	st := newHandleStreamTest(t)
   262  	handleStream := func(s *Stream) {
   263  		if want := "/service/foo.bar"; s.method != want {
   264  			t.Errorf("stream method = %q; want %q", s.method, want)
   265  		}
   266  
   267  		err := s.SetHeader(metadata.Pairs("custom-header", "Custom header value"))
   268  		if err != nil {
   269  			t.Error(err)
   270  		}
   271  		err = s.SetTrailer(metadata.Pairs("custom-trailer", "Custom trailer value"))
   272  		if err != nil {
   273  			t.Error(err)
   274  		}
   275  
   276  		md := metadata.Pairs("custom-header", "Another custom header value")
   277  		err = s.SendHeader(md)
   278  		delete(md, "custom-header")
   279  		if err != nil {
   280  			t.Error(err)
   281  		}
   282  
   283  		err = s.SetHeader(metadata.Pairs("too-late", "Header value that should be ignored"))
   284  		if err == nil {
   285  			t.Error("expected SetHeader call after SendHeader to fail")
   286  		}
   287  		err = s.SendHeader(metadata.Pairs("too-late", "This header value should be ignored as well"))
   288  		if err == nil {
   289  			t.Error("expected second SendHeader call to fail")
   290  		}
   291  
   292  		st.bodyw.Close() // no body
   293  		st.ht.WriteStatus(s, status.New(codes.OK, ""))
   294  	}
   295  	st.ht.HandleStreams(
   296  		func(s *Stream) { go handleStream(s) },
   297  		func(ctx context.Context, method string) context.Context { return ctx },
   298  	)
   299  	wantHeader := http.Header{
   300  		"Date":          {},
   301  		"Content-Type":  {"application/grpc"},
   302  		"Trailer":       {"Grpc-Status", "Grpc-Message", "Grpc-Status-Details-Bin"},
   303  		"Custom-Header": {"Custom header value", "Another custom header value"},
   304  	}
   305  	wantTrailer := http.Header{
   306  		"Grpc-Status":    {"0"},
   307  		"Custom-Trailer": {"Custom trailer value"},
   308  	}
   309  	checkHeaderAndTrailer(t, st.rw, wantHeader, wantTrailer)
   310  }
   311  
   312  // Tests that codes.Unimplemented will close the body, per comment in handler_server.go.
   313  func (s) TestHandlerTransport_HandleStreams_Unimplemented(t *testing.T) {
   314  	handleStreamCloseBodyTest(t, codes.Unimplemented, "thingy is unimplemented")
   315  }
   316  
   317  // Tests that codes.InvalidArgument will close the body, per comment in handler_server.go.
   318  func (s) TestHandlerTransport_HandleStreams_InvalidArgument(t *testing.T) {
   319  	handleStreamCloseBodyTest(t, codes.InvalidArgument, "bad arg")
   320  }
   321  
   322  func handleStreamCloseBodyTest(t *testing.T, statusCode codes.Code, msg string) {
   323  	st := newHandleStreamTest(t)
   324  
   325  	handleStream := func(s *Stream) {
   326  		st.ht.WriteStatus(s, status.New(statusCode, msg))
   327  	}
   328  	st.ht.HandleStreams(
   329  		func(s *Stream) { go handleStream(s) },
   330  		func(ctx context.Context, method string) context.Context { return ctx },
   331  	)
   332  	wantHeader := http.Header{
   333  		"Date":         {},
   334  		"Content-Type": {"application/grpc"},
   335  		"Trailer":      {"Grpc-Status", "Grpc-Message", "Grpc-Status-Details-Bin"},
   336  	}
   337  	wantTrailer := http.Header{
   338  		"Grpc-Status":  {fmt.Sprint(uint32(statusCode))},
   339  		"Grpc-Message": {encodeGrpcMessage(msg)},
   340  	}
   341  	checkHeaderAndTrailer(t, st.rw, wantHeader, wantTrailer)
   342  }
   343  
   344  func (s) TestHandlerTransport_HandleStreams_Timeout(t *testing.T) {
   345  	bodyr, bodyw := io.Pipe()
   346  	req := &http.Request{
   347  		ProtoMajor: 2,
   348  		Method:     "POST",
   349  		Header: http.Header{
   350  			"Content-Type": {"application/grpc"},
   351  			"Grpc-Timeout": {"200m"},
   352  		},
   353  		URL: &url.URL{
   354  			Path: "/service/foo.bar",
   355  		},
   356  		Body: bodyr,
   357  	}
   358  	rw := newTestHandlerResponseWriter().(testHandlerResponseWriter)
   359  	ht, err := NewServerHandlerTransport(rw, req, nil)
   360  	if err != nil {
   361  		t.Fatal(err)
   362  	}
   363  	runStream := func(s *Stream) {
   364  		defer bodyw.Close()
   365  		select {
   366  		case <-s.ctx.Done():
   367  		case <-time.After(5 * time.Second):
   368  			t.Errorf("timeout waiting for ctx.Done")
   369  			return
   370  		}
   371  		err := s.ctx.Err()
   372  		if err != context.DeadlineExceeded {
   373  			t.Errorf("ctx.Err = %v; want %v", err, context.DeadlineExceeded)
   374  			return
   375  		}
   376  		ht.WriteStatus(s, status.New(codes.DeadlineExceeded, "too slow"))
   377  	}
   378  	ht.HandleStreams(
   379  		func(s *Stream) { go runStream(s) },
   380  		func(ctx context.Context, method string) context.Context { return ctx },
   381  	)
   382  	wantHeader := http.Header{
   383  		"Date":         {},
   384  		"Content-Type": {"application/grpc"},
   385  		"Trailer":      {"Grpc-Status", "Grpc-Message", "Grpc-Status-Details-Bin"},
   386  	}
   387  	wantTrailer := http.Header{
   388  		"Grpc-Status":  {"4"},
   389  		"Grpc-Message": {encodeGrpcMessage("too slow")},
   390  	}
   391  	checkHeaderAndTrailer(t, rw, wantHeader, wantTrailer)
   392  }
   393  
   394  // TestHandlerTransport_HandleStreams_MultiWriteStatus ensures that
   395  // concurrent "WriteStatus"s do not panic writing to closed "writes" channel.
   396  func (s) TestHandlerTransport_HandleStreams_MultiWriteStatus(t *testing.T) {
   397  	testHandlerTransportHandleStreams(t, func(st *handleStreamTest, s *Stream) {
   398  		if want := "/service/foo.bar"; s.method != want {
   399  			t.Errorf("stream method = %q; want %q", s.method, want)
   400  		}
   401  		st.bodyw.Close() // no body
   402  
   403  		var wg sync.WaitGroup
   404  		wg.Add(5)
   405  		for i := 0; i < 5; i++ {
   406  			go func() {
   407  				defer wg.Done()
   408  				st.ht.WriteStatus(s, status.New(codes.OK, ""))
   409  			}()
   410  		}
   411  		wg.Wait()
   412  	})
   413  }
   414  
   415  // TestHandlerTransport_HandleStreams_WriteStatusWrite ensures that "Write"
   416  // following "WriteStatus" does not panic writing to closed "writes" channel.
   417  func (s) TestHandlerTransport_HandleStreams_WriteStatusWrite(t *testing.T) {
   418  	testHandlerTransportHandleStreams(t, func(st *handleStreamTest, s *Stream) {
   419  		if want := "/service/foo.bar"; s.method != want {
   420  			t.Errorf("stream method = %q; want %q", s.method, want)
   421  		}
   422  		st.bodyw.Close() // no body
   423  
   424  		st.ht.WriteStatus(s, status.New(codes.OK, ""))
   425  		st.ht.Write(s, []byte("hdr"), []byte("data"), &Options{})
   426  	})
   427  }
   428  
   429  func testHandlerTransportHandleStreams(t *testing.T, handleStream func(st *handleStreamTest, s *Stream)) {
   430  	st := newHandleStreamTest(t)
   431  	st.ht.HandleStreams(
   432  		func(s *Stream) { go handleStream(st, s) },
   433  		func(ctx context.Context, method string) context.Context { return ctx },
   434  	)
   435  }
   436  
   437  func (s) TestHandlerTransport_HandleStreams_ErrDetails(t *testing.T) {
   438  	errDetails := []proto.Message{
   439  		&epb.RetryInfo{
   440  			RetryDelay: &dpb.Duration{Seconds: 60},
   441  		},
   442  		&epb.ResourceInfo{
   443  			ResourceType: "foo bar",
   444  			ResourceName: "service.foo.bar",
   445  			Owner:        "User",
   446  		},
   447  	}
   448  
   449  	statusCode := codes.ResourceExhausted
   450  	msg := "you are being throttled"
   451  	st, err := status.New(statusCode, msg).WithDetails(errDetails...)
   452  	if err != nil {
   453  		t.Fatal(err)
   454  	}
   455  
   456  	stBytes, err := proto.Marshal(st.Proto())
   457  	if err != nil {
   458  		t.Fatal(err)
   459  	}
   460  
   461  	hst := newHandleStreamTest(t)
   462  	handleStream := func(s *Stream) {
   463  		hst.ht.WriteStatus(s, st)
   464  	}
   465  	hst.ht.HandleStreams(
   466  		func(s *Stream) { go handleStream(s) },
   467  		func(ctx context.Context, method string) context.Context { return ctx },
   468  	)
   469  	wantHeader := http.Header{
   470  		"Date":         {},
   471  		"Content-Type": {"application/grpc"},
   472  		"Trailer":      {"Grpc-Status", "Grpc-Message", "Grpc-Status-Details-Bin"},
   473  	}
   474  	wantTrailer := http.Header{
   475  		"Grpc-Status":             {fmt.Sprint(uint32(statusCode))},
   476  		"Grpc-Message":            {encodeGrpcMessage(msg)},
   477  		"Grpc-Status-Details-Bin": {encodeBinHeader(stBytes)},
   478  	}
   479  
   480  	checkHeaderAndTrailer(t, hst.rw, wantHeader, wantTrailer)
   481  }
   482  
   483  // checkHeaderAndTrailer checks that the resulting header and trailer matches the expectation.
   484  func checkHeaderAndTrailer(t *testing.T, rw testHandlerResponseWriter, wantHeader, wantTrailer http.Header) {
   485  	// For trailer-only responses, the trailer values might be reported as part of the Header. They will however
   486  	// be present in Trailer in either case. Hence, normalize the header by removing all trailer values.
   487  	actualHeader := cloneHeader(rw.Result().Header)
   488  	for _, trailerKey := range actualHeader["Trailer"] {
   489  		actualHeader.Del(trailerKey)
   490  	}
   491  
   492  	if !reflect.DeepEqual(actualHeader, wantHeader) {
   493  		t.Errorf("Header mismatch.\n got: %#v\n want: %#v", actualHeader, wantHeader)
   494  	}
   495  	if actualTrailer := rw.Result().Trailer; !reflect.DeepEqual(actualTrailer, wantTrailer) {
   496  		t.Errorf("Trailer mismatch.\n got: %#v\n want: %#v", actualTrailer, wantTrailer)
   497  	}
   498  }
   499  
   500  // cloneHeader performs a deep clone of an http.Header, since the (http.Header).Clone() method was only added in
   501  // Go 1.13.
   502  func cloneHeader(hdr http.Header) http.Header {
   503  	if hdr == nil {
   504  		return nil
   505  	}
   506  
   507  	hdrClone := make(http.Header, len(hdr))
   508  
   509  	for k, vv := range hdr {
   510  		vvClone := make([]string, len(vv))
   511  		copy(vvClone, vv)
   512  		hdrClone[k] = vvClone
   513  	}
   514  
   515  	return hdrClone
   516  }