go.uber.org/yarpc@v1.72.1/internal/observability/ctx_middleware_test.go (about) 1 // Copyright (c) 2022 Uber Technologies, Inc. 2 // 3 // Permission is hereby granted, free of charge, to any person obtaining a copy 4 // of this software and associated documentation files (the "Software"), to deal 5 // in the Software without restriction, including without limitation the rights 6 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 // copies of the Software, and to permit persons to whom the Software is 8 // furnished to do so, subject to the following conditions: 9 // 10 // The above copyright notice and this permission notice shall be included in 11 // all copies or substantial portions of the Software. 12 // 13 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19 // THE SOFTWARE. 20 21 package observability 22 23 import ( 24 "context" 25 "errors" 26 "fmt" 27 "testing" 28 29 "github.com/stretchr/testify/assert" 30 "github.com/stretchr/testify/require" 31 "go.uber.org/yarpc/api/transport" 32 "go.uber.org/yarpc/api/transport/transporttest" 33 "go.uber.org/yarpc/yarpcerrors" 34 "go.uber.org/zap" 35 "go.uber.org/zap/zapcore" 36 "go.uber.org/zap/zaptest/observer" 37 ) 38 39 func TestContextMiddleware(t *testing.T) { 40 const ( 41 ctxDeadlineExceededMsg = `call to procedure "my-procedure" of service "my-service" from caller "my-caller" timed out` 42 ctxCancelledMsg = `call to procedure "my-procedure" of service "my-service" from caller "my-caller" was canceled` 43 ) 44 45 core, logs := observer.New(zapcore.DebugLevel) 46 infoLevel := zapcore.InfoLevel 47 mw := NewMiddleware(Config{ 48 Logger: zap.New(core), 49 ContextExtractor: NewNopContextExtractor(), 50 Levels: LevelsConfig{ 51 Default: DirectionalLevelsConfig{ 52 Success: &infoLevel, 53 ApplicationError: &infoLevel, 54 Failure: &infoLevel, 55 }, 56 }, 57 }) 58 59 tests := []struct { 60 name string 61 handlerErr error 62 appErr bool 63 ctx func() context.Context 64 65 wantDeadlineExceeded bool 66 wantCtxCancelled bool 67 }{ 68 { 69 name: "no-op/handler success", 70 ctx: func() context.Context { return context.Background() }, 71 }, 72 { 73 name: "no-op/handler err", 74 handlerErr: errors.New("an err"), 75 ctx: func() context.Context { return context.Background() }, 76 }, 77 { 78 name: "deadline exceeded/handler success", 79 ctx: func() context.Context { 80 ctx, cancel := context.WithTimeout(context.Background(), -1) 81 cancel() 82 return ctx 83 }, 84 wantDeadlineExceeded: true, 85 }, 86 { 87 name: "deadline exceeded/handler err", 88 handlerErr: fmt.Errorf("my custom error"), 89 ctx: func() context.Context { 90 ctx, cancel := context.WithTimeout(context.Background(), -1) 91 cancel() 92 return ctx 93 }, 94 wantDeadlineExceeded: true, 95 }, 96 { 97 name: "deadline exceeded/app err", 98 ctx: func() context.Context { 99 ctx, cancel := context.WithTimeout(context.Background(), -1) 100 cancel() 101 return ctx 102 }, 103 appErr: true, 104 wantDeadlineExceeded: true, 105 }, 106 { 107 name: "cancelled error/handler success", 108 ctx: func() context.Context { 109 ctx, cancel := context.WithCancel(context.Background()) 110 cancel() 111 return ctx 112 }, 113 wantCtxCancelled: true, 114 }, 115 { 116 name: "cancelled error/handler err", 117 handlerErr: fmt.Errorf("my custom error"), 118 ctx: func() context.Context { 119 ctx, cancel := context.WithCancel(context.Background()) 120 cancel() 121 return ctx 122 }, 123 wantCtxCancelled: true, 124 }, 125 { 126 name: "cancelled error/app err", 127 ctx: func() context.Context { 128 ctx, cancel := context.WithCancel(context.Background()) 129 cancel() 130 return ctx 131 }, 132 appErr: true, 133 wantCtxCancelled: true, 134 }, 135 } 136 137 req := &transport.Request{ 138 Service: "my-service", 139 Procedure: "my-procedure", 140 Caller: "my-caller", 141 } 142 143 expectLogField := func(appErr bool, err error) *zap.Field { 144 dropMsg := _droppedSuccessLog 145 if err == nil && appErr { 146 dropMsg = _droppedAppErrLog 147 } else if err != nil { 148 dropMsg = fmt.Sprintf(_droppedErrLogFmt, err) 149 } 150 log := zap.String(_dropped, dropMsg) 151 return &log 152 } 153 154 getDropLogField := func(t *testing.T) *zap.Field { 155 entries := logs.TakeAll() 156 require.Equal(t, 1, len(entries), "unexpected number of logs written: %v", entries) 157 for _, f := range entries[0].Context { 158 if f.Key == _dropped { 159 return &f 160 } 161 } 162 return nil 163 } 164 165 for _, tt := range tests { 166 t.Run(tt.name, func(t *testing.T) { 167 defer logs.TakeAll() // throw away logs for next run 168 169 handler := &testHandler{err: tt.handlerErr, appErr: tt.appErr} 170 err := mw.Handle(tt.ctx(), req, &transporttest.FakeResponseWriter{}, handler) 171 172 if tt.wantDeadlineExceeded { 173 assert.EqualError(t, 174 err, 175 yarpcerrors.DeadlineExceededErrorf(ctxDeadlineExceededMsg).Error(), 176 "expected deadline exceeded error override") 177 178 assert.Equal(t, expectLogField(tt.appErr, tt.handlerErr), getDropLogField(t), "unexpected log") 179 return 180 } 181 182 if tt.wantCtxCancelled { 183 assert.EqualError(t, 184 err, 185 yarpcerrors.CancelledErrorf(ctxCancelledMsg).Error(), 186 "expected cancelled yarpcerror code") 187 188 assert.Equal(t, expectLogField(tt.appErr, tt.handlerErr), getDropLogField(t), "unexpected log") 189 return 190 } 191 192 assert.Equal(t, tt.handlerErr, err, "unexpected error") 193 assert.Nil(t, getDropLogField(t), "unexpectedly saw 'dropped' log field") 194 }) 195 } 196 } 197 198 type testHandler struct { 199 err error 200 appErr bool 201 } 202 203 func (h *testHandler) Handle(ctx context.Context, req *transport.Request, resw transport.ResponseWriter) error { 204 if h.appErr { 205 resw.SetApplicationError() 206 } 207 return h.err 208 }