github.com/grpc-ecosystem/grpc-gateway/v2@v2.19.1/runtime/handler_test.go (about)

     1  package runtime_test
     2  
     3  import (
     4  	"context"
     5  	"io"
     6  	"net/http"
     7  	"net/http/httptest"
     8  	"reflect"
     9  	"testing"
    10  
    11  	"github.com/grpc-ecosystem/grpc-gateway/v2/runtime"
    12  	pb "github.com/grpc-ecosystem/grpc-gateway/v2/runtime/internal/examplepb"
    13  	"google.golang.org/grpc/codes"
    14  	"google.golang.org/grpc/metadata"
    15  	"google.golang.org/grpc/status"
    16  	"google.golang.org/protobuf/proto"
    17  )
    18  
    19  type fakeReponseBodyWrapper struct {
    20  	proto.Message
    21  }
    22  
    23  // XXX_ResponseBody returns id of SimpleMessage
    24  func (r fakeReponseBodyWrapper) XXX_ResponseBody() interface{} {
    25  	resp := r.Message.(*pb.SimpleMessage)
    26  	return resp.Id
    27  }
    28  
    29  func TestForwardResponseStream(t *testing.T) {
    30  	type msg struct {
    31  		pb  proto.Message
    32  		err error
    33  	}
    34  	tests := []struct {
    35  		name         string
    36  		msgs         []msg
    37  		statusCode   int
    38  		responseBody bool
    39  	}{{
    40  		name: "encoding",
    41  		msgs: []msg{
    42  			{&pb.SimpleMessage{Id: "One"}, nil},
    43  			{&pb.SimpleMessage{Id: "Two"}, nil},
    44  		},
    45  		statusCode: http.StatusOK,
    46  	}, {
    47  		name:       "empty",
    48  		statusCode: http.StatusOK,
    49  	}, {
    50  		name:       "error",
    51  		msgs:       []msg{{nil, status.Errorf(codes.OutOfRange, "400")}},
    52  		statusCode: http.StatusBadRequest,
    53  	}, {
    54  		name: "stream_error",
    55  		msgs: []msg{
    56  			{&pb.SimpleMessage{Id: "One"}, nil},
    57  			{nil, status.Errorf(codes.OutOfRange, "400")},
    58  		},
    59  		statusCode: http.StatusOK,
    60  	}, {
    61  		name: "response body stream case",
    62  		msgs: []msg{
    63  			{fakeReponseBodyWrapper{&pb.SimpleMessage{Id: "One"}}, nil},
    64  			{fakeReponseBodyWrapper{&pb.SimpleMessage{Id: "Two"}}, nil},
    65  		},
    66  		responseBody: true,
    67  		statusCode:   http.StatusOK,
    68  	}, {
    69  		name: "response body stream error case",
    70  		msgs: []msg{
    71  			{fakeReponseBodyWrapper{&pb.SimpleMessage{Id: "One"}}, nil},
    72  			{nil, status.Errorf(codes.OutOfRange, "400")},
    73  		},
    74  		responseBody: true,
    75  		statusCode:   http.StatusOK,
    76  	}}
    77  
    78  	newTestRecv := func(t *testing.T, msgs []msg) func() (proto.Message, error) {
    79  		var count int
    80  		return func() (proto.Message, error) {
    81  			if count == len(msgs) {
    82  				return nil, io.EOF
    83  			} else if count > len(msgs) {
    84  				t.Errorf("recv() called %d times for %d messages", count, len(msgs))
    85  			}
    86  			count++
    87  			msg := msgs[count-1]
    88  			return msg.pb, msg.err
    89  		}
    90  	}
    91  	ctx := runtime.NewServerMetadataContext(context.Background(), runtime.ServerMetadata{})
    92  	marshaler := &runtime.JSONPb{}
    93  	for _, tt := range tests {
    94  		t.Run(tt.name, func(t *testing.T) {
    95  			recv := newTestRecv(t, tt.msgs)
    96  			req := httptest.NewRequest("GET", "http://example.com/foo", nil)
    97  			resp := httptest.NewRecorder()
    98  
    99  			runtime.ForwardResponseStream(ctx, runtime.NewServeMux(), marshaler, resp, req, recv)
   100  
   101  			w := resp.Result()
   102  			if w.StatusCode != tt.statusCode {
   103  				t.Errorf("StatusCode %d want %d", w.StatusCode, tt.statusCode)
   104  			}
   105  			if h := w.Header.Get("Transfer-Encoding"); h != "chunked" {
   106  				t.Errorf("ForwardResponseStream missing header chunked")
   107  			}
   108  			body, err := io.ReadAll(w.Body)
   109  			if err != nil {
   110  				t.Errorf("Failed to read response body with %v", err)
   111  			}
   112  			w.Body.Close()
   113  			if len(body) > 0 && w.Header.Get("Content-Type") != "application/json" {
   114  				t.Errorf("Content-Type %s want application/json", w.Header.Get("Content-Type"))
   115  			}
   116  
   117  			var want []byte
   118  			for i, msg := range tt.msgs {
   119  				if msg.err != nil {
   120  					if i == 0 {
   121  						// Skip non-stream errors
   122  						t.Skip("checking error encodings")
   123  					}
   124  					delimiter := marshaler.Delimiter()
   125  					st := status.Convert(msg.err)
   126  					b, err := marshaler.Marshal(map[string]proto.Message{
   127  						"error": st.Proto(),
   128  					})
   129  					if err != nil {
   130  						t.Errorf("marshaler.Marshal() failed %v", err)
   131  					}
   132  					errBytes := body[len(want):]
   133  					if string(errBytes) != string(b)+string(delimiter) {
   134  						t.Errorf("ForwardResponseStream() = \"%s\" want \"%s\"", errBytes, b)
   135  					}
   136  
   137  					return
   138  				}
   139  
   140  				var b []byte
   141  
   142  				if tt.responseBody {
   143  					// responseBody interface is in runtime package and test is in runtime_test package. hence can't use responseBody directly
   144  					// So type casting to fakeReponseBodyWrapper struct to verify the data.
   145  					rb, ok := msg.pb.(fakeReponseBodyWrapper)
   146  					if !ok {
   147  						t.Errorf("stream responseBody failed %v", err)
   148  					}
   149  
   150  					b, err = marshaler.Marshal(map[string]interface{}{"result": rb.XXX_ResponseBody()})
   151  				} else {
   152  					b, err = marshaler.Marshal(map[string]interface{}{"result": msg.pb})
   153  				}
   154  
   155  				if err != nil {
   156  					t.Errorf("marshaler.Marshal() failed %v", err)
   157  				}
   158  				want = append(want, b...)
   159  				want = append(want, marshaler.Delimiter()...)
   160  			}
   161  
   162  			if string(body) != string(want) {
   163  				t.Errorf("ForwardResponseStream() = \"%s\" want \"%s\"", body, want)
   164  			}
   165  		})
   166  	}
   167  }
   168  
   169  // A custom marshaler implementation, that doesn't implement the delimited interface
   170  type CustomMarshaler struct {
   171  	m *runtime.JSONPb
   172  }
   173  
   174  func (c *CustomMarshaler) Marshal(v interface{}) ([]byte, error)      { return c.m.Marshal(v) }
   175  func (c *CustomMarshaler) Unmarshal(data []byte, v interface{}) error { return c.m.Unmarshal(data, v) }
   176  func (c *CustomMarshaler) NewDecoder(r io.Reader) runtime.Decoder     { return c.m.NewDecoder(r) }
   177  func (c *CustomMarshaler) NewEncoder(w io.Writer) runtime.Encoder     { return c.m.NewEncoder(w) }
   178  func (c *CustomMarshaler) ContentType(v interface{}) string           { return "Custom-Content-Type" }
   179  
   180  func TestForwardResponseStreamCustomMarshaler(t *testing.T) {
   181  	type msg struct {
   182  		pb  proto.Message
   183  		err error
   184  	}
   185  	tests := []struct {
   186  		name       string
   187  		msgs       []msg
   188  		statusCode int
   189  	}{{
   190  		name: "encoding",
   191  		msgs: []msg{
   192  			{&pb.SimpleMessage{Id: "One"}, nil},
   193  			{&pb.SimpleMessage{Id: "Two"}, nil},
   194  		},
   195  		statusCode: http.StatusOK,
   196  	}, {
   197  		name:       "empty",
   198  		statusCode: http.StatusOK,
   199  	}, {
   200  		name:       "error",
   201  		msgs:       []msg{{nil, status.Errorf(codes.OutOfRange, "400")}},
   202  		statusCode: http.StatusBadRequest,
   203  	}, {
   204  		name: "stream_error",
   205  		msgs: []msg{
   206  			{&pb.SimpleMessage{Id: "One"}, nil},
   207  			{nil, status.Errorf(codes.OutOfRange, "400")},
   208  		},
   209  		statusCode: http.StatusOK,
   210  	}}
   211  
   212  	newTestRecv := func(t *testing.T, msgs []msg) func() (proto.Message, error) {
   213  		var count int
   214  		return func() (proto.Message, error) {
   215  			if count == len(msgs) {
   216  				return nil, io.EOF
   217  			} else if count > len(msgs) {
   218  				t.Errorf("recv() called %d times for %d messages", count, len(msgs))
   219  			}
   220  			count++
   221  			msg := msgs[count-1]
   222  			return msg.pb, msg.err
   223  		}
   224  	}
   225  	ctx := runtime.NewServerMetadataContext(context.Background(), runtime.ServerMetadata{})
   226  	marshaler := &CustomMarshaler{&runtime.JSONPb{}}
   227  	for _, tt := range tests {
   228  		t.Run(tt.name, func(t *testing.T) {
   229  			recv := newTestRecv(t, tt.msgs)
   230  			req := httptest.NewRequest("GET", "http://example.com/foo", nil)
   231  			resp := httptest.NewRecorder()
   232  
   233  			runtime.ForwardResponseStream(ctx, runtime.NewServeMux(), marshaler, resp, req, recv)
   234  
   235  			w := resp.Result()
   236  			if w.StatusCode != tt.statusCode {
   237  				t.Errorf("StatusCode %d want %d", w.StatusCode, tt.statusCode)
   238  			}
   239  			if h := w.Header.Get("Transfer-Encoding"); h != "chunked" {
   240  				t.Errorf("ForwardResponseStream missing header chunked")
   241  			}
   242  			body, err := io.ReadAll(w.Body)
   243  			if err != nil {
   244  				t.Errorf("Failed to read response body with %v", err)
   245  			}
   246  			w.Body.Close()
   247  			if len(body) > 0 && w.Header.Get("Content-Type") != "Custom-Content-Type" {
   248  				t.Errorf("Content-Type %s want Custom-Content-Type", w.Header.Get("Content-Type"))
   249  			}
   250  
   251  			var want []byte
   252  			for _, msg := range tt.msgs {
   253  				if msg.err != nil {
   254  					t.Skip("checking erorr encodings")
   255  				}
   256  				b, err := marshaler.Marshal(map[string]proto.Message{"result": msg.pb})
   257  				if err != nil {
   258  					t.Errorf("marshaler.Marshal() failed %v", err)
   259  				}
   260  				want = append(want, b...)
   261  				want = append(want, "\n"...)
   262  			}
   263  
   264  			if string(body) != string(want) {
   265  				t.Errorf("ForwardResponseStream() = \"%s\" want \"%s\"", body, want)
   266  			}
   267  		})
   268  	}
   269  }
   270  
   271  func TestForwardResponseMessage(t *testing.T) {
   272  	msg := &pb.SimpleMessage{Id: "One"}
   273  	tests := []struct {
   274  		name        string
   275  		marshaler   runtime.Marshaler
   276  		contentType string
   277  	}{{
   278  		name:        "standard marshaler",
   279  		marshaler:   &runtime.JSONPb{},
   280  		contentType: "application/json",
   281  	}, {
   282  		name:        "httpbody marshaler",
   283  		marshaler:   &runtime.HTTPBodyMarshaler{&runtime.JSONPb{}},
   284  		contentType: "application/json",
   285  	}, {
   286  		name:        "custom marshaler",
   287  		marshaler:   &CustomMarshaler{&runtime.JSONPb{}},
   288  		contentType: "Custom-Content-Type",
   289  	}}
   290  
   291  	ctx := runtime.NewServerMetadataContext(context.Background(), runtime.ServerMetadata{})
   292  	for _, tt := range tests {
   293  		t.Run(tt.name, func(t *testing.T) {
   294  			req := httptest.NewRequest("GET", "http://example.com/foo", nil)
   295  			resp := httptest.NewRecorder()
   296  
   297  			runtime.ForwardResponseMessage(ctx, runtime.NewServeMux(), tt.marshaler, resp, req, msg)
   298  
   299  			w := resp.Result()
   300  			if w.StatusCode != http.StatusOK {
   301  				t.Errorf("StatusCode %d want %d", w.StatusCode, http.StatusOK)
   302  			}
   303  			if h := w.Header.Get("Content-Type"); h != tt.contentType {
   304  				t.Errorf("Content-Type %v want %v", h, tt.contentType)
   305  			}
   306  			body, err := io.ReadAll(w.Body)
   307  			if err != nil {
   308  				t.Errorf("Failed to read response body with %v", err)
   309  			}
   310  			w.Body.Close()
   311  
   312  			want, err := tt.marshaler.Marshal(msg)
   313  			if err != nil {
   314  				t.Errorf("marshaler.Marshal() failed %v", err)
   315  			}
   316  
   317  			if string(body) != string(want) {
   318  				t.Errorf("ForwardResponseMessage() = \"%s\" want \"%s\"", body, want)
   319  			}
   320  		})
   321  	}
   322  }
   323  
   324  func TestOutgoingHeaderMatcher(t *testing.T) {
   325  	t.Parallel()
   326  	msg := &pb.SimpleMessage{Id: "foo"}
   327  	for _, tc := range []struct {
   328  		name    string
   329  		md      runtime.ServerMetadata
   330  		headers http.Header
   331  		matcher runtime.HeaderMatcherFunc
   332  	}{
   333  		{
   334  			name: "default matcher",
   335  			md: runtime.ServerMetadata{
   336  				HeaderMD: metadata.Pairs(
   337  					"foo", "bar",
   338  					"baz", "qux",
   339  				),
   340  			},
   341  			headers: http.Header{
   342  				"Content-Type":      []string{"application/json"},
   343  				"Grpc-Metadata-Foo": []string{"bar"},
   344  				"Grpc-Metadata-Baz": []string{"qux"},
   345  			},
   346  		},
   347  		{
   348  			name: "custom matcher",
   349  			md: runtime.ServerMetadata{
   350  				HeaderMD: metadata.Pairs(
   351  					"foo", "bar",
   352  					"baz", "qux",
   353  				),
   354  			},
   355  			headers: http.Header{
   356  				"Content-Type": []string{"application/json"},
   357  				"Custom-Foo":   []string{"bar"},
   358  			},
   359  			matcher: func(key string) (string, bool) {
   360  				switch key {
   361  				case "foo":
   362  					return "custom-foo", true
   363  				default:
   364  					return "", false
   365  				}
   366  			},
   367  		},
   368  	} {
   369  		tc := tc
   370  		t.Run(tc.name, func(t *testing.T) {
   371  			t.Parallel()
   372  			ctx := runtime.NewServerMetadataContext(context.Background(), tc.md)
   373  
   374  			req := httptest.NewRequest("GET", "http://example.com/foo", nil)
   375  			resp := httptest.NewRecorder()
   376  
   377  			runtime.ForwardResponseMessage(ctx, runtime.NewServeMux(runtime.WithOutgoingHeaderMatcher(tc.matcher)), &runtime.JSONPb{}, resp, req, msg)
   378  
   379  			w := resp.Result()
   380  			defer w.Body.Close()
   381  			if w.StatusCode != http.StatusOK {
   382  				t.Fatalf("StatusCode %d want %d", w.StatusCode, http.StatusOK)
   383  			}
   384  
   385  			if !reflect.DeepEqual(w.Header, tc.headers) {
   386  				t.Fatalf("Header %v want %v", w.Header, tc.headers)
   387  			}
   388  		})
   389  	}
   390  }
   391  
   392  func TestOutgoingTrailerMatcher(t *testing.T) {
   393  	t.Parallel()
   394  	msg := &pb.SimpleMessage{Id: "foo"}
   395  	for _, tc := range []struct {
   396  		name    string
   397  		md      runtime.ServerMetadata
   398  		caller  http.Header
   399  		headers http.Header
   400  		trailer http.Header
   401  		matcher runtime.HeaderMatcherFunc
   402  	}{
   403  		{
   404  			name: "default matcher, caller accepts",
   405  			md: runtime.ServerMetadata{
   406  				TrailerMD: metadata.Pairs(
   407  					"foo", "bar",
   408  					"baz", "qux",
   409  				),
   410  			},
   411  			caller: http.Header{
   412  				"Te": []string{"trailers"},
   413  			},
   414  			headers: http.Header{
   415  				"Content-Type": []string{"application/json"},
   416  				"Trailer":      []string{"Grpc-Trailer-Foo,Grpc-Trailer-Baz"},
   417  			},
   418  			trailer: http.Header{
   419  				"Grpc-Trailer-Foo": []string{"bar"},
   420  				"Grpc-Trailer-Baz": []string{"qux"},
   421  			},
   422  		},
   423  		{
   424  			name: "default matcher, caller rejects",
   425  			md: runtime.ServerMetadata{
   426  				TrailerMD: metadata.Pairs(
   427  					"foo", "bar",
   428  					"baz", "qux",
   429  				),
   430  			},
   431  			headers: http.Header{
   432  				"Content-Type": []string{"application/json"},
   433  			},
   434  		},
   435  		{
   436  			name: "custom matcher",
   437  			md: runtime.ServerMetadata{
   438  				TrailerMD: metadata.Pairs(
   439  					"foo", "bar",
   440  					"baz", "qux",
   441  				),
   442  			},
   443  			caller: http.Header{
   444  				"Te": []string{"trailers"},
   445  			},
   446  			headers: http.Header{
   447  				"Content-Type": []string{"application/json"},
   448  				"Trailer":      []string{"Custom-Trailer-Foo"},
   449  			},
   450  			trailer: http.Header{
   451  				"Custom-Trailer-Foo": []string{"bar"},
   452  			},
   453  			matcher: func(key string) (string, bool) {
   454  				switch key {
   455  				case "foo":
   456  					return "custom-trailer-foo", true
   457  				default:
   458  					return "", false
   459  				}
   460  			},
   461  		},
   462  	} {
   463  		tc := tc
   464  		t.Run(tc.name, func(t *testing.T) {
   465  			t.Parallel()
   466  			ctx := runtime.NewServerMetadataContext(context.Background(), tc.md)
   467  
   468  			req := httptest.NewRequest("GET", "http://example.com/foo", nil)
   469  			req.Header = tc.caller
   470  			resp := httptest.NewRecorder()
   471  
   472  			runtime.ForwardResponseMessage(ctx, runtime.NewServeMux(runtime.WithOutgoingTrailerMatcher(tc.matcher)), &runtime.JSONPb{}, resp, req, msg)
   473  
   474  			w := resp.Result()
   475  			_, _ = io.Copy(io.Discard, w.Body)
   476  			defer w.Body.Close()
   477  			if w.StatusCode != http.StatusOK {
   478  				t.Fatalf("StatusCode %d want %d", w.StatusCode, http.StatusOK)
   479  			}
   480  
   481  			if !reflect.DeepEqual(w.Trailer, tc.trailer) {
   482  				t.Fatalf("Trailer %v want %v", w.Trailer, tc.trailer)
   483  			}
   484  		})
   485  	}
   486  }