go.uber.org/yarpc@v1.72.1/encoding/thrift/inbound_nowire_test.go (about)

     1  // Copyright (c) 2022 Uber Technologies, Inc.
     2  //
     3  // Permission is hereby granted, free of charge, to any person obtaining a copy
     4  // of this software and associated documentation files (the "Software"), to deal
     5  // in the Software without restriction, including without limitation the rights
     6  // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
     7  // copies of the Software, and to permit persons to whom the Software is
     8  // furnished to do so, subject to the following conditions:
     9  //
    10  // The above copyright notice and this permission notice shall be included in
    11  // all copies or substantial portions of the Software.
    12  //
    13  // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    14  // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    15  // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    16  // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    17  // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    18  // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
    19  // THE SOFTWARE.
    20  
    21  package thrift
    22  
    23  import (
    24  	"context"
    25  	"fmt"
    26  	"testing"
    27  
    28  	"github.com/golang/mock/gomock"
    29  	"github.com/stretchr/testify/assert"
    30  	"github.com/stretchr/testify/require"
    31  	"go.uber.org/thriftrw/protocol/binary"
    32  	"go.uber.org/thriftrw/protocol/stream"
    33  	"go.uber.org/thriftrw/thrifttest/streamtest"
    34  	"go.uber.org/thriftrw/wire"
    35  	"go.uber.org/yarpc/api/transport/transporttest"
    36  	"go.uber.org/yarpc/internal/testtime"
    37  )
    38  
    39  const _body = "decoded"
    40  
    41  type bodyReader struct {
    42  	sr   stream.Reader
    43  	body string
    44  	err  error
    45  }
    46  
    47  func (br *bodyReader) Decode(sr stream.Reader) error {
    48  	br.sr = sr
    49  	br.body = _body
    50  	return br.err
    51  }
    52  
    53  type responseHandler struct {
    54  	t   *testing.T
    55  	nwc *NoWireCall
    56  
    57  	reqBody  stream.BodyReader
    58  	body     stream.Enveloper
    59  	appError bool
    60  }
    61  
    62  var _ NoWireHandler = (*responseHandler)(nil)
    63  
    64  func (rh *responseHandler) HandleNoWire(ctx context.Context, nwc *NoWireCall) (NoWireResponse, error) {
    65  	rh.t.Helper()
    66  
    67  	// All calls to Handle must have everything in a NoWireCall set.
    68  	require.NotNil(rh.t, nwc)
    69  	assert.NotNil(rh.t, nwc.Reader)
    70  	assert.NotNil(rh.t, nwc.RequestReader)
    71  	assert.NotNil(rh.t, nwc.EnvelopeType)
    72  
    73  	rh.nwc = nwc
    74  	rw, err := nwc.RequestReader.ReadRequest(ctx, nwc.EnvelopeType, nwc.Reader, rh.reqBody)
    75  	return NoWireResponse{
    76  		Body:               rh.body,
    77  		ResponseWriter:     rw,
    78  		IsApplicationError: rh.appError,
    79  	}, err
    80  }
    81  
    82  func TestDecodeNoWireRequestUnary(t *testing.T) {
    83  	mockCtrl := gomock.NewController(t)
    84  	defer mockCtrl.Finish()
    85  
    86  	env := streamtest.NewMockEnveloper(mockCtrl)
    87  	env.EXPECT().EnvelopeType().Return(wire.Reply).Times(1)
    88  	env.EXPECT().Encode(gomock.Any()).Return(nil).Times(1)
    89  
    90  	rh := responseHandler{
    91  		t:       t,
    92  		reqBody: &bodyReader{},
    93  		body:    env,
    94  	}
    95  	proto := binary.Default
    96  	h := thriftNoWireHandler{
    97  		Handler:       &rh,
    98  		RequestReader: proto,
    99  	}
   100  
   101  	ctx, cancel := context.WithTimeout(context.Background(), testtime.Second)
   102  	defer cancel()
   103  
   104  	req := request()
   105  	rw := new(transporttest.FakeResponseWriter)
   106  	require.NoError(t, h.Handle(ctx, req, rw))
   107  	assert.Equal(t, req.Body, rh.nwc.Reader)
   108  	assert.Equal(t, proto, rh.nwc.RequestReader)
   109  	assert.Equal(t, wire.Call, rh.nwc.EnvelopeType) // Unary call
   110  }
   111  
   112  func TestDecodeNoWireRequestOneway(t *testing.T) {
   113  	mockCtrl := gomock.NewController(t)
   114  	defer mockCtrl.Finish()
   115  
   116  	// OneWay calls have no calls to the response body
   117  	env := streamtest.NewMockEnveloper(mockCtrl)
   118  	rh := responseHandler{
   119  		t:       t,
   120  		reqBody: &bodyReader{},
   121  		body:    env,
   122  	}
   123  	proto := binary.Default
   124  	h := thriftNoWireHandler{
   125  		Handler:       &rh,
   126  		RequestReader: proto,
   127  	}
   128  
   129  	ctx, cancel := context.WithTimeout(context.Background(), testtime.Second)
   130  	defer cancel()
   131  
   132  	req := request()
   133  	require.NoError(t, h.HandleOneway(ctx, req))
   134  	assert.Equal(t, req.Body, rh.nwc.Reader)
   135  	assert.Equal(t, proto, rh.nwc.RequestReader)
   136  	assert.Equal(t, wire.OneWay, rh.nwc.EnvelopeType) // OneWay call
   137  }
   138  
   139  func TestNoWireHandleIncorrectResponseEnvelope(t *testing.T) {
   140  	mockCtrl := gomock.NewController(t)
   141  	defer mockCtrl.Finish()
   142  
   143  	env := streamtest.NewMockEnveloper(mockCtrl)
   144  	env.EXPECT().EnvelopeType().Return(wire.Exception).Times(1)
   145  
   146  	br := &bodyReader{}
   147  	rh := responseHandler{
   148  		t:       t,
   149  		reqBody: br,
   150  		body:    env,
   151  	}
   152  	proto := binary.Default
   153  	h := thriftNoWireHandler{
   154  		Handler:       &rh,
   155  		RequestReader: proto,
   156  	}
   157  
   158  	ctx, cancel := context.WithTimeout(context.Background(), testtime.Second)
   159  	defer cancel()
   160  
   161  	req := request()
   162  	rw := new(transporttest.FakeResponseWriter)
   163  	err := h.Handle(ctx, req, rw)
   164  	require.Error(t, err)
   165  	assert.Contains(t, err.Error(), "unexpected envelope type")
   166  }
   167  
   168  func TestNoWireHandleWriteResponseError(t *testing.T) {
   169  	mockCtrl := gomock.NewController(t)
   170  	defer mockCtrl.Finish()
   171  
   172  	env := streamtest.NewMockEnveloper(mockCtrl)
   173  	env.EXPECT().EnvelopeType().Return(wire.Reply).Times(1)
   174  
   175  	rw := new(transporttest.FakeResponseWriter)
   176  	streamRw := streamtest.NewMockResponseWriter(mockCtrl)
   177  	streamRw.EXPECT().WriteResponse(wire.Reply, rw, env).Return(fmt.Errorf("write response error")).Times(1)
   178  
   179  	req := request()
   180  	br := &bodyReader{}
   181  	proto := streamtest.NewMockRequestReader(mockCtrl)
   182  	proto.EXPECT().ReadRequest(gomock.Any(), wire.Call, req.Body, br).
   183  		Return(streamRw, nil)
   184  
   185  	rh := responseHandler{t: t, reqBody: br, body: env}
   186  	h := thriftNoWireHandler{
   187  		Handler:       &rh,
   188  		RequestReader: proto,
   189  	}
   190  
   191  	ctx, cancel := context.WithTimeout(context.Background(), testtime.Second)
   192  	defer cancel()
   193  
   194  	err := h.Handle(ctx, req, rw)
   195  	require.Error(t, err)
   196  	assert.Contains(t, err.Error(), "write response error")
   197  }
   198  
   199  func TestDecodeNoWireRequestExpectEncodingsError(t *testing.T) {
   200  	mockCtrl := gomock.NewController(t)
   201  	defer mockCtrl.Finish()
   202  
   203  	// incorrect encoding in response should result in no calls to the response body
   204  	env := streamtest.NewMockEnveloper(mockCtrl)
   205  	h := thriftNoWireHandler{
   206  		Handler:       &responseHandler{t: t, body: env},
   207  		RequestReader: binary.Default,
   208  	}
   209  
   210  	ctx, cancel := context.WithTimeout(context.Background(), testtime.Second)
   211  	defer cancel()
   212  
   213  	req := request()
   214  	req.Encoding = "grpc"
   215  
   216  	rw := new(transporttest.FakeResponseWriter)
   217  	err := h.Handle(ctx, req, rw)
   218  	require.Error(t, err)
   219  	assert.Contains(t, err.Error(), `expected encoding "thrift" but got "grpc"`)
   220  }
   221  
   222  func TestDecodeNoWireAppliationError(t *testing.T) {
   223  	mockCtrl := gomock.NewController(t)
   224  	defer mockCtrl.Finish()
   225  
   226  	env := streamtest.NewMockEnveloper(mockCtrl)
   227  	env.EXPECT().EnvelopeType().Return(wire.Reply).Times(1)
   228  	env.EXPECT().Encode(gomock.Any()).Return(nil).Times(1)
   229  
   230  	br := &bodyReader{}
   231  	h := thriftNoWireHandler{
   232  		Handler: &responseHandler{
   233  			t:        t,
   234  			reqBody:  br,
   235  			body:     env,
   236  			appError: true,
   237  		},
   238  		RequestReader: binary.Default,
   239  	}
   240  
   241  	ctx, cancel := context.WithTimeout(context.Background(), testtime.Second)
   242  	defer cancel()
   243  
   244  	req := request()
   245  	rw := new(transporttest.FakeResponseWriter)
   246  	require.NoError(t, h.Handle(ctx, req, rw))
   247  	assert.True(t, rw.IsApplicationError)
   248  }