github.com/m3db/m3@v1.5.0/src/query/api/v1/middleware/source_test.go (about) 1 // Copyright (c) 2021 Uber Technologies, Inc. 2 // 3 // Permission is hereby granted, free of charge, to any person obtaining a copy 4 // of this software and associated documentation files (the "Software"), to deal 5 // in the Software without restriction, including without limitation the rights 6 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 // copies of the Software, and to permit persons to whom the Software is 8 // furnished to do so, subject to the following conditions: 9 // 10 // The above copyright notice and this permission notice shall be included in 11 // all copies or substantial portions of the Software. 12 // 13 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19 // THE SOFTWARE. 20 21 package middleware 22 23 import ( 24 "context" 25 "errors" 26 "net/http" 27 "net/http/httptest" 28 "testing" 29 30 "github.com/gorilla/mux" 31 "github.com/stretchr/testify/require" 32 "go.uber.org/zap" 33 "go.uber.org/zap/zapcore" 34 "go.uber.org/zap/zaptest/observer" 35 36 "github.com/m3db/m3/src/query/source" 37 "github.com/m3db/m3/src/query/util/logging" 38 "github.com/m3db/m3/src/x/headers" 39 "github.com/m3db/m3/src/x/instrument" 40 ) 41 42 type testSource struct { 43 name string 44 } 45 46 var testDeserialize = func(bytes []byte) (interface{}, error) { 47 return testSource{string(bytes)}, nil 48 } 49 50 func TestMiddleware(t *testing.T) { 51 cases := []struct { 52 name string 53 sourceHeader string 54 expected testSource 55 expectedLog string 56 deserializer source.Deserializer 57 invalidErr bool 58 }{ 59 { 60 name: "happy path", 61 sourceHeader: "foobar", 62 expected: testSource{"foobar"}, 63 expectedLog: "foobar", 64 }, 65 { 66 name: "no source header", 67 sourceHeader: "", 68 expected: testSource{""}, 69 }, 70 { 71 name: "deserialize error", 72 sourceHeader: "foobar", 73 invalidErr: true, 74 deserializer: func(bytes []byte) (interface{}, error) { 75 return nil, errors.New("boom") 76 }, 77 }, 78 } 79 80 for _, tc := range cases { 81 tc := tc 82 core, recorded := observer.New(zapcore.InfoLevel) 83 l := zap.New(core) 84 iOpts := instrument.NewOptions().SetLogger(l) 85 t.Run(tc.name, func(t *testing.T) { 86 r := mux.NewRouter() 87 r.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { 88 l = logging.WithContext(r.Context(), iOpts) 89 l.Info("test") 90 typed, ok := source.FromContext(r.Context()) 91 if tc.expected.name == "" { 92 require.False(t, ok) 93 require.Nil(t, typed) 94 } else { 95 require.True(t, ok) 96 require.Equal(t, tc.expected, typed.(testSource)) 97 } 98 }) 99 if tc.deserializer == nil { 100 tc.deserializer = testDeserialize 101 } 102 r.Use(Source(Options{ 103 InstrumentOpts: iOpts, 104 Source: SourceOptions{ 105 Deserializer: tc.deserializer, 106 }, 107 })) 108 s := httptest.NewServer(r) 109 defer s.Close() 110 111 req, err := http.NewRequestWithContext(context.Background(), "GET", s.URL, nil) 112 require.NoError(t, err) 113 req.Header.Set(headers.SourceHeader, tc.sourceHeader) 114 resp, err := s.Client().Do(req) 115 require.NoError(t, err) 116 require.NoError(t, resp.Body.Close()) 117 if tc.invalidErr { 118 require.Equal(t, http.StatusBadRequest, resp.StatusCode) 119 } else { 120 require.Equal(t, http.StatusOK, resp.StatusCode) 121 testMsgs := recorded.FilterMessage("test").All() 122 require.Len(t, testMsgs, 1) 123 entry := testMsgs[0] 124 require.Equal(t, "test", entry.Message) 125 fields := entry.ContextMap() 126 if tc.expectedLog != "" { 127 require.Len(t, fields, 1) 128 require.Equal(t, tc.expectedLog, fields["source"]) 129 } 130 } 131 }) 132 } 133 }