go.uber.org/yarpc@v1.72.1/encoding/thrift/outbound_nowire_test.go (about)

     1  // Copyright (c) 2022 Uber Technologies, Inc.
     2  //
     3  // Permission is hereby granted, free of charge, to any person obtaining a copy
     4  // of this software and associated documentation files (the "Software"), to deal
     5  // in the Software without restriction, including without limitation the rights
     6  // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
     7  // copies of the Software, and to permit persons to whom the Software is
     8  // furnished to do so, subject to the following conditions:
     9  //
    10  // The above copyright notice and this permission notice shall be included in
    11  // all copies or substantial portions of the Software.
    12  //
    13  // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    14  // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    15  // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    16  // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    17  // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    18  // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
    19  // THE SOFTWARE.
    20  
    21  package thrift
    22  
    23  import (
    24  	"bytes"
    25  	"context"
    26  	"encoding/binary"
    27  	"errors"
    28  	"io"
    29  	"testing"
    30  
    31  	"github.com/golang/mock/gomock"
    32  	"github.com/stretchr/testify/assert"
    33  	"github.com/stretchr/testify/require"
    34  	"go.uber.org/thriftrw/protocol"
    35  	tbinary "go.uber.org/thriftrw/protocol/binary"
    36  	"go.uber.org/thriftrw/protocol/stream"
    37  	"go.uber.org/thriftrw/thrifttest"
    38  	"go.uber.org/thriftrw/thrifttest/streamtest"
    39  	"go.uber.org/thriftrw/wire"
    40  	"go.uber.org/yarpc/api/transport"
    41  	"go.uber.org/yarpc/api/transport/transporttest"
    42  	"go.uber.org/yarpc/internal/clientconfig"
    43  	"go.uber.org/yarpc/internal/testtime"
    44  )
    45  
    46  const _response = "response"
    47  
    48  func TestNoWireClientCall(t *testing.T) {
    49  	tests := []struct {
    50  		desc             string
    51  		giveRequestBody  stream.Enveloper
    52  		giveResponseBody string
    53  		clientOptions    []ClientOption
    54  
    55  		expectCall       bool
    56  		wantRequestBody  string
    57  		wantResponseBody string
    58  		wantError        string
    59  	}{
    60  		{
    61  			desc:             "positive case, without enveloping",
    62  			giveRequestBody:  fakeEnveloper(wire.Call),
    63  			giveResponseBody: encodeThriftString(t, _response),
    64  			expectCall:       true,
    65  			wantRequestBody:  encodeThriftString(t, _irrelevant),
    66  			wantResponseBody: _response,
    67  		},
    68  		{
    69  			desc:            "positive case, with enveloping",
    70  			giveRequestBody: fakeEnveloper(wire.Call),
    71  			giveResponseBody: encodeEnvelopeType(t, wire.Reply) +
    72  				encodeThriftString(t, "someMethod") +
    73  				encodeEnvelopeSeqID(t, 1) +
    74  				encodeThriftString(t, _response),
    75  			clientOptions: []ClientOption{Enveloped},
    76  			expectCall:    true,
    77  			wantRequestBody: encodeEnvelopeType(t, wire.Call) +
    78  				encodeThriftString(t, "someMethod") +
    79  				encodeEnvelopeSeqID(t, 1) +
    80  				encodeThriftString(t, _irrelevant),
    81  			wantResponseBody: _response,
    82  		},
    83  		{
    84  			desc:            "unexpected request envelope type",
    85  			giveRequestBody: fakeEnveloper(wire.Exception),
    86  			wantError:       `failed to encode "thrift" request body for procedure "MyService::someMethod" of service "service": unexpected envelope type: Exception`,
    87  		},
    88  		{
    89  			desc:            "response envelope exception (TApplicationException) decoding error",
    90  			giveRequestBody: fakeEnveloper(wire.Call),
    91  			clientOptions:   []ClientOption{Enveloped},
    92  			giveResponseBody: encodeEnvelopeType(t, wire.Exception) +
    93  				encodeThriftString(t, "someMethod") +
    94  				encodeEnvelopeSeqID(t, 1),
    95  			expectCall: true,
    96  			wantRequestBody: encodeEnvelopeType(t, wire.Call) +
    97  				encodeThriftString(t, "someMethod") +
    98  				encodeEnvelopeSeqID(t, 1) +
    99  				encodeThriftString(t, _irrelevant),
   100  			wantResponseBody: _response,
   101  			wantError:        `failed to decode "thrift" response body for procedure "MyService::someMethod" of service "service": unexpected EOF`,
   102  		},
   103  		{
   104  			desc:            "response envelope exception (TApplicationException) error",
   105  			giveRequestBody: fakeEnveloper(wire.Call),
   106  			clientOptions:   []ClientOption{Enveloped},
   107  			giveResponseBody: encodeEnvelopeType(t, wire.Exception) +
   108  				encodeThriftString(t, "someMethod") +
   109  				encodeEnvelopeSeqID(t, 1) +
   110  				encodeThriftString(t, _response),
   111  			expectCall: true,
   112  			wantRequestBody: encodeEnvelopeType(t, wire.Call) +
   113  				encodeThriftString(t, "someMethod") +
   114  				encodeEnvelopeSeqID(t, 1) +
   115  				encodeThriftString(t, _irrelevant),
   116  			wantResponseBody: _response,
   117  			wantError:        "encountered an internal failure: TApplicationException{}",
   118  		},
   119  		{
   120  			desc:            "unexpected response envelope type",
   121  			giveRequestBody: fakeEnveloper(wire.Call),
   122  			giveResponseBody: encodeEnvelopeType(t, wire.Call) +
   123  				encodeThriftString(t, "someMethod") +
   124  				encodeEnvelopeSeqID(t, 1) +
   125  				encodeThriftString(t, _response),
   126  			clientOptions: []ClientOption{Enveloped},
   127  			expectCall:    true,
   128  			wantRequestBody: encodeEnvelopeType(t, wire.Call) +
   129  				encodeThriftString(t, "someMethod") +
   130  				encodeEnvelopeSeqID(t, 1) +
   131  				encodeThriftString(t, _irrelevant),
   132  			wantResponseBody: _response,
   133  			wantError:        `failed to decode "thrift" response body for procedure "MyService::someMethod" of service "service": unexpected envelope type: Call`,
   134  		},
   135  	}
   136  
   137  	// This type aliasing is needed because it's not possible to embed two types
   138  	// with the same name without collision.
   139  	type streamProtocol = stream.Protocol
   140  	type fakeProtocol struct {
   141  		protocol.Protocol
   142  		streamProtocol
   143  	}
   144  
   145  	for _, tt := range tests {
   146  		t.Run(tt.desc, func(t *testing.T) {
   147  			mockCtrl := gomock.NewController(t)
   148  			defer mockCtrl.Finish()
   149  
   150  			sp := streamtest.NewMockProtocol(mockCtrl)
   151  			if tt.wantRequestBody != "" {
   152  				sp.EXPECT().Writer(gomock.Any()).
   153  					DoAndReturn(func(w io.Writer) stream.Writer {
   154  						return tbinary.Default.Writer(w)
   155  					})
   156  			}
   157  
   158  			if tt.wantResponseBody != "" {
   159  				sp.EXPECT().Reader(gomock.Any()).
   160  					DoAndReturn(func(r io.Reader) stream.Reader {
   161  						return tbinary.Default.Reader(r)
   162  					})
   163  			}
   164  
   165  			ctx, cancel := context.WithTimeout(context.Background(), testtime.Second)
   166  			defer cancel()
   167  
   168  			trans := transporttest.NewMockUnaryOutbound(mockCtrl)
   169  			if tt.expectCall {
   170  				trans.EXPECT().Call(gomock.Any(),
   171  					transporttest.NewRequestMatcher(t, &transport.Request{
   172  						Caller:    "caller",
   173  						Service:   "service",
   174  						Encoding:  Encoding,
   175  						Procedure: "MyService::someMethod",
   176  						Body:      bytes.NewReader([]byte(tt.wantRequestBody)),
   177  					}),
   178  				).Return(&transport.Response{
   179  					Body: readCloser{bytes.NewReader([]byte(tt.giveResponseBody))},
   180  				}, nil)
   181  			}
   182  
   183  			opts := tt.clientOptions
   184  			opts = append(opts, Protocol(&fakeProtocol{streamProtocol: sp}))
   185  			nwc := NewNoWire(Config{
   186  				Service: "MyService",
   187  				ClientConfig: clientconfig.MultiOutbound("caller", "service",
   188  					transport.Outbounds{
   189  						Unary: trans,
   190  					}),
   191  			}, opts...)
   192  
   193  			br := fakeBodyReader{}
   194  			err := nwc.Call(ctx, tt.giveRequestBody, &br)
   195  			if tt.wantError != "" {
   196  				require.Error(t, err)
   197  				assert.Contains(t, err.Error(), tt.wantError)
   198  				return
   199  			}
   200  
   201  			require.NoError(t, err)
   202  			assert.Equal(t, tt.wantResponseBody, br.body)
   203  		})
   204  	}
   205  }
   206  
   207  func TestNoWireClientOneway(t *testing.T) {
   208  	tests := []struct {
   209  		msg             string
   210  		giveRequestBody stream.Enveloper
   211  		clientOptions   []ClientOption
   212  
   213  		expectCall      bool
   214  		wantRequestBody string
   215  		wantError       string
   216  	}{
   217  		{
   218  			msg:             "positive case, without enveloping",
   219  			giveRequestBody: fakeEnveloper(wire.OneWay),
   220  			expectCall:      true,
   221  			wantRequestBody: encodeThriftString(t, _irrelevant),
   222  		},
   223  		{
   224  			msg:             "positive case, with enveloping",
   225  			giveRequestBody: fakeEnveloper(wire.OneWay),
   226  			clientOptions:   []ClientOption{Enveloped},
   227  			expectCall:      true,
   228  			wantRequestBody: encodeEnvelopeType(t, wire.OneWay) +
   229  				encodeThriftString(t, "someMethod") +
   230  				encodeEnvelopeSeqID(t, 1) +
   231  				encodeThriftString(t, _irrelevant),
   232  		},
   233  		{
   234  			msg:             "unexpected request envelope type",
   235  			giveRequestBody: fakeEnveloper(wire.Exception),
   236  			wantError:       `failed to encode "thrift" request body for procedure "MyService::someMethod" of service "service": unexpected envelope type: Exception`,
   237  		},
   238  		{
   239  			msg:             "oneway call error",
   240  			giveRequestBody: fakeEnveloper(wire.OneWay),
   241  			clientOptions:   []ClientOption{Enveloped},
   242  			expectCall:      true,
   243  			wantRequestBody: encodeEnvelopeType(t, wire.OneWay) +
   244  				encodeThriftString(t, "someMethod") +
   245  				encodeEnvelopeSeqID(t, 1) +
   246  				encodeThriftString(t, _irrelevant),
   247  			wantError: "oneway outbound error",
   248  		},
   249  	}
   250  
   251  	type streamProtocol = stream.Protocol
   252  	type fakeProtocol struct {
   253  		protocol.Protocol
   254  		streamProtocol
   255  	}
   256  
   257  	var _ stream.Protocol = &fakeProtocol{}
   258  
   259  	for _, tt := range tests {
   260  		t.Run(tt.msg, func(t *testing.T) {
   261  			mockCtrl := gomock.NewController(t)
   262  			defer mockCtrl.Finish()
   263  
   264  			sp := streamtest.NewMockProtocol(mockCtrl)
   265  			if tt.wantRequestBody != "" {
   266  				sp.EXPECT().Writer(gomock.Any()).
   267  					DoAndReturn(func(w io.Writer) stream.Writer {
   268  						return tbinary.Default.Writer(w)
   269  					})
   270  			}
   271  
   272  			ctx, cancel := context.WithTimeout(context.Background(), testtime.Second)
   273  			defer cancel()
   274  
   275  			oneway := transporttest.NewMockOnewayOutbound(mockCtrl)
   276  			if tt.expectCall {
   277  				if tt.wantError != "" {
   278  					oneway.EXPECT().CallOneway(gomock.Any(),
   279  						transporttest.NewRequestMatcher(t, &transport.Request{
   280  							Caller:    "caller",
   281  							Service:   "service",
   282  							Encoding:  Encoding,
   283  							Procedure: "MyService::someMethod",
   284  							Body:      bytes.NewReader([]byte(tt.wantRequestBody)),
   285  						}),
   286  					).Return(nil, errors.New("oneway outbound error"))
   287  				} else {
   288  					oneway.EXPECT().CallOneway(gomock.Any(),
   289  						transporttest.NewRequestMatcher(t, &transport.Request{
   290  							Caller:    "caller",
   291  							Service:   "service",
   292  							Encoding:  Encoding,
   293  							Procedure: "MyService::someMethod",
   294  							Body:      bytes.NewReader([]byte(tt.wantRequestBody)),
   295  						}),
   296  					).Return(&successAck{}, nil)
   297  				}
   298  			}
   299  
   300  			opts := tt.clientOptions
   301  			opts = append(opts, Protocol(&fakeProtocol{streamProtocol: sp}))
   302  			nwc := NewNoWire(Config{
   303  				Service: "MyService",
   304  				ClientConfig: clientconfig.MultiOutbound("caller", "service",
   305  					transport.Outbounds{
   306  						Oneway: oneway,
   307  					}),
   308  			}, opts...)
   309  
   310  			ack, err := nwc.CallOneway(ctx, tt.giveRequestBody)
   311  			if tt.wantError != "" {
   312  				require.Error(t, err)
   313  				assert.Contains(t, err.Error(), tt.wantError)
   314  				return
   315  			}
   316  
   317  			require.NoError(t, err)
   318  			assert.Equal(t, "success", ack.String())
   319  		})
   320  	}
   321  }
   322  
   323  func TestNoNewWireBadProtocolConfig(t *testing.T) {
   324  	mockCtrl := gomock.NewController(t)
   325  	defer mockCtrl.Finish()
   326  
   327  	proto := thrifttest.NewMockProtocol(mockCtrl)
   328  	assert.Panics(t,
   329  		func() {
   330  			NewNoWire(Config{}, Protocol(proto))
   331  		})
   332  }
   333  
   334  func TestBuildTransportRequestWriteError(t *testing.T) {
   335  	mockCtrl := gomock.NewController(t)
   336  	defer mockCtrl.Finish()
   337  
   338  	sp := streamtest.NewMockProtocol(mockCtrl)
   339  	sw := streamtest.NewMockWriter(mockCtrl)
   340  	sp.EXPECT().Writer(gomock.Any()).Return(sw).AnyTimes()
   341  
   342  	nwc := noWireThriftClient{
   343  		cc:         clientconfig.MultiOutbound("caller", "service", transport.Outbounds{}),
   344  		p:          sp,
   345  		Enveloping: true,
   346  	}
   347  
   348  	wantEnvHeader := stream.EnvelopeHeader{
   349  		Name:  "someMethod",
   350  		Type:  wire.Call,
   351  		SeqID: 1,
   352  	}
   353  
   354  	t.Run("envelope begin", func(t *testing.T) {
   355  		sw.EXPECT().Close().Return(nil)
   356  		sw.EXPECT().WriteEnvelopeBegin(wantEnvHeader).Return(errors.New("writeenvelopebegin error"))
   357  
   358  		_, _, err := nwc.buildTransportRequest(fakeEnveloper(wire.Call))
   359  		require.Error(t, err)
   360  		assert.Contains(t, err.Error(), `failed to encode "thrift" request body for procedure "::someMethod" of service "service": writeenvelopebegin error`)
   361  	})
   362  
   363  	t.Run("encode", func(t *testing.T) {
   364  		sw.EXPECT().Close().Return(nil)
   365  		sw.EXPECT().WriteEnvelopeBegin(wantEnvHeader).Return(nil)
   366  
   367  		_, _, err := nwc.buildTransportRequest(errorEnveloper{envelopeType: wire.Call, err: errors.New("encode error")})
   368  		require.Error(t, err)
   369  		assert.Contains(t, err.Error(), `failed to encode "thrift" request body for procedure "::someMethod" of service "service": encode error`)
   370  	})
   371  
   372  	t.Run("encode", func(t *testing.T) {
   373  		sw.EXPECT().Close().Return(nil)
   374  		sw.EXPECT().WriteEnvelopeBegin(wantEnvHeader).Return(nil)
   375  		sw.EXPECT().WriteString(_irrelevant).Return(nil)
   376  		sw.EXPECT().WriteEnvelopeEnd().Return(errors.New("writeenvelopeend error"))
   377  
   378  		_, _, err := nwc.buildTransportRequest(fakeEnveloper(wire.Call))
   379  		require.Error(t, err)
   380  		assert.Contains(t, err.Error(), `failed to encode "thrift" request body for procedure "::someMethod" of service "service": writeenvelopeend error`)
   381  	})
   382  }
   383  
   384  // encodeThriftString prefixes the passed in string with an int32 that contains
   385  // the length of the string, compliant to the Thrift protocol.
   386  func encodeThriftString(t *testing.T, s string) string {
   387  	t.Helper()
   388  
   389  	buf := make([]byte, 4)
   390  	binary.BigEndian.PutUint32(buf, uint32(len(s)))
   391  	return string(buf) + s
   392  }
   393  
   394  func encodeEnvelopeSeqID(t *testing.T, seqID int) string {
   395  	t.Helper()
   396  
   397  	buf := make([]byte, 4)
   398  	binary.BigEndian.PutUint32(buf, uint32(seqID))
   399  	return string(buf)
   400  }
   401  
   402  func encodeEnvelopeType(t *testing.T, et wire.EnvelopeType) string {
   403  	t.Helper()
   404  
   405  	buf := make([]byte, 4)
   406  	version := uint32(0x80010000) | uint32(et)
   407  	binary.BigEndian.PutUint32(buf, version)
   408  	return string(buf)
   409  }