github.com/mendersoftware/go-lib-micro@v0.0.0-20240304135804-e8e39c59b148/accesslog/middleware_test.go (about)

     1  // Copyright 2023 Northern.tech AS
     2  //
     3  //    Licensed under the Apache License, Version 2.0 (the "License");
     4  //    you may not use this file except in compliance with the License.
     5  //    You may obtain a copy of the License at
     6  //
     7  //        http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  //    Unless required by applicable law or agreed to in writing, software
    10  //    distributed under the License is distributed on an "AS IS" BASIS,
    11  //    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  //    See the License for the specific language governing permissions and
    13  //    limitations under the License.
    14  
    15  package accesslog
    16  
    17  import (
    18  	"bytes"
    19  	"context"
    20  	"net/http"
    21  	"net/http/httptest"
    22  	"testing"
    23  
    24  	"github.com/ant0ine/go-json-rest/rest"
    25  	"github.com/mendersoftware/go-lib-micro/log"
    26  	"github.com/sirupsen/logrus"
    27  	"github.com/stretchr/testify/assert"
    28  )
    29  
    30  func TestMiddlewareLegacy(t *testing.T) {
    31  	testCases := []struct {
    32  		Name string
    33  		CTX  context.Context
    34  
    35  		HandlerFunc rest.HandlerFunc
    36  
    37  		Fields       []string
    38  		ExpectedBody string
    39  	}{{
    40  		Name: "ok",
    41  
    42  		HandlerFunc: func(w rest.ResponseWriter, r *rest.Request) {
    43  			w.WriteHeader(http.StatusNoContent)
    44  		},
    45  		Fields: []string{
    46  			"status=204",
    47  			`path=/test`,
    48  			`qs="foo=bar"`,
    49  			"method=GET",
    50  			"responsetime=",
    51  			"ts=",
    52  		},
    53  	}, {
    54  		Name: "canceled context",
    55  
    56  		CTX: func() context.Context {
    57  			ctx, cancel := context.WithCancel(context.Background())
    58  			cancel()
    59  			return ctx
    60  		}(),
    61  		HandlerFunc: func(w rest.ResponseWriter, r *rest.Request) {
    62  			w.WriteHeader(http.StatusNoContent)
    63  		},
    64  		Fields: []string{
    65  			"status=499",
    66  			`path=/test`,
    67  			`qs="foo=bar"`,
    68  			"method=GET",
    69  			"responsetime=",
    70  			"ts=",
    71  		},
    72  	}, {
    73  		Name: "error, panic in handler",
    74  
    75  		HandlerFunc: func(w rest.ResponseWriter, r *rest.Request) {
    76  			panic("!!!!!")
    77  		},
    78  
    79  		Fields: []string{
    80  			"status=500",
    81  			`path=/test`,
    82  			`qs="foo=bar"`,
    83  			"method=GET",
    84  			"responsetime=",
    85  			"ts=",
    86  			// First three entries in the trace should match this:
    87  			`trace=".+TestMiddlewareLegacy\.func[0-9.]*@middleware_test\.go:[0-9.]+\\n`,
    88  		},
    89  		ExpectedBody: `{"Error": "Internal Server Error"}`,
    90  	}}
    91  
    92  	for i := range testCases {
    93  		tc := testCases[i]
    94  		t.Run(tc.Name, func(t *testing.T) {
    95  			app, err := rest.MakeRouter(rest.Get("/test", tc.HandlerFunc))
    96  			if err != nil {
    97  				t.Error(err)
    98  				t.FailNow()
    99  			}
   100  			api := rest.NewApi()
   101  			var logBuf = bytes.NewBuffer(nil)
   102  			api.Use(rest.MiddlewareSimple(
   103  				func(h rest.HandlerFunc) rest.HandlerFunc {
   104  					logger := log.NewEmpty()
   105  					logger.Logger.SetLevel(logrus.InfoLevel)
   106  					logger.Logger.SetOutput(logBuf)
   107  					logger.Logger.SetFormatter(&logrus.TextFormatter{
   108  						DisableColors: true,
   109  						FullTimestamp: true,
   110  					})
   111  					return func(w rest.ResponseWriter, r *rest.Request) {
   112  						ctx := r.Request.Context()
   113  						ctx = log.WithContext(ctx, logger)
   114  						r.Request = r.Request.WithContext(ctx)
   115  						h(w, r)
   116  						t.Log(r.Env)
   117  					}
   118  				}))
   119  			api.Use(&AccessLogMiddleware{})
   120  			api.SetApp(app)
   121  			handler := api.MakeHandler()
   122  			w := httptest.NewRecorder()
   123  			ctx := context.Background()
   124  			if tc.CTX != nil {
   125  				ctx = tc.CTX
   126  			}
   127  			req, _ := http.NewRequestWithContext(
   128  				ctx,
   129  				http.MethodGet,
   130  				"http://localhost/test?foo=bar",
   131  				nil,
   132  			)
   133  			req.Header.Set("User-Agent", "tester")
   134  
   135  			handler.ServeHTTP(w, req)
   136  
   137  			logEntry := logBuf.String()
   138  			for _, field := range tc.Fields {
   139  				assert.Regexp(t, field, logEntry)
   140  			}
   141  			if tc.Fields == nil {
   142  				assert.Empty(t, logEntry)
   143  			}
   144  			if tc.ExpectedBody != "" {
   145  				if assert.NotNil(t, w.Body) {
   146  					assert.JSONEq(t, tc.ExpectedBody, w.Body.String())
   147  				}
   148  			}
   149  		})
   150  	}
   151  }