github.com/emcfarlane/larking@v0.0.0-20220605172417-1704b45ee6c3/rules_test.go (about)

     1  // Copyright 2021 Edward McFarlane. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package larking
     6  
     7  import (
     8  	"bytes"
     9  	"context"
    10  	"fmt"
    11  	"io/ioutil"
    12  	"net"
    13  	"net/http"
    14  	"net/http/httptest"
    15  	"strings"
    16  	"testing"
    17  
    18  	"github.com/google/go-cmp/cmp"
    19  	"golang.org/x/sync/errgroup"
    20  	"google.golang.org/genproto/googleapis/api/httpbody"
    21  	"google.golang.org/genproto/googleapis/rpc/status"
    22  	"google.golang.org/grpc"
    23  	"google.golang.org/grpc/credentials/insecure"
    24  	"google.golang.org/grpc/metadata"
    25  	"google.golang.org/grpc/reflection"
    26  	"google.golang.org/protobuf/encoding/protojson"
    27  	"google.golang.org/protobuf/proto"
    28  	"google.golang.org/protobuf/testing/protocmp"
    29  	"google.golang.org/protobuf/types/known/durationpb"
    30  	"google.golang.org/protobuf/types/known/emptypb"
    31  	"google.golang.org/protobuf/types/known/fieldmaskpb"
    32  	"google.golang.org/protobuf/types/known/timestamppb"
    33  	"google.golang.org/protobuf/types/known/wrapperspb"
    34  
    35  	"github.com/emcfarlane/larking/testpb"
    36  )
    37  
    38  type in struct {
    39  	method string
    40  	msg    proto.Message
    41  	// TODO: headers?
    42  }
    43  
    44  type out struct {
    45  	msg proto.Message
    46  	err error
    47  	// TODO: trailers?
    48  }
    49  
    50  // overrides is a map of an array of in/out msgs.
    51  type overrides struct {
    52  	testing.TB
    53  	header string
    54  	inouts []interface{}
    55  }
    56  
    57  // unary context is used to check if this request should be overriden.
    58  func (o *overrides) unary() grpc.UnaryServerInterceptor {
    59  	return func(
    60  		ctx context.Context,
    61  		req interface{},
    62  		info *grpc.UnaryServerInfo,
    63  		handler grpc.UnaryHandler,
    64  	) (interface{}, error) {
    65  		md, _ := metadata.FromIncomingContext(ctx)
    66  		if hdr := md[o.header]; len(hdr) == 0 || info.FullMethod != hdr[0] {
    67  			return handler(ctx, req)
    68  		}
    69  		in, out := o.inouts[0].(in), o.inouts[1].(out)
    70  
    71  		msg := req.(proto.Message)
    72  		if in.method != "" && info.FullMethod != in.method {
    73  			err := fmt.Errorf("grpc expected %s, got %s", in.method, info.FullMethod)
    74  			o.Log(err)
    75  			return nil, err
    76  		}
    77  
    78  		diff := cmp.Diff(msg, in.msg, protocmp.Transform())
    79  		if diff != "" {
    80  			o.Log(diff)
    81  			return nil, fmt.Errorf("message didn't match")
    82  		}
    83  		return out.msg, out.err
    84  	}
    85  }
    86  
    87  func (o *overrides) unaryOption() grpc.ServerOption {
    88  	return grpc.UnaryInterceptor(o.unary())
    89  }
    90  
    91  // stream context is used to check if this request should be overriden.
    92  func (o *overrides) stream() grpc.StreamServerInterceptor {
    93  	return func(
    94  		srv interface{},
    95  		stream grpc.ServerStream,
    96  		info *grpc.StreamServerInfo,
    97  		handler grpc.StreamHandler,
    98  	) (err error) {
    99  		md, _ := metadata.FromIncomingContext(stream.Context())
   100  		if hdr := md[o.header]; len(hdr) == 0 || info.FullMethod != hdr[0] {
   101  			return handler(srv, stream)
   102  		}
   103  
   104  		for i, v := range o.inouts {
   105  			switch v := v.(type) {
   106  			case in:
   107  				//if v.method != "" && info.FullMethod != v.method {
   108  				//	return fmt.Errorf("grpc expected %s, got %s", v.method, info.FullMethod)
   109  				//}
   110  
   111  				msg := v.msg.ProtoReflect().New().Interface()
   112  				if err := stream.RecvMsg(msg); err != nil {
   113  					o.Log(err)
   114  					return err
   115  				}
   116  				diff := cmp.Diff(msg, v.msg, protocmp.Transform())
   117  				if diff != "" {
   118  					o.Log(diff)
   119  					return fmt.Errorf("message didn't match")
   120  				}
   121  
   122  			case out:
   123  				if i == 0 {
   124  					return fmt.Errorf("unexpected first message type: %T", v)
   125  				}
   126  
   127  				if err := v.err; err != nil {
   128  					o.Log(err)
   129  					return err // application
   130  				}
   131  				if err := stream.SendMsg(v.msg); err != nil {
   132  					o.Log(err)
   133  					return err
   134  				}
   135  			default:
   136  				return fmt.Errorf("unknown override type: %T", v)
   137  			}
   138  		}
   139  		return nil
   140  	}
   141  }
   142  
   143  func (o *overrides) streamOption() grpc.ServerOption {
   144  	return grpc.StreamInterceptor(o.stream())
   145  }
   146  
   147  func (o *overrides) reset(t testing.TB, header string, msgs []interface{}) {
   148  	o.TB = t
   149  	o.header = header
   150  	o.inouts = append(o.inouts[:0], msgs...)
   151  }
   152  
   153  func TestMessageServer(t *testing.T) {
   154  
   155  	// Create test server.
   156  	ms := &testpb.UnimplementedMessagingServer{}
   157  	fs := &testpb.UnimplementedFilesServer{}
   158  	js := &testpb.UnimplementedWellKnownServer{}
   159  
   160  	o := new(overrides)
   161  	gs := grpc.NewServer(o.unaryOption(), o.streamOption())
   162  
   163  	testpb.RegisterMessagingServer(gs, ms)
   164  	testpb.RegisterFilesServer(gs, fs)
   165  	testpb.RegisterWellKnownServer(gs, js)
   166  	reflection.Register(gs)
   167  
   168  	lis, err := net.Listen("tcp", "localhost:0")
   169  	if err != nil {
   170  		t.Fatalf("failed to listen: %v", err)
   171  	}
   172  	defer lis.Close()
   173  
   174  	var g errgroup.Group
   175  	defer func() {
   176  		if err := g.Wait(); err != nil {
   177  			t.Fatal(err)
   178  		}
   179  	}()
   180  
   181  	g.Go(func() error {
   182  		return gs.Serve(lis)
   183  	})
   184  	defer gs.Stop()
   185  
   186  	// Create client.
   187  	conn, err := grpc.Dial(
   188  		lis.Addr().String(),
   189  		grpc.WithTransportCredentials(insecure.NewCredentials()),
   190  	)
   191  	if err != nil {
   192  		t.Fatalf("cannot connect to server: %v", err)
   193  	}
   194  	defer conn.Close()
   195  
   196  	h, err := NewMux()
   197  	if err != nil {
   198  		t.Fatal(err)
   199  	}
   200  	if err := h.RegisterConn(context.Background(), conn); err != nil {
   201  		t.Fatal(err)
   202  	}
   203  
   204  	type want struct {
   205  		statusCode int
   206  		body       []byte        // either
   207  		msg        proto.Message // or
   208  		// TODO: headers
   209  	}
   210  
   211  	// TODO: compare http.Response output
   212  	tests := []struct {
   213  		name string
   214  		req  *http.Request
   215  		in   in
   216  		out  out
   217  		want want
   218  	}{{
   219  		name: "first",
   220  		req:  httptest.NewRequest(http.MethodGet, "/v1/messages/name/hello", nil),
   221  		in: in{
   222  			method: "/larking.testpb.Messaging/GetMessageOne",
   223  			msg:    &testpb.GetMessageRequestOne{Name: "name/hello"},
   224  		},
   225  		out: out{
   226  			msg: &testpb.Message{Text: "hello, world!"},
   227  		},
   228  		want: want{
   229  			statusCode: 200,
   230  			msg:        &testpb.Message{Text: "hello, world!"},
   231  		},
   232  	}, {
   233  		name: "sub.subfield",
   234  		req:  httptest.NewRequest(http.MethodGet, "/v1/messages/123456?revision=2&sub.subfield=foo", nil),
   235  		in: in{
   236  			method: "/larking.testpb.Messaging/GetMessageTwo",
   237  			msg: &testpb.GetMessageRequestTwo{
   238  				MessageId: "123456",
   239  				Revision:  2,
   240  				Sub: &testpb.GetMessageRequestTwo_SubMessage{
   241  					Subfield: "foo",
   242  				},
   243  			},
   244  		},
   245  		out: out{
   246  			msg: &testpb.Message{Text: "hello, query params!"},
   247  		},
   248  		want: want{
   249  			statusCode: 200,
   250  			msg:        &testpb.Message{Text: "hello, query params!"},
   251  		},
   252  	}, {
   253  		name: "additional_bindings1",
   254  		req:  httptest.NewRequest(http.MethodGet, "/v1/users/usr_123/messages?message_id=msg_123&revision=2", nil),
   255  		in: in{
   256  			method: "/larking.testpb.Messaging/GetMessageTwo",
   257  			msg: &testpb.GetMessageRequestTwo{
   258  				MessageId: "msg_123",
   259  				Revision:  2,
   260  				UserId:    "usr_123",
   261  			},
   262  		},
   263  		out: out{
   264  			msg: &testpb.Message{Text: "hello, additional bindings!"},
   265  		},
   266  		want: want{
   267  			statusCode: 200,
   268  			msg:        &testpb.Message{Text: "hello, additional bindings!"},
   269  		},
   270  	}, {
   271  		name: "additional_bindings2",
   272  		req:  httptest.NewRequest(http.MethodGet, "/v1/users/usr_123/messages/msg_123?revision=2", nil),
   273  		in: in{
   274  			method: "/larking.testpb.Messaging/GetMessageTwo",
   275  			msg: &testpb.GetMessageRequestTwo{
   276  				MessageId: "msg_123",
   277  				Revision:  2,
   278  				UserId:    "usr_123",
   279  			},
   280  		},
   281  		out: out{
   282  			msg: &testpb.Message{Text: "hello, additional bindings!"},
   283  		},
   284  		want: want{
   285  			statusCode: 200,
   286  			msg:        &testpb.Message{Text: "hello, additional bindings!"},
   287  		},
   288  	}, {
   289  		name: "patch",
   290  		req: httptest.NewRequest(http.MethodPatch, "/v1/messages/msg_123", strings.NewReader(
   291  			`{ "text": "Hi!" }`,
   292  		)),
   293  		in: in{
   294  			method: "/larking.testpb.Messaging/UpdateMessage",
   295  			msg: &testpb.UpdateMessageRequestOne{
   296  				MessageId: "msg_123",
   297  				Message: &testpb.Message{
   298  					Text: "Hi!",
   299  				},
   300  			},
   301  		},
   302  		out: out{
   303  			msg: &testpb.Message{Text: "hello, patch!"},
   304  		},
   305  		want: want{
   306  			statusCode: 200,
   307  			msg:        &testpb.Message{Text: "hello, patch!"},
   308  		},
   309  	}, {
   310  		name: "action",
   311  		req: httptest.NewRequest(http.MethodPost, "/v1/action:cancel", strings.NewReader(
   312  			`{ "message_id": "123" }`,
   313  		)),
   314  		in: in{
   315  			method: "/larking.testpb.Messaging/Action",
   316  			msg:    &testpb.Message{MessageId: "123", Text: "action"},
   317  		},
   318  		out: out{
   319  			msg: &emptypb.Empty{},
   320  		},
   321  		want: want{
   322  			statusCode: 200,
   323  			msg:        &emptypb.Empty{},
   324  		},
   325  	}, {
   326  		name: "actionSegment",
   327  		req: httptest.NewRequest(http.MethodPost, "/v1/name:clear", strings.NewReader(
   328  			`{ "message_id": "123" }`,
   329  		)),
   330  		in: in{
   331  			method: "/larking.testpb.Messaging/ActionSegment",
   332  			msg:    &testpb.Message{MessageId: "123", Text: "name"},
   333  		},
   334  		out: out{
   335  			msg: &emptypb.Empty{},
   336  		},
   337  		want: want{
   338  			statusCode: 200,
   339  			msg:        &emptypb.Empty{},
   340  		},
   341  	}, {
   342  		name: "actionResource",
   343  		req:  httptest.NewRequest(http.MethodGet, "/v1/actions/123:fetch", nil),
   344  		in: in{
   345  			method: "/larking.testpb.Messaging/ActionResource",
   346  			msg:    &testpb.Message{Text: "actions/123"},
   347  		},
   348  		out: out{
   349  			msg: &emptypb.Empty{},
   350  		},
   351  		want: want{
   352  			statusCode: 200,
   353  			msg:        &emptypb.Empty{},
   354  		},
   355  	}, {
   356  		name: "actionSegments",
   357  		req: httptest.NewRequest(http.MethodPost, "/v1/name/id:watch", strings.NewReader(
   358  			`{ "message_id": "123" }`,
   359  		)),
   360  		in: in{
   361  			method: "/larking.testpb.Messaging/ActionSegments",
   362  			msg:    &testpb.Message{MessageId: "123", Text: "name/id"},
   363  		},
   364  		out: out{
   365  			msg: &emptypb.Empty{},
   366  		},
   367  		want: want{
   368  			statusCode: 200,
   369  			msg:        &emptypb.Empty{},
   370  		},
   371  	}, {
   372  		name: "batchGet",
   373  		req: httptest.NewRequest(http.MethodGet, "/v3/events:batchGet", strings.NewReader(
   374  			`{}`,
   375  		)),
   376  		in: in{
   377  			method: "/larking.testpb.Messaging/BatchGet",
   378  			msg:    &emptypb.Empty{},
   379  		},
   380  		out: out{
   381  			msg: &emptypb.Empty{},
   382  		},
   383  		want: want{
   384  			statusCode: 200,
   385  			msg:        &emptypb.Empty{},
   386  		},
   387  	}, {
   388  		name: "404",
   389  		req:  httptest.NewRequest(http.MethodGet, "/error404", nil),
   390  		want: want{
   391  			statusCode: 404,
   392  			msg:        &status.Status{Code: 5, Message: "not found"},
   393  		},
   394  	}, {
   395  		name: "cat.jpg",
   396  		req: func() *http.Request {
   397  			r := httptest.NewRequest(
   398  				http.MethodPost, "/files/cat.jpg",
   399  				strings.NewReader("cat"),
   400  			)
   401  			r.Header.Set("Content-Type", "image/jpeg")
   402  			return r
   403  		}(),
   404  		in: in{
   405  			method: "/larking.testpb.Files/UploadDownload",
   406  			msg: &testpb.UploadFileRequest{
   407  				Filename: "cat.jpg",
   408  				File: &httpbody.HttpBody{
   409  					ContentType: "image/jpeg",
   410  					Data:        []byte("cat"),
   411  				},
   412  			},
   413  		},
   414  		out: out{
   415  			msg: &httpbody.HttpBody{
   416  				ContentType: "image/jpeg",
   417  				Data:        []byte("cat"),
   418  			},
   419  		},
   420  		want: want{
   421  			statusCode: 200,
   422  			body:       []byte("cat"),
   423  		},
   424  
   425  		/*}, {
   426  		name: "large_cat.jpg",
   427  		req: func() *http.Request {
   428  			r := httptest.NewRequest(
   429  				http.MethodPost, "/files/large/cat.jpg",
   430  				strings.NewReader("cat"),
   431  			)
   432  			r.Header.Set("Content-Type", "image/jpeg")
   433  			return r
   434  		}(),
   435  		in: in{
   436  			method: "/larking.testpb.Files/UploadDownload",
   437  			msg: &testpb.UploadFileRequest{
   438  				Filename: "cat.jpg",
   439  				File: &httpbody.HttpBody{
   440  					ContentType: "image/jpeg",
   441  					Data:        []byte("cat"),
   442  				},
   443  			},
   444  		},
   445  		out: out{
   446  			msg: &httpbody.HttpBody{
   447  				ContentType: "image/jpeg",
   448  				Data:        []byte("cat"),
   449  			},
   450  		},
   451  		want: want{
   452  			statusCode: 200,
   453  			body:       []byte("cat"),
   454  		},*/
   455  	}, {
   456  		name: "wellknown_scalars",
   457  		req: httptest.NewRequest(
   458  			http.MethodGet,
   459  			"/v1/wellknown?"+
   460  				"timestamp=\"2017-01-15T01:30:15.01Z\"&"+
   461  				"duration=\"3.000001s\"&"+
   462  				"bool_value=true&"+
   463  				"int32_value=1&"+
   464  				"int64_value=2&"+
   465  				"uint32_value=3&"+
   466  				"uint64_value=4&"+
   467  				"float_value=5.5&"+
   468  				"double_value=6.6&"+
   469  				"bytes_value=aGVsbG8&"+ // base64URL
   470  				"string_value=hello&"+
   471  				"field_mask=\"user.displayName,photo\"",
   472  			nil,
   473  		),
   474  		in: in{
   475  			method: "/larking.testpb.WellKnown/Check",
   476  			msg: &testpb.Scalars{
   477  				Timestamp: &timestamppb.Timestamp{
   478  					Seconds: 1484443815,
   479  					Nanos:   10000000,
   480  				},
   481  				Duration: &durationpb.Duration{
   482  					Seconds: 3,
   483  					Nanos:   1000,
   484  				},
   485  				BoolValue:   &wrapperspb.BoolValue{Value: true},
   486  				Int32Value:  &wrapperspb.Int32Value{Value: 1},
   487  				Int64Value:  &wrapperspb.Int64Value{Value: 2},
   488  				Uint32Value: &wrapperspb.UInt32Value{Value: 3},
   489  				Uint64Value: &wrapperspb.UInt64Value{Value: 4},
   490  				FloatValue:  &wrapperspb.FloatValue{Value: 5.5},
   491  				DoubleValue: &wrapperspb.DoubleValue{Value: 6.6},
   492  				BytesValue:  &wrapperspb.BytesValue{Value: []byte("hello")},
   493  				StringValue: &wrapperspb.StringValue{Value: "hello"},
   494  				FieldMask: &fieldmaskpb.FieldMask{
   495  					Paths: []string{
   496  						"user.display_name", // JSON name converted to field name
   497  						"photo",
   498  					},
   499  				},
   500  			},
   501  		},
   502  		out: out{
   503  			msg: &emptypb.Empty{},
   504  		},
   505  		want: want{
   506  			statusCode: 200,
   507  			msg:        &emptypb.Empty{},
   508  		},
   509  	}, {
   510  		name: "variable_one",
   511  		req:  httptest.NewRequest(http.MethodGet, "/version/one", nil),
   512  		in: in{
   513  			method: "/larking.testpb.Messaging/VariableOne",
   514  			msg:    &testpb.Message{Text: "version"},
   515  		},
   516  		out: out{
   517  			msg: &emptypb.Empty{},
   518  		},
   519  		want: want{
   520  			statusCode: 200,
   521  			msg:        &emptypb.Empty{},
   522  		},
   523  	}, {
   524  		name: "variable_two",
   525  		req:  httptest.NewRequest(http.MethodGet, "/version/two", nil),
   526  		in: in{
   527  			method: "/larking.testpb.Messaging/VariableTwo",
   528  			msg:    &testpb.Message{Text: "version"},
   529  		},
   530  		out: out{
   531  			msg: &emptypb.Empty{},
   532  		},
   533  		want: want{
   534  			statusCode: 200,
   535  			msg:        &emptypb.Empty{},
   536  		},
   537  	}, {
   538  		name: "shelf_name_get",
   539  		req:  httptest.NewRequest(http.MethodGet, "/v1/shelves/shelf1", nil),
   540  		in: in{
   541  			method: "/larking.testpb.Messaging/GetShelf",
   542  			msg:    &testpb.GetShelfRequest{Name: "shelves/shelf1"},
   543  		},
   544  		out: out{
   545  			msg: &testpb.Shelf{Name: "shelves/shelf1"},
   546  		},
   547  		want: want{
   548  			statusCode: 200,
   549  			msg:        &testpb.Shelf{Name: "shelves/shelf1"},
   550  		},
   551  	}, {
   552  		name: "book_name_get",
   553  		req:  httptest.NewRequest(http.MethodGet, "/v1/shelves/shelf1/books/book2", nil),
   554  		in: in{
   555  			method: "/larking.testpb.Messaging/GetBook",
   556  			msg:    &testpb.GetBookRequest{Name: "shelves/shelf1/books/book2"},
   557  		},
   558  		out: out{
   559  			msg: &testpb.Book{Name: "shelves/shelf1/books/book2"},
   560  		},
   561  		want: want{
   562  			statusCode: 200,
   563  			msg:        &testpb.Book{Name: "shelves/shelf1/books/book2"},
   564  		},
   565  	}, {
   566  		name: "book_name_create",
   567  		req: httptest.NewRequest(http.MethodPost, "/v1/shelves/shelf1/books", strings.NewReader(
   568  			`{ "name": "book3" }`,
   569  		)),
   570  		in: in{
   571  			method: "/larking.testpb.Messaging/CreateBook",
   572  			msg: &testpb.CreateBookRequest{
   573  				Parent: "shelves/shelf1",
   574  				Book: &testpb.Book{
   575  					Name: "book3",
   576  				},
   577  			},
   578  		},
   579  		out: out{
   580  			msg: &testpb.Book{Name: "book3"},
   581  		},
   582  		want: want{
   583  			statusCode: 200,
   584  			msg:        &testpb.Book{Name: "book3"},
   585  		},
   586  	}, {
   587  		name: "book_name_update",
   588  		req: httptest.NewRequest(http.MethodPatch, `/v1/shelves/shelf1/books/book2?update_mask="name,title"`, strings.NewReader(
   589  			`{ "title": "Lord of the Rings" }`,
   590  		)),
   591  		in: in{
   592  			method: "/larking.testpb.Messaging/UpdateBook",
   593  			msg: &testpb.UpdateBookRequest{
   594  				Book: &testpb.Book{
   595  					Name:  "shelves/shelf1/books/book2",
   596  					Title: "Lord of the Rings",
   597  				},
   598  				UpdateMask: &fieldmaskpb.FieldMask{
   599  					Paths: []string{
   600  						"name",
   601  						"title",
   602  					},
   603  				},
   604  			},
   605  		},
   606  		out: out{
   607  			msg: &testpb.Book{
   608  				Name:  "shelves/shelf1/books/book2",
   609  				Title: "Lord of the Rings",
   610  			},
   611  		},
   612  		want: want{
   613  			statusCode: 200,
   614  			msg: &testpb.Book{
   615  				Name:  "shelves/shelf1/books/book2",
   616  				Title: "Lord of the Rings",
   617  			},
   618  		},
   619  	}}
   620  
   621  	opts := cmp.Options{protocmp.Transform()}
   622  
   623  	for _, tt := range tests {
   624  		t.Run(tt.name, func(t *testing.T) {
   625  			o.reset(t, "http-test", []interface{}{tt.in, tt.out})
   626  
   627  			req := tt.req
   628  			req.Header["test"] = []string{tt.in.method}
   629  
   630  			w := httptest.NewRecorder()
   631  			h.ServeHTTP(w, req)
   632  			resp := w.Result()
   633  
   634  			b, err := ioutil.ReadAll(resp.Body)
   635  			if err != nil {
   636  				t.Fatal(err)
   637  			}
   638  
   639  			if sc := tt.want.statusCode; sc != resp.StatusCode {
   640  				t.Errorf("expected %d got %d", tt.want.statusCode, resp.StatusCode)
   641  				var msg status.Status
   642  				if err := protojson.Unmarshal(b, &msg); err != nil {
   643  					t.Error(err, string(b))
   644  					return
   645  				}
   646  				t.Error("status.code", msg.Code)
   647  				t.Error("status.message", msg.Message)
   648  				return
   649  			}
   650  
   651  			if tt.want.body != nil {
   652  				if !bytes.Equal(b, tt.want.body) {
   653  					t.Errorf("body %s != %s", tt.want.body, b)
   654  				}
   655  			}
   656  
   657  			if tt.want.msg != nil {
   658  				msg := proto.Clone(tt.want.msg)
   659  				if err := protojson.Unmarshal(b, msg); err != nil {
   660  					t.Error(err, string(b))
   661  					return
   662  				}
   663  
   664  				diff := cmp.Diff(msg, tt.want.msg, opts...)
   665  				if diff != "" {
   666  					t.Error(diff)
   667  				}
   668  			}
   669  		})
   670  	}
   671  }