go.uber.org/yarpc@v1.72.1/encoding/thrift/outbound_test.go (about)

     1  // Copyright (c) 2022 Uber Technologies, Inc.
     2  //
     3  // Permission is hereby granted, free of charge, to any person obtaining a copy
     4  // of this software and associated documentation files (the "Software"), to deal
     5  // in the Software without restriction, including without limitation the rights
     6  // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
     7  // copies of the Software, and to permit persons to whom the Software is
     8  // furnished to do so, subject to the following conditions:
     9  //
    10  // The above copyright notice and this permission notice shall be included in
    11  // all copies or substantial portions of the Software.
    12  //
    13  // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    14  // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    15  // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    16  // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    17  // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    18  // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
    19  // THE SOFTWARE.
    20  
    21  package thrift
    22  
    23  import (
    24  	"bytes"
    25  	"context"
    26  	"errors"
    27  	"io"
    28  	"io/ioutil"
    29  	"testing"
    30  
    31  	"github.com/golang/mock/gomock"
    32  	"github.com/stretchr/testify/assert"
    33  	"github.com/stretchr/testify/require"
    34  	"go.uber.org/thriftrw/envelope"
    35  	"go.uber.org/thriftrw/thrifttest"
    36  	"go.uber.org/thriftrw/wire"
    37  	"go.uber.org/yarpc/api/transport"
    38  	"go.uber.org/yarpc/api/transport/transporttest"
    39  	"go.uber.org/yarpc/internal/clientconfig"
    40  	"go.uber.org/yarpc/internal/testtime"
    41  	"go.uber.org/yarpc/pkg/procedure"
    42  )
    43  
    44  func valueptr(v wire.Value) *wire.Value { return &v }
    45  
    46  func TestClient(t *testing.T) {
    47  	tests := []struct {
    48  		desc                 string
    49  		giveRequestBody      envelope.Enveloper // outgoing request body
    50  		giveResponseEnvelope *wire.Envelope     // returned on DecodeEnveloped()
    51  		giveResponseBody     *wire.Value        // return on Decode()
    52  		clientOptions        []ClientOption
    53  
    54  		expectCall          bool           // whether outbound.Call is expected
    55  		wantRequestEnvelope *wire.Envelope // expect EncodeEnveloped(x)
    56  		wantRequestBody     *wire.Value    // expect Encode(x)
    57  		wantError           string         // whether an error is expected
    58  
    59  		responseBody io.ReadCloser
    60  	}{
    61  		{
    62  			desc:            "happy case",
    63  			clientOptions:   []ClientOption{Enveloped},
    64  			giveRequestBody: fakeEnveloper(wire.Call),
    65  			wantRequestEnvelope: &wire.Envelope{
    66  				Name:  "someMethod",
    67  				SeqID: 1,
    68  				Type:  wire.Call,
    69  				Value: wire.NewValueStruct(wire.Struct{}),
    70  			},
    71  			expectCall: true,
    72  			giveResponseEnvelope: &wire.Envelope{
    73  				Name:  "someMethod",
    74  				SeqID: 1,
    75  				Type:  wire.Reply,
    76  				Value: wire.NewValueStruct(wire.Struct{}),
    77  			},
    78  			responseBody: readCloser{bytes.NewReader([]byte("irrelevant"))},
    79  		},
    80  		{
    81  			desc:             "happy case without enveloping",
    82  			giveRequestBody:  fakeEnveloper(wire.Call),
    83  			wantRequestBody:  valueptr(wire.NewValueStruct(wire.Struct{})),
    84  			expectCall:       true,
    85  			giveResponseBody: valueptr(wire.NewValueStruct(wire.Struct{})),
    86  			responseBody:     readCloser{bytes.NewReader([]byte("irrelevant"))},
    87  		},
    88  		{
    89  			desc:            "wrong envelope type for request",
    90  			clientOptions:   []ClientOption{Enveloped},
    91  			giveRequestBody: fakeEnveloper(wire.Reply),
    92  			wantError: `failed to encode "thrift" request body for procedure ` +
    93  				`"MyService::someMethod" of service "service": unexpected envelope type: Reply`,
    94  			responseBody: readCloser{bytes.NewReader([]byte("irrelevant"))},
    95  		},
    96  		{
    97  			desc:            "TApplicationException",
    98  			clientOptions:   []ClientOption{Enveloped},
    99  			giveRequestBody: fakeEnveloper(wire.Call),
   100  			wantRequestEnvelope: &wire.Envelope{
   101  				Name:  "someMethod",
   102  				SeqID: 1,
   103  				Type:  wire.Call,
   104  				Value: wire.NewValueStruct(wire.Struct{}),
   105  			},
   106  			expectCall: true,
   107  			giveResponseEnvelope: &wire.Envelope{
   108  				Name:  "someMethod",
   109  				SeqID: 1,
   110  				Type:  wire.Exception,
   111  				Value: wire.NewValueStruct(wire.Struct{Fields: []wire.Field{
   112  					{ID: 1, Value: wire.NewValueString("great sadness")},
   113  					{ID: 2, Value: wire.NewValueI32(7)},
   114  				}}),
   115  			},
   116  			wantError: `thrift request to procedure "MyService::someMethod" of ` +
   117  				`service "service" encountered an internal failure: ` +
   118  				"TApplicationException{Message: great sadness, Type: PROTOCOL_ERROR}",
   119  			responseBody: readCloser{bytes.NewReader([]byte("irrelevant"))},
   120  		},
   121  		{
   122  			desc:            "wrong envelope type for response",
   123  			clientOptions:   []ClientOption{Enveloped},
   124  			giveRequestBody: fakeEnveloper(wire.Call),
   125  			wantRequestEnvelope: &wire.Envelope{
   126  				Name:  "someMethod",
   127  				SeqID: 1,
   128  				Type:  wire.Call,
   129  				Value: wire.NewValueStruct(wire.Struct{}),
   130  			},
   131  			expectCall: true,
   132  			giveResponseEnvelope: &wire.Envelope{
   133  				Name:  "someMethod",
   134  				SeqID: 1,
   135  				Type:  wire.Call,
   136  				Value: wire.NewValueStruct(wire.Struct{}),
   137  			},
   138  			wantError: `failed to decode "thrift" response body for procedure ` +
   139  				`"MyService::someMethod" of service "service": unexpected envelope type: Call`,
   140  			responseBody: ioutil.NopCloser(bytes.NewReader([]byte("irrelevant"))),
   141  		},
   142  	}
   143  
   144  	for _, tt := range tests {
   145  		mockCtrl := gomock.NewController(t)
   146  		defer mockCtrl.Finish()
   147  
   148  		proto := thrifttest.NewMockProtocol(mockCtrl)
   149  
   150  		if tt.wantRequestEnvelope != nil {
   151  			proto.EXPECT().EncodeEnveloped(*tt.wantRequestEnvelope, gomock.Any()).
   152  				Do(func(_ wire.Envelope, w io.Writer) {
   153  					_, err := w.Write([]byte("irrelevant"))
   154  					require.NoError(t, err, "Write() failed")
   155  				}).Return(nil)
   156  		}
   157  
   158  		if tt.wantRequestBody != nil {
   159  			proto.EXPECT().Encode(*tt.wantRequestBody, gomock.Any()).
   160  				Do(func(_ wire.Value, w io.Writer) {
   161  					_, err := w.Write([]byte("irrelevant"))
   162  					require.NoError(t, err, "Write() failed")
   163  				}).Return(nil)
   164  		}
   165  
   166  		ctx, cancel := context.WithTimeout(context.Background(), testtime.Second)
   167  		defer cancel()
   168  
   169  		trans := transporttest.NewMockUnaryOutbound(mockCtrl)
   170  		if tt.expectCall {
   171  			trans.EXPECT().Call(ctx,
   172  				transporttest.NewRequestMatcher(t, &transport.Request{
   173  					Caller:    "caller",
   174  					Service:   "service",
   175  					Encoding:  Encoding,
   176  					Procedure: "MyService::someMethod",
   177  					Body:      bytes.NewReader([]byte("irrelevant")),
   178  				}),
   179  			).Return(&transport.Response{
   180  				Body: tt.responseBody,
   181  			}, nil)
   182  		}
   183  
   184  		if tt.giveResponseEnvelope != nil {
   185  			proto.EXPECT().DecodeEnveloped(gomock.Any()).Return(*tt.giveResponseEnvelope, nil)
   186  		}
   187  
   188  		if tt.giveResponseBody != nil {
   189  			proto.EXPECT().Decode(gomock.Any(), wire.TStruct).Return(*tt.giveResponseBody, nil)
   190  		}
   191  
   192  		opts := tt.clientOptions
   193  		opts = append(opts, Protocol(proto))
   194  		c := New(Config{
   195  			Service: "MyService",
   196  			ClientConfig: clientconfig.MultiOutbound("caller", "service",
   197  				transport.Outbounds{
   198  					Unary: trans,
   199  				}),
   200  		}, opts...)
   201  
   202  		_, err := c.Call(ctx, tt.giveRequestBody)
   203  		if tt.wantError != "" {
   204  			if assert.Error(t, err, "%v: expected failure", tt.desc) {
   205  				assert.Contains(t, err.Error(), tt.wantError, "%v: error mismatch", tt.desc)
   206  			}
   207  		} else {
   208  			assert.NoError(t, err, "%v: expected success", tt.desc)
   209  		}
   210  	}
   211  }
   212  
   213  type successAck struct{}
   214  
   215  func (a successAck) String() string {
   216  	return "success"
   217  }
   218  
   219  func TestClientOneway(t *testing.T) {
   220  	caller, service, procedureName := "caller", "MyService", "someMethod"
   221  
   222  	tests := []struct {
   223  		desc            string
   224  		giveRequestBody envelope.Enveloper // outgoing request body
   225  		clientOptions   []ClientOption
   226  
   227  		expectCall          bool           // whether outbound.Call is expected
   228  		wantRequestEnvelope *wire.Envelope // expect EncodeEnveloped(x)
   229  		wantRequestBody     *wire.Value    // expect Encode(x)
   230  		wantError           string         // whether an error is expected
   231  	}{
   232  		{
   233  			desc:            "happy case",
   234  			giveRequestBody: fakeEnveloper(wire.Call),
   235  			clientOptions:   []ClientOption{Enveloped},
   236  
   237  			expectCall: true,
   238  			wantRequestEnvelope: &wire.Envelope{
   239  				Name:  procedureName,
   240  				SeqID: 1,
   241  				Type:  wire.Call,
   242  				Value: wire.NewValueStruct(wire.Struct{}),
   243  			},
   244  		},
   245  		{
   246  			desc:            "happy case without enveloping",
   247  			giveRequestBody: fakeEnveloper(wire.Call),
   248  
   249  			expectCall:      true,
   250  			wantRequestBody: valueptr(wire.NewValueStruct(wire.Struct{})),
   251  		},
   252  		{
   253  			desc:            "wrong envelope type for request",
   254  			giveRequestBody: fakeEnveloper(wire.Reply),
   255  			clientOptions:   []ClientOption{Enveloped},
   256  
   257  			wantError: `failed to encode "thrift" request body for procedure ` +
   258  				`"MyService::someMethod" of service "MyService": unexpected envelope type: Reply`,
   259  		},
   260  	}
   261  
   262  	for _, tt := range tests {
   263  		mockCtrl := gomock.NewController(t)
   264  		defer mockCtrl.Finish()
   265  
   266  		proto := thrifttest.NewMockProtocol(mockCtrl)
   267  		bodyBytes := []byte("irrelevant")
   268  
   269  		if tt.wantRequestEnvelope != nil {
   270  			proto.EXPECT().EncodeEnveloped(*tt.wantRequestEnvelope, gomock.Any()).
   271  				Do(func(_ wire.Envelope, w io.Writer) {
   272  					_, err := w.Write(bodyBytes)
   273  					require.NoError(t, err, "Write() failed")
   274  				}).Return(nil)
   275  		}
   276  
   277  		if tt.wantRequestBody != nil {
   278  			proto.EXPECT().Encode(*tt.wantRequestBody, gomock.Any()).
   279  				Do(func(_ wire.Value, w io.Writer) {
   280  					_, err := w.Write(bodyBytes)
   281  					require.NoError(t, err, "Write() failed")
   282  				}).Return(nil)
   283  		}
   284  
   285  		ctx := context.Background()
   286  
   287  		onewayOutbound := transporttest.NewMockOnewayOutbound(mockCtrl)
   288  
   289  		requestMatcher := transporttest.NewRequestMatcher(t, &transport.Request{
   290  			Caller:    caller,
   291  			Service:   service,
   292  			Encoding:  Encoding,
   293  			Procedure: procedure.ToName(service, procedureName),
   294  			Body:      bytes.NewReader(bodyBytes),
   295  		})
   296  
   297  		if tt.expectCall {
   298  			if tt.wantError != "" {
   299  				onewayOutbound.
   300  					EXPECT().
   301  					CallOneway(ctx, requestMatcher).
   302  					Return(nil, errors.New(tt.wantError))
   303  			} else {
   304  				onewayOutbound.
   305  					EXPECT().
   306  					CallOneway(ctx, requestMatcher).
   307  					Return(&successAck{}, nil)
   308  			}
   309  		}
   310  		opts := tt.clientOptions
   311  		opts = append(opts, Protocol(proto))
   312  
   313  		c := New(Config{
   314  			Service: service,
   315  			ClientConfig: clientconfig.MultiOutbound(caller, service,
   316  				transport.Outbounds{
   317  					Oneway: onewayOutbound,
   318  				}),
   319  		}, opts...)
   320  
   321  		ack, err := c.CallOneway(ctx, tt.giveRequestBody)
   322  		if tt.wantError != "" {
   323  			if assert.Error(t, err, "%v: expected failure", tt.desc) {
   324  				assert.Contains(t, err.Error(), tt.wantError, "%v: error mismatch", tt.desc)
   325  			}
   326  		} else {
   327  			assert.NoError(t, err, "%v: expected success", tt.desc)
   328  			assert.Equal(t, "success", ack.String())
   329  		}
   330  	}
   331  }
   332  
   333  type readCloser struct {
   334  	*bytes.Reader
   335  }
   336  
   337  func (r readCloser) Close() error { return nil }