go.uber.org/yarpc@v1.72.1/encoding/protobuf/v2/outbound_test.go (about) 1 // Copyright (c) 2022 Uber Technologies, Inc. 2 // 3 // Permission is hereby granted, free of charge, to any person obtaining a copy 4 // of this software and associated documentation files (the "Software"), to deal 5 // in the Software without restriction, including without limitation the rights 6 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 // copies of the Software, and to permit persons to whom the Software is 8 // furnished to do so, subject to the following conditions: 9 // 10 // The above copyright notice and this permission notice shall be included in 11 // all copies or substantial portions of the Software. 12 // 13 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19 // THE SOFTWARE. 20 21 package v2_test 22 23 import ( 24 "context" 25 "errors" 26 "io/ioutil" 27 "testing" 28 29 "github.com/stretchr/testify/assert" 30 "github.com/stretchr/testify/require" 31 "go.uber.org/yarpc/api/transport" 32 "go.uber.org/yarpc/encoding/protobuf/internal/testpb/v2" 33 "go.uber.org/yarpc/encoding/protobuf/v2" 34 "go.uber.org/yarpc/yarpctest" 35 "google.golang.org/protobuf/proto" 36 "google.golang.org/protobuf/reflect/protoreflect" 37 "google.golang.org/protobuf/runtime/protoimpl" 38 "google.golang.org/protobuf/types/known/anypb" 39 ) 40 41 func TestOutboundWithAnyResolver(t *testing.T) { 42 const testValue = "foo-bar-baz" 43 newReq := func() proto.Message { return &testpb.TestMessage{} } 44 customAnyResolver := &testAnyResolver{NewMessage: &testpb.TestMessage{}} 45 tests := []struct { 46 name string 47 anyURL string 48 resolver v2.AnyResolver 49 wantErr bool 50 }{ 51 { 52 name: "nothing custom", 53 anyURL: "uber.yarpc.encoding.protobuf.TestMessage", 54 }, 55 { 56 name: "custom resolver", 57 anyURL: "uber.yarpc.encoding.protobuf.TestMessage", 58 resolver: customAnyResolver, 59 }, 60 { 61 name: "custom resolver, custom URL", 62 anyURL: "foo.bar.baz", 63 resolver: customAnyResolver, 64 }, 65 { 66 name: "custom URL, no resolver", 67 anyURL: "foo.bar.baz", 68 wantErr: true, 69 }, 70 } 71 72 for _, tt := range tests { 73 t.Run(tt.name, func(t *testing.T) { 74 trans := yarpctest.NewFakeTransport() 75 // outbound that echos the body back 76 out := trans.NewOutbound(nil, yarpctest.OutboundCallOverride( 77 yarpctest.OutboundCallable(func(ctx context.Context, req *transport.Request) (*transport.Response, error) { 78 return &transport.Response{Body: ioutil.NopCloser(req.Body)}, nil 79 }), 80 )) 81 82 client := v2.NewClient(v2.ClientParams{ 83 ClientConfig: &transport.OutboundConfig{ 84 Outbounds: transport.Outbounds{ 85 Unary: out, 86 }, 87 }, 88 AnyResolver: tt.resolver, 89 Options: []v2.ClientOption{v2.UseJSON}, 90 }) 91 92 testMessage := &testpb.TestMessage{Value: testValue} 93 94 // convert to an Any so that the marshaller will use the custom resolver 95 anyMsg, err := anypb.New(testMessage) 96 require.NoError(t, err) 97 anyMsg.TypeUrl = tt.anyURL // update to custom URL 98 99 gotMessage, err := client.Call(context.Background(), "", anyMsg, newReq) 100 if tt.wantErr { 101 require.Error(t, err) 102 } else { 103 require.NoError(t, err) 104 assert.True(t, proto.Equal(testMessage, gotMessage)) // we expect the actual type behind the Any 105 } 106 }) 107 } 108 } 109 110 func TestOutboundWithKnownProtoMsg(t *testing.T) { 111 t.Run("known proto message", func(t *testing.T) { 112 newReq := func() proto.Message { return &testpb.TestMessage{} } 113 trans := yarpctest.NewFakeTransport() 114 // outbound that echos the body back 115 out := trans.NewOutbound(nil, yarpctest.OutboundCallOverride( 116 yarpctest.OutboundCallable(func(ctx context.Context, req *transport.Request) (*transport.Response, error) { 117 return &transport.Response{Body: ioutil.NopCloser(req.Body)}, nil 118 }), 119 )) 120 121 client := v2.NewClient(v2.ClientParams{ 122 ClientConfig: &transport.OutboundConfig{ 123 Outbounds: transport.Outbounds{ 124 Unary: out, 125 }, 126 }, 127 Options: []v2.ClientOption{}, 128 }) 129 130 testMessage := &testpb.TestMessage{Value: "foo-bar-baz"} 131 gotMessage, err := client.Call(context.Background(), "", testMessage, newReq) 132 require.NoError(t, err) 133 assert.True(t, proto.Equal(testMessage, gotMessage)) 134 135 }) 136 } 137 138 func TestOutboundWithAnyProtobufMsg(t *testing.T) { 139 t.Run("any message without resolver", func(t *testing.T) { 140 newReq := func() proto.Message { return &anypb.Any{} } 141 trans := yarpctest.NewFakeTransport() 142 // outbound that echos the body back 143 out := trans.NewOutbound(nil, yarpctest.OutboundCallOverride( 144 yarpctest.OutboundCallable(func(ctx context.Context, req *transport.Request) (*transport.Response, error) { 145 return &transport.Response{Body: ioutil.NopCloser(req.Body)}, nil 146 }), 147 )) 148 149 client := v2.NewClient(v2.ClientParams{ 150 ClientConfig: &transport.OutboundConfig{ 151 Outbounds: transport.Outbounds{ 152 Unary: out, 153 }, 154 }, 155 Options: []v2.ClientOption{}, 156 }) 157 158 testMessage := &testpb.TestMessage{Value: "foo-bar-baz"} 159 anyMsg, err := anypb.New(testMessage) 160 require.NoError(t, err) 161 162 gotMessage, err := client.Call(context.Background(), "", anyMsg, newReq) 163 require.NoError(t, err) 164 returnMsg := &testpb.TestMessage{} 165 anypb.UnmarshalTo(gotMessage.(*anypb.Any), returnMsg, proto.UnmarshalOptions{}) 166 assert.True(t, proto.Equal(testMessage, returnMsg)) 167 168 }) 169 } 170 171 type testAnyResolver struct { 172 NewMessage proto.Message 173 } 174 175 func (r testAnyResolver) FindMessageByName(message protoreflect.FullName) (protoreflect.MessageType, error) { 176 return r.FindMessageByURL(string(message)) 177 } 178 179 func (r testAnyResolver) FindMessageByURL(url string) (protoreflect.MessageType, error) { 180 // Custom resolver for TestMessage resolve with both global registered or custom URL 181 if r.NewMessage != nil { 182 if url == "uber.yarpc.encoding.protobuf.TestMessage" || url == "foo.bar.baz" { 183 return protoimpl.X.MessageTypeOf(r.NewMessage), nil 184 } 185 } 186 return nil, errors.New("test resolver is not initialized") 187 } 188 189 func (r testAnyResolver) FindExtensionByName(field protoreflect.FullName) (protoreflect.ExtensionType, error) { 190 return nil, nil 191 } 192 193 func (r testAnyResolver) FindExtensionByNumber(message protoreflect.FullName, field protoreflect.FieldNumber) (protoreflect.ExtensionType, error) { 194 return nil, nil 195 }