github.com/mendersoftware/go-lib-micro@v0.0.0-20240304135804-e8e39c59b148/requestid/middleware_test.go (about)

     1  // Copyright 2023 Northern.tech AS
     2  //
     3  //    Licensed under the Apache License, Version 2.0 (the "License");
     4  //    you may not use this file except in compliance with the License.
     5  //    You may obtain a copy of the License at
     6  //
     7  //        http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  //    Unless required by applicable law or agreed to in writing, software
    10  //    distributed under the License is distributed on an "AS IS" BASIS,
    11  //    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  //    See the License for the specific language governing permissions and
    13  //    limitations under the License.
    14  
    15  package requestid
    16  
    17  import (
    18  	"net/http"
    19  	"net/http/httptest"
    20  	"testing"
    21  
    22  	"github.com/ant0ine/go-json-rest/rest"
    23  	"github.com/ant0ine/go-json-rest/rest/test"
    24  	"github.com/gin-gonic/gin"
    25  	"github.com/google/uuid"
    26  	"github.com/stretchr/testify/assert"
    27  )
    28  
    29  func init() {
    30  	gin.SetMode(gin.ReleaseMode) // please just shut up
    31  }
    32  
    33  func TestGinMiddleware(t *testing.T) {
    34  	t.Parallel()
    35  	testCases := []struct {
    36  		Name string
    37  
    38  		Options *MiddlewareOptions
    39  
    40  		Headers http.Header
    41  	}{{
    42  		Name: "Request with ID",
    43  
    44  		Headers: func() http.Header {
    45  			hdr := http.Header{}
    46  			hdr.Set(RequestIdHeader, "test")
    47  			return hdr
    48  		}(),
    49  	}, {
    50  		Name: "Request generated ID",
    51  
    52  		Options: NewMiddlewareOptions().
    53  			SetGenerateRequestID(true),
    54  	}}
    55  
    56  	for i := range testCases {
    57  		tc := testCases[i]
    58  		t.Run(tc.Name, func(t *testing.T) {
    59  			t.Parallel()
    60  
    61  			router := gin.New()
    62  			router.Use(Middleware(tc.Options))
    63  			router.GET("/test", func(c *gin.Context) {})
    64  
    65  			w := httptest.NewRecorder()
    66  			req, _ := http.NewRequest("GET", "http://mender.io/test", nil)
    67  			for k, v := range tc.Headers {
    68  				for _, vv := range v {
    69  					req.Header.Add(k, vv)
    70  				}
    71  			}
    72  			router.ServeHTTP(w, req)
    73  
    74  			rsp := w.Result()
    75  
    76  			if id := tc.Headers.Get(RequestIdHeader); id != "" {
    77  				rspID := rsp.Header.Get(RequestIdHeader)
    78  				assert.Equal(t, id, rspID)
    79  			} else {
    80  				if tc.Options.GenerateRequestID != nil &&
    81  					*tc.Options.GenerateRequestID {
    82  					_, err := uuid.Parse(rsp.Header.Get(RequestIdHeader))
    83  					assert.NoError(t, err, "Generated requestID is not a UUID")
    84  				} else {
    85  					assert.Empty(t, rsp.Header.Get(RequestIdHeader))
    86  				}
    87  			}
    88  		})
    89  	}
    90  }
    91  
    92  func TestRequestIdMiddlewareWithReqID(t *testing.T) {
    93  	api := rest.NewApi()
    94  
    95  	api.Use(&RequestIdMiddleware{})
    96  
    97  	reqid := "4420a5b9-dbf2-4e5d-8b4f-3cf2013d04af"
    98  	api.SetApp(rest.AppSimple(func(w rest.ResponseWriter, r *rest.Request) {
    99  		assert.Equal(t, reqid, FromContext(r.Context()))
   100  		w.WriteJson(map[string]string{"foo": "bar"})
   101  	}))
   102  
   103  	handler := api.MakeHandler()
   104  
   105  	req := test.MakeSimpleRequest("GET", "http://localhost/", nil)
   106  	req.Header.Set(RequestIdHeader, reqid)
   107  
   108  	recorded := test.RunRequest(t, handler, req)
   109  	recorded.CodeIs(200)
   110  	recorded.ContentTypeIsJson()
   111  	recorded.HeaderIs(RequestIdHeader, reqid)
   112  
   113  }
   114  
   115  func TestRequestIdMiddlewareNoReqID(t *testing.T) {
   116  	api := rest.NewApi()
   117  
   118  	api.Use(&RequestIdMiddleware{})
   119  
   120  	api.SetApp(rest.AppSimple(func(w rest.ResponseWriter, r *rest.Request) {
   121  		reqid := FromContext(r.Context())
   122  		_, err := uuid.Parse(reqid)
   123  		assert.NoError(t, err)
   124  		w.WriteJson(map[string]string{"foo": "bar"})
   125  	}))
   126  
   127  	handler := api.MakeHandler()
   128  
   129  	req := test.MakeSimpleRequest("GET", "http://localhost/", nil)
   130  	recorded := test.RunRequest(t, handler, req)
   131  	recorded.CodeIs(200)
   132  	recorded.ContentTypeIsJson()
   133  	outReqIdStr := recorded.Recorder.HeaderMap.Get(RequestIdHeader)
   134  	_, err := uuid.Parse(outReqIdStr)
   135  	assert.NoError(t, err)
   136  }