github.com/mendersoftware/go-lib-micro@v0.0.0-20240304135804-e8e39c59b148/rest_utils/response_helpers_test.go (about) 1 // Copyright 2024 Northern.tech AS 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package rest_utils 16 17 import ( 18 "bytes" 19 "context" 20 "encoding/json" 21 "errors" 22 "net/http" 23 "net/http/httptest" 24 "testing" 25 26 "github.com/ant0ine/go-json-rest/rest" 27 "github.com/mendersoftware/go-lib-micro/accesslog" 28 "github.com/mendersoftware/go-lib-micro/log" 29 "github.com/sirupsen/logrus" 30 "github.com/stretchr/testify/assert" 31 ) 32 33 type logCounter struct { 34 n int 35 } 36 37 func (logCounter) Levels() []logrus.Level { 38 return logrus.AllLevels 39 } 40 41 func (l *logCounter) Fire(*logrus.Entry) error { 42 l.n++ 43 return nil 44 } 45 46 func TestResponseHelpers(t *testing.T) { 47 t.Parallel() 48 testCases := []struct { 49 Name string 50 CTX context.Context 51 52 HandlerFunc rest.HandlerFunc 53 NumEntries int 54 55 Fields []string 56 ExpectedBody string 57 }{{ 58 Name: "internal", 59 60 NumEntries: 1, 61 HandlerFunc: func(w rest.ResponseWriter, r *rest.Request) { 62 RestErrWithLogInternal(w, r, log.NewEmpty(), errors.New("test error")) 63 }, 64 Fields: []string{ 65 `level=error`, 66 `error="(?P<callerFrame>rest_utils.TestResponseHelpers[^@]+` + 67 `@[^:]+:[0-9]+:) internal error: test error"`, 68 }, 69 ExpectedBody: func() string { 70 b, _ := json.Marshal(ApiError{Err: "internal error"}) 71 return string(b) 72 }(), 73 }, { 74 Name: "client error", 75 76 NumEntries: 1, 77 HandlerFunc: func(w rest.ResponseWriter, r *rest.Request) { 78 RestErrWithWarningMsg(w, r, log.NewEmpty(), 79 errors.New("test error"), http.StatusBadRequest, "bad request") 80 }, 81 Fields: []string{ 82 `level=warn`, 83 `error="(?P<callerFrame>rest_utils.TestResponseHelpers[^@]+` + 84 `@[^:]+:[0-9]+:) bad request: test error"`, 85 }, 86 ExpectedBody: func() string { 87 b, _ := json.Marshal(ApiError{Err: "bad request"}) 88 return string(b) 89 }(), 90 }, { 91 Name: "fallback to logger", 92 93 NumEntries: 2, 94 HandlerFunc: func(w rest.ResponseWriter, r *rest.Request) { 95 lc := accesslog.GetContext(r.Request.Context()) 96 e := errors.New("test") 97 i := 0 98 for lc.PushError(e) && i < 10000 { 99 i++ 100 } 101 if i >= 10000 { 102 // Guard against breaking the accesslog 103 t.Error("should not be able to push 10000 errors to accesslog") 104 t.FailNow() 105 } 106 RestErrWithWarningMsg(w, r, log.NewEmpty(), 107 errors.New("test error"), http.StatusBadRequest, "bad request") 108 }, 109 Fields: []string{ 110 `level=warn`, 111 `msg="bad request: test error"`, 112 }, 113 ExpectedBody: func() string { 114 b, _ := json.Marshal(ApiError{Err: "bad request"}) 115 return string(b) 116 }(), 117 }} 118 119 for i := range testCases { 120 tc := testCases[i] 121 t.Run(tc.Name, func(t *testing.T) { 122 app, err := rest.MakeRouter(rest.Get("/test", tc.HandlerFunc)) 123 if err != nil { 124 t.Error(err) 125 t.FailNow() 126 } 127 counter := &logCounter{} 128 api := rest.NewApi() 129 var logBuf = bytes.NewBuffer(nil) 130 api.Use(rest.MiddlewareSimple( 131 func(h rest.HandlerFunc) rest.HandlerFunc { 132 logger := log.NewEmpty() 133 logger.Logger.SetLevel(logrus.DebugLevel) 134 logger.Logger.SetOutput(logBuf) 135 logger.Logger.SetFormatter(&logrus.TextFormatter{ 136 DisableColors: true, 137 FullTimestamp: true, 138 }) 139 logger.Logger.AddHook(counter) 140 return func(w rest.ResponseWriter, r *rest.Request) { 141 ctx := r.Request.Context() 142 ctx = log.WithContext(ctx, logger) 143 r.Request = r.Request.WithContext(ctx) 144 h(w, r) 145 } 146 })) 147 api.Use(&accesslog.AccessLogMiddleware{}) 148 api.SetApp(app) 149 handler := api.MakeHandler() 150 w := httptest.NewRecorder() 151 ctx := context.Background() 152 if tc.CTX != nil { 153 ctx = tc.CTX 154 } 155 req, _ := http.NewRequestWithContext( 156 ctx, 157 http.MethodGet, 158 "http://localhost/test?foo=bar", 159 nil, 160 ) 161 req.Header.Set("User-Agent", "tester") 162 163 handler.ServeHTTP(w, req) 164 165 logEntry := logBuf.String() 166 for _, field := range tc.Fields { 167 assert.Regexp(t, field, logEntry) 168 } 169 if tc.Fields == nil { 170 assert.Empty(t, logEntry) 171 } 172 if tc.ExpectedBody != "" { 173 if assert.NotNil(t, w.Body) { 174 assert.JSONEq(t, tc.ExpectedBody, w.Body.String()) 175 } 176 } 177 assert.Equal(t, tc.NumEntries, counter.n) 178 }) 179 } 180 }