go.uber.org/yarpc@v1.72.1/internal/inboundmiddleware/chain_test.go (about)

     1  // Copyright (c) 2022 Uber Technologies, Inc.
     2  //
     3  // Permission is hereby granted, free of charge, to any person obtaining a copy
     4  // of this software and associated documentation files (the "Software"), to deal
     5  // in the Software without restriction, including without limitation the rights
     6  // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
     7  // copies of the Software, and to permit persons to whom the Software is
     8  // furnished to do so, subject to the following conditions:
     9  //
    10  // The above copyright notice and this permission notice shall be included in
    11  // all copies or substantial portions of the Software.
    12  //
    13  // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    14  // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    15  // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    16  // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    17  // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    18  // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
    19  // THE SOFTWARE.
    20  
    21  package inboundmiddleware
    22  
    23  import (
    24  	"bytes"
    25  	"context"
    26  	"errors"
    27  	"testing"
    28  
    29  	"github.com/golang/mock/gomock"
    30  	"github.com/stretchr/testify/assert"
    31  	"github.com/stretchr/testify/require"
    32  	"go.uber.org/yarpc/api/middleware"
    33  	"go.uber.org/yarpc/api/transport"
    34  	"go.uber.org/yarpc/api/transport/transporttest"
    35  	"go.uber.org/yarpc/internal/testtime"
    36  )
    37  
    38  type countInboundMiddleware struct{ Count int }
    39  
    40  func (c *countInboundMiddleware) Handle(
    41  	ctx context.Context, req *transport.Request, resw transport.ResponseWriter, h transport.UnaryHandler) error {
    42  	c.Count++
    43  	return h.Handle(ctx, req, resw)
    44  }
    45  
    46  func (c *countInboundMiddleware) HandleOneway(ctx context.Context, req *transport.Request, h transport.OnewayHandler) error {
    47  	c.Count++
    48  	return h.HandleOneway(ctx, req)
    49  }
    50  
    51  func (c *countInboundMiddleware) HandleStream(s *transport.ServerStream, h transport.StreamHandler) error {
    52  	c.Count++
    53  	return h.HandleStream(s)
    54  }
    55  
    56  var retryUnaryInbound middleware.UnaryInboundFunc = func(
    57  	ctx context.Context, req *transport.Request, resw transport.ResponseWriter, h transport.UnaryHandler) error {
    58  	if err := h.Handle(ctx, req, resw); err != nil {
    59  		return h.Handle(ctx, req, resw)
    60  	}
    61  	return nil
    62  }
    63  
    64  func TestUnaryChain(t *testing.T) {
    65  	before := &countInboundMiddleware{}
    66  	after := &countInboundMiddleware{}
    67  
    68  	tests := []struct {
    69  		desc string
    70  		mw   middleware.UnaryInbound
    71  	}{
    72  		{"flat chain", UnaryChain(before, retryUnaryInbound, after, nil)},
    73  		{"nested chain", UnaryChain(before, UnaryChain(retryUnaryInbound, nil, after))},
    74  	}
    75  
    76  	for _, tt := range tests {
    77  		t.Run(tt.desc, func(t *testing.T) {
    78  			before.Count, after.Count = 0, 0
    79  			mockCtrl := gomock.NewController(t)
    80  			defer mockCtrl.Finish()
    81  			ctx, cancel := context.WithTimeout(context.Background(), testtime.Second)
    82  			defer cancel()
    83  
    84  			req := &transport.Request{
    85  				Caller:    "somecaller",
    86  				Service:   "someservice",
    87  				Encoding:  transport.Encoding("raw"),
    88  				Procedure: "hello",
    89  				Body:      bytes.NewReader([]byte{1, 2, 3}),
    90  			}
    91  			resw := new(transporttest.FakeResponseWriter)
    92  			h := transporttest.NewMockUnaryHandler(mockCtrl)
    93  			h.EXPECT().Handle(ctx, req, resw).After(
    94  				h.EXPECT().Handle(ctx, req, resw).Return(errors.New("great sadness")),
    95  			).Return(nil)
    96  
    97  			err := middleware.ApplyUnaryInbound(h, tt.mw).Handle(ctx, req, resw)
    98  
    99  			assert.NoError(t, err, "expected success")
   100  			assert.Equal(t, 1, before.Count, "expected outer inbound middleware to be called once")
   101  			assert.Equal(t, 2, after.Count, "expected inner inbound middleware to be called twice")
   102  		})
   103  	}
   104  }
   105  
   106  var retryOnewayInbound middleware.OnewayInboundFunc = func(
   107  	ctx context.Context, req *transport.Request, h transport.OnewayHandler) error {
   108  	if err := h.HandleOneway(ctx, req); err != nil {
   109  		return h.HandleOneway(ctx, req)
   110  	}
   111  	return nil
   112  }
   113  
   114  func TestOnewayChain(t *testing.T) {
   115  	before := &countInboundMiddleware{}
   116  	after := &countInboundMiddleware{}
   117  
   118  	tests := []struct {
   119  		desc string
   120  		mw   middleware.OnewayInbound
   121  	}{
   122  		{"flat chain", OnewayChain(before, retryOnewayInbound, after, nil)},
   123  		{"nested chain", OnewayChain(before, OnewayChain(retryOnewayInbound, nil, after))},
   124  	}
   125  
   126  	for _, tt := range tests {
   127  		t.Run(tt.desc, func(t *testing.T) {
   128  			before.Count, after.Count = 0, 0
   129  			mockCtrl := gomock.NewController(t)
   130  			defer mockCtrl.Finish()
   131  			ctx, cancel := context.WithTimeout(context.Background(), testtime.Second)
   132  			defer cancel()
   133  
   134  			req := &transport.Request{
   135  				Caller:    "somecaller",
   136  				Service:   "someservice",
   137  				Encoding:  transport.Encoding("raw"),
   138  				Procedure: "hello",
   139  				Body:      bytes.NewReader([]byte{1, 2, 3}),
   140  			}
   141  			h := transporttest.NewMockOnewayHandler(mockCtrl)
   142  			h.EXPECT().HandleOneway(ctx, req).After(
   143  				h.EXPECT().HandleOneway(ctx, req).Return(errors.New("great sadness")),
   144  			).Return(nil)
   145  
   146  			err := middleware.ApplyOnewayInbound(h, tt.mw).HandleOneway(ctx, req)
   147  
   148  			assert.NoError(t, err, "expected success")
   149  			assert.Equal(t, 1, before.Count, "expected outer inbound middleware to be called once")
   150  			assert.Equal(t, 2, after.Count, "expected inner inbound middleware to be called twice")
   151  		})
   152  	}
   153  }
   154  
   155  var retryStreamInbound middleware.StreamInboundFunc = func(
   156  	s *transport.ServerStream, h transport.StreamHandler) error {
   157  	if err := h.HandleStream(s); err != nil {
   158  		return h.HandleStream(s)
   159  	}
   160  	return nil
   161  }
   162  
   163  func TestStreamChain(t *testing.T) {
   164  	before := &countInboundMiddleware{}
   165  	after := &countInboundMiddleware{}
   166  
   167  	tests := []struct {
   168  		desc string
   169  		mw   middleware.StreamInbound
   170  	}{
   171  		{"flat chain", StreamChain(before, retryStreamInbound, after, nil)},
   172  		{"nested chain", StreamChain(before, StreamChain(retryStreamInbound, nil, after))},
   173  		{"single chain", StreamChain(StreamChain(before), retryStreamInbound, StreamChain(after), StreamChain())},
   174  	}
   175  
   176  	for _, tt := range tests {
   177  		t.Run(tt.desc, func(t *testing.T) {
   178  			before.Count, after.Count = 0, 0
   179  			mockCtrl := gomock.NewController(t)
   180  			defer mockCtrl.Finish()
   181  			s, err := transport.NewServerStream(transporttest.NewMockStream(mockCtrl))
   182  			require.NoError(t, err)
   183  
   184  			h := transporttest.NewMockStreamHandler(mockCtrl)
   185  			h.EXPECT().HandleStream(s).After(
   186  				h.EXPECT().HandleStream(s).Return(errors.New("great sadness")),
   187  			).Return(nil)
   188  
   189  			err = middleware.ApplyStreamInbound(h, tt.mw).HandleStream(s)
   190  
   191  			assert.NoError(t, err, "expected success")
   192  			assert.Equal(t, 1, before.Count, "expected outer inbound middleware to be called once")
   193  			assert.Equal(t, 2, after.Count, "expected inner inbound middleware to be called twice")
   194  		})
   195  	}
   196  }