go.uber.org/yarpc@v1.72.1/internal/observability/ctx_middleware_test.go (about)

     1  // Copyright (c) 2022 Uber Technologies, Inc.
     2  //
     3  // Permission is hereby granted, free of charge, to any person obtaining a copy
     4  // of this software and associated documentation files (the "Software"), to deal
     5  // in the Software without restriction, including without limitation the rights
     6  // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
     7  // copies of the Software, and to permit persons to whom the Software is
     8  // furnished to do so, subject to the following conditions:
     9  //
    10  // The above copyright notice and this permission notice shall be included in
    11  // all copies or substantial portions of the Software.
    12  //
    13  // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    14  // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    15  // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    16  // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    17  // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    18  // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
    19  // THE SOFTWARE.
    20  
    21  package observability
    22  
    23  import (
    24  	"context"
    25  	"errors"
    26  	"fmt"
    27  	"testing"
    28  
    29  	"github.com/stretchr/testify/assert"
    30  	"github.com/stretchr/testify/require"
    31  	"go.uber.org/yarpc/api/transport"
    32  	"go.uber.org/yarpc/api/transport/transporttest"
    33  	"go.uber.org/yarpc/yarpcerrors"
    34  	"go.uber.org/zap"
    35  	"go.uber.org/zap/zapcore"
    36  	"go.uber.org/zap/zaptest/observer"
    37  )
    38  
    39  func TestContextMiddleware(t *testing.T) {
    40  	const (
    41  		ctxDeadlineExceededMsg = `call to procedure "my-procedure" of service "my-service" from caller "my-caller" timed out`
    42  		ctxCancelledMsg        = `call to procedure "my-procedure" of service "my-service" from caller "my-caller" was canceled`
    43  	)
    44  
    45  	core, logs := observer.New(zapcore.DebugLevel)
    46  	infoLevel := zapcore.InfoLevel
    47  	mw := NewMiddleware(Config{
    48  		Logger:           zap.New(core),
    49  		ContextExtractor: NewNopContextExtractor(),
    50  		Levels: LevelsConfig{
    51  			Default: DirectionalLevelsConfig{
    52  				Success:          &infoLevel,
    53  				ApplicationError: &infoLevel,
    54  				Failure:          &infoLevel,
    55  			},
    56  		},
    57  	})
    58  
    59  	tests := []struct {
    60  		name       string
    61  		handlerErr error
    62  		appErr     bool
    63  		ctx        func() context.Context
    64  
    65  		wantDeadlineExceeded bool
    66  		wantCtxCancelled     bool
    67  	}{
    68  		{
    69  			name: "no-op/handler success",
    70  			ctx:  func() context.Context { return context.Background() },
    71  		},
    72  		{
    73  			name:       "no-op/handler err",
    74  			handlerErr: errors.New("an err"),
    75  			ctx:        func() context.Context { return context.Background() },
    76  		},
    77  		{
    78  			name: "deadline exceeded/handler success",
    79  			ctx: func() context.Context {
    80  				ctx, cancel := context.WithTimeout(context.Background(), -1)
    81  				cancel()
    82  				return ctx
    83  			},
    84  			wantDeadlineExceeded: true,
    85  		},
    86  		{
    87  			name:       "deadline exceeded/handler err",
    88  			handlerErr: fmt.Errorf("my custom error"),
    89  			ctx: func() context.Context {
    90  				ctx, cancel := context.WithTimeout(context.Background(), -1)
    91  				cancel()
    92  				return ctx
    93  			},
    94  			wantDeadlineExceeded: true,
    95  		},
    96  		{
    97  			name: "deadline exceeded/app err",
    98  			ctx: func() context.Context {
    99  				ctx, cancel := context.WithTimeout(context.Background(), -1)
   100  				cancel()
   101  				return ctx
   102  			},
   103  			appErr:               true,
   104  			wantDeadlineExceeded: true,
   105  		},
   106  		{
   107  			name: "cancelled error/handler success",
   108  			ctx: func() context.Context {
   109  				ctx, cancel := context.WithCancel(context.Background())
   110  				cancel()
   111  				return ctx
   112  			},
   113  			wantCtxCancelled: true,
   114  		},
   115  		{
   116  			name:       "cancelled error/handler err",
   117  			handlerErr: fmt.Errorf("my custom error"),
   118  			ctx: func() context.Context {
   119  				ctx, cancel := context.WithCancel(context.Background())
   120  				cancel()
   121  				return ctx
   122  			},
   123  			wantCtxCancelled: true,
   124  		},
   125  		{
   126  			name: "cancelled error/app err",
   127  			ctx: func() context.Context {
   128  				ctx, cancel := context.WithCancel(context.Background())
   129  				cancel()
   130  				return ctx
   131  			},
   132  			appErr:           true,
   133  			wantCtxCancelled: true,
   134  		},
   135  	}
   136  
   137  	req := &transport.Request{
   138  		Service:   "my-service",
   139  		Procedure: "my-procedure",
   140  		Caller:    "my-caller",
   141  	}
   142  
   143  	expectLogField := func(appErr bool, err error) *zap.Field {
   144  		dropMsg := _droppedSuccessLog
   145  		if err == nil && appErr {
   146  			dropMsg = _droppedAppErrLog
   147  		} else if err != nil {
   148  			dropMsg = fmt.Sprintf(_droppedErrLogFmt, err)
   149  		}
   150  		log := zap.String(_dropped, dropMsg)
   151  		return &log
   152  	}
   153  
   154  	getDropLogField := func(t *testing.T) *zap.Field {
   155  		entries := logs.TakeAll()
   156  		require.Equal(t, 1, len(entries), "unexpected number of logs written: %v", entries)
   157  		for _, f := range entries[0].Context {
   158  			if f.Key == _dropped {
   159  				return &f
   160  			}
   161  		}
   162  		return nil
   163  	}
   164  
   165  	for _, tt := range tests {
   166  		t.Run(tt.name, func(t *testing.T) {
   167  			defer logs.TakeAll() // throw away logs for next run
   168  
   169  			handler := &testHandler{err: tt.handlerErr, appErr: tt.appErr}
   170  			err := mw.Handle(tt.ctx(), req, &transporttest.FakeResponseWriter{}, handler)
   171  
   172  			if tt.wantDeadlineExceeded {
   173  				assert.EqualError(t,
   174  					err,
   175  					yarpcerrors.DeadlineExceededErrorf(ctxDeadlineExceededMsg).Error(),
   176  					"expected deadline exceeded error override")
   177  
   178  				assert.Equal(t, expectLogField(tt.appErr, tt.handlerErr), getDropLogField(t), "unexpected log")
   179  				return
   180  			}
   181  
   182  			if tt.wantCtxCancelled {
   183  				assert.EqualError(t,
   184  					err,
   185  					yarpcerrors.CancelledErrorf(ctxCancelledMsg).Error(),
   186  					"expected cancelled yarpcerror code")
   187  
   188  				assert.Equal(t, expectLogField(tt.appErr, tt.handlerErr), getDropLogField(t), "unexpected log")
   189  				return
   190  			}
   191  
   192  			assert.Equal(t, tt.handlerErr, err, "unexpected error")
   193  			assert.Nil(t, getDropLogField(t), "unexpectedly saw 'dropped' log field")
   194  		})
   195  	}
   196  }
   197  
   198  type testHandler struct {
   199  	err    error
   200  	appErr bool
   201  }
   202  
   203  func (h *testHandler) Handle(ctx context.Context, req *transport.Request, resw transport.ResponseWriter) error {
   204  	if h.appErr {
   205  		resw.SetApplicationError()
   206  	}
   207  	return h.err
   208  }