go.uber.org/yarpc@v1.72.1/encoding/protobuf/v2/outbound_test.go (about)

     1  // Copyright (c) 2022 Uber Technologies, Inc.
     2  //
     3  // Permission is hereby granted, free of charge, to any person obtaining a copy
     4  // of this software and associated documentation files (the "Software"), to deal
     5  // in the Software without restriction, including without limitation the rights
     6  // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
     7  // copies of the Software, and to permit persons to whom the Software is
     8  // furnished to do so, subject to the following conditions:
     9  //
    10  // The above copyright notice and this permission notice shall be included in
    11  // all copies or substantial portions of the Software.
    12  //
    13  // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    14  // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    15  // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    16  // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    17  // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    18  // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
    19  // THE SOFTWARE.
    20  
    21  package v2_test
    22  
    23  import (
    24  	"context"
    25  	"errors"
    26  	"io/ioutil"
    27  	"testing"
    28  
    29  	"github.com/stretchr/testify/assert"
    30  	"github.com/stretchr/testify/require"
    31  	"go.uber.org/yarpc/api/transport"
    32  	"go.uber.org/yarpc/encoding/protobuf/internal/testpb/v2"
    33  	"go.uber.org/yarpc/encoding/protobuf/v2"
    34  	"go.uber.org/yarpc/yarpctest"
    35  	"google.golang.org/protobuf/proto"
    36  	"google.golang.org/protobuf/reflect/protoreflect"
    37  	"google.golang.org/protobuf/runtime/protoimpl"
    38  	"google.golang.org/protobuf/types/known/anypb"
    39  )
    40  
    41  func TestOutboundWithAnyResolver(t *testing.T) {
    42  	const testValue = "foo-bar-baz"
    43  	newReq := func() proto.Message { return &testpb.TestMessage{} }
    44  	customAnyResolver := &testAnyResolver{NewMessage: &testpb.TestMessage{}}
    45  	tests := []struct {
    46  		name     string
    47  		anyURL   string
    48  		resolver v2.AnyResolver
    49  		wantErr  bool
    50  	}{
    51  		{
    52  			name:   "nothing custom",
    53  			anyURL: "uber.yarpc.encoding.protobuf.TestMessage",
    54  		},
    55  		{
    56  			name:     "custom resolver",
    57  			anyURL:   "uber.yarpc.encoding.protobuf.TestMessage",
    58  			resolver: customAnyResolver,
    59  		},
    60  		{
    61  			name:     "custom resolver, custom URL",
    62  			anyURL:   "foo.bar.baz",
    63  			resolver: customAnyResolver,
    64  		},
    65  		{
    66  			name:    "custom URL, no resolver",
    67  			anyURL:  "foo.bar.baz",
    68  			wantErr: true,
    69  		},
    70  	}
    71  
    72  	for _, tt := range tests {
    73  		t.Run(tt.name, func(t *testing.T) {
    74  			trans := yarpctest.NewFakeTransport()
    75  			// outbound that echos the body back
    76  			out := trans.NewOutbound(nil, yarpctest.OutboundCallOverride(
    77  				yarpctest.OutboundCallable(func(ctx context.Context, req *transport.Request) (*transport.Response, error) {
    78  					return &transport.Response{Body: ioutil.NopCloser(req.Body)}, nil
    79  				}),
    80  			))
    81  
    82  			client := v2.NewClient(v2.ClientParams{
    83  				ClientConfig: &transport.OutboundConfig{
    84  					Outbounds: transport.Outbounds{
    85  						Unary: out,
    86  					},
    87  				},
    88  				AnyResolver: tt.resolver,
    89  				Options:     []v2.ClientOption{v2.UseJSON},
    90  			})
    91  
    92  			testMessage := &testpb.TestMessage{Value: testValue}
    93  
    94  			// convert to an Any so that the marshaller will use the custom resolver
    95  			anyMsg, err := anypb.New(testMessage)
    96  			require.NoError(t, err)
    97  			anyMsg.TypeUrl = tt.anyURL // update to custom URL
    98  
    99  			gotMessage, err := client.Call(context.Background(), "", anyMsg, newReq)
   100  			if tt.wantErr {
   101  				require.Error(t, err)
   102  			} else {
   103  				require.NoError(t, err)
   104  				assert.True(t, proto.Equal(testMessage, gotMessage)) // we expect the actual type behind the Any
   105  			}
   106  		})
   107  	}
   108  }
   109  
   110  func TestOutboundWithKnownProtoMsg(t *testing.T) {
   111  	t.Run("known proto message", func(t *testing.T) {
   112  		newReq := func() proto.Message { return &testpb.TestMessage{} }
   113  		trans := yarpctest.NewFakeTransport()
   114  		// outbound that echos the body back
   115  		out := trans.NewOutbound(nil, yarpctest.OutboundCallOverride(
   116  			yarpctest.OutboundCallable(func(ctx context.Context, req *transport.Request) (*transport.Response, error) {
   117  				return &transport.Response{Body: ioutil.NopCloser(req.Body)}, nil
   118  			}),
   119  		))
   120  
   121  		client := v2.NewClient(v2.ClientParams{
   122  			ClientConfig: &transport.OutboundConfig{
   123  				Outbounds: transport.Outbounds{
   124  					Unary: out,
   125  				},
   126  			},
   127  			Options: []v2.ClientOption{},
   128  		})
   129  
   130  		testMessage := &testpb.TestMessage{Value: "foo-bar-baz"}
   131  		gotMessage, err := client.Call(context.Background(), "", testMessage, newReq)
   132  		require.NoError(t, err)
   133  		assert.True(t, proto.Equal(testMessage, gotMessage))
   134  
   135  	})
   136  }
   137  
   138  func TestOutboundWithAnyProtobufMsg(t *testing.T) {
   139  	t.Run("any message without resolver", func(t *testing.T) {
   140  		newReq := func() proto.Message { return &anypb.Any{} }
   141  		trans := yarpctest.NewFakeTransport()
   142  		// outbound that echos the body back
   143  		out := trans.NewOutbound(nil, yarpctest.OutboundCallOverride(
   144  			yarpctest.OutboundCallable(func(ctx context.Context, req *transport.Request) (*transport.Response, error) {
   145  				return &transport.Response{Body: ioutil.NopCloser(req.Body)}, nil
   146  			}),
   147  		))
   148  
   149  		client := v2.NewClient(v2.ClientParams{
   150  			ClientConfig: &transport.OutboundConfig{
   151  				Outbounds: transport.Outbounds{
   152  					Unary: out,
   153  				},
   154  			},
   155  			Options: []v2.ClientOption{},
   156  		})
   157  
   158  		testMessage := &testpb.TestMessage{Value: "foo-bar-baz"}
   159  		anyMsg, err := anypb.New(testMessage)
   160  		require.NoError(t, err)
   161  
   162  		gotMessage, err := client.Call(context.Background(), "", anyMsg, newReq)
   163  		require.NoError(t, err)
   164  		returnMsg := &testpb.TestMessage{}
   165  		anypb.UnmarshalTo(gotMessage.(*anypb.Any), returnMsg, proto.UnmarshalOptions{})
   166  		assert.True(t, proto.Equal(testMessage, returnMsg))
   167  
   168  	})
   169  }
   170  
   171  type testAnyResolver struct {
   172  	NewMessage proto.Message
   173  }
   174  
   175  func (r testAnyResolver) FindMessageByName(message protoreflect.FullName) (protoreflect.MessageType, error) {
   176  	return r.FindMessageByURL(string(message))
   177  }
   178  
   179  func (r testAnyResolver) FindMessageByURL(url string) (protoreflect.MessageType, error) {
   180  	// Custom resolver for TestMessage resolve with both global registered or custom URL
   181  	if r.NewMessage != nil {
   182  		if url == "uber.yarpc.encoding.protobuf.TestMessage" || url == "foo.bar.baz" {
   183  			return protoimpl.X.MessageTypeOf(r.NewMessage), nil
   184  		}
   185  	}
   186  	return nil, errors.New("test resolver is not initialized")
   187  }
   188  
   189  func (r testAnyResolver) FindExtensionByName(field protoreflect.FullName) (protoreflect.ExtensionType, error) {
   190  	return nil, nil
   191  }
   192  
   193  func (r testAnyResolver) FindExtensionByNumber(message protoreflect.FullName, field protoreflect.FieldNumber) (protoreflect.ExtensionType, error) {
   194  	return nil, nil
   195  }