github.com/m3db/m3@v1.5.0/src/query/api/v1/middleware/source_test.go (about)

     1  // Copyright (c) 2021 Uber Technologies, Inc.
     2  //
     3  // Permission is hereby granted, free of charge, to any person obtaining a copy
     4  // of this software and associated documentation files (the "Software"), to deal
     5  // in the Software without restriction, including without limitation the rights
     6  // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
     7  // copies of the Software, and to permit persons to whom the Software is
     8  // furnished to do so, subject to the following conditions:
     9  //
    10  // The above copyright notice and this permission notice shall be included in
    11  // all copies or substantial portions of the Software.
    12  //
    13  // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    14  // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    15  // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    16  // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    17  // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    18  // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
    19  // THE SOFTWARE.
    20  
    21  package middleware
    22  
    23  import (
    24  	"context"
    25  	"errors"
    26  	"net/http"
    27  	"net/http/httptest"
    28  	"testing"
    29  
    30  	"github.com/gorilla/mux"
    31  	"github.com/stretchr/testify/require"
    32  	"go.uber.org/zap"
    33  	"go.uber.org/zap/zapcore"
    34  	"go.uber.org/zap/zaptest/observer"
    35  
    36  	"github.com/m3db/m3/src/query/source"
    37  	"github.com/m3db/m3/src/query/util/logging"
    38  	"github.com/m3db/m3/src/x/headers"
    39  	"github.com/m3db/m3/src/x/instrument"
    40  )
    41  
    42  type testSource struct {
    43  	name string
    44  }
    45  
    46  var testDeserialize = func(bytes []byte) (interface{}, error) {
    47  	return testSource{string(bytes)}, nil
    48  }
    49  
    50  func TestMiddleware(t *testing.T) {
    51  	cases := []struct {
    52  		name         string
    53  		sourceHeader string
    54  		expected     testSource
    55  		expectedLog  string
    56  		deserializer source.Deserializer
    57  		invalidErr   bool
    58  	}{
    59  		{
    60  			name:         "happy path",
    61  			sourceHeader: "foobar",
    62  			expected:     testSource{"foobar"},
    63  			expectedLog:  "foobar",
    64  		},
    65  		{
    66  			name:         "no source header",
    67  			sourceHeader: "",
    68  			expected:     testSource{""},
    69  		},
    70  		{
    71  			name:         "deserialize error",
    72  			sourceHeader: "foobar",
    73  			invalidErr:   true,
    74  			deserializer: func(bytes []byte) (interface{}, error) {
    75  				return nil, errors.New("boom")
    76  			},
    77  		},
    78  	}
    79  
    80  	for _, tc := range cases {
    81  		tc := tc
    82  		core, recorded := observer.New(zapcore.InfoLevel)
    83  		l := zap.New(core)
    84  		iOpts := instrument.NewOptions().SetLogger(l)
    85  		t.Run(tc.name, func(t *testing.T) {
    86  			r := mux.NewRouter()
    87  			r.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
    88  				l = logging.WithContext(r.Context(), iOpts)
    89  				l.Info("test")
    90  				typed, ok := source.FromContext(r.Context())
    91  				if tc.expected.name == "" {
    92  					require.False(t, ok)
    93  					require.Nil(t, typed)
    94  				} else {
    95  					require.True(t, ok)
    96  					require.Equal(t, tc.expected, typed.(testSource))
    97  				}
    98  			})
    99  			if tc.deserializer == nil {
   100  				tc.deserializer = testDeserialize
   101  			}
   102  			r.Use(Source(Options{
   103  				InstrumentOpts: iOpts,
   104  				Source: SourceOptions{
   105  					Deserializer: tc.deserializer,
   106  				},
   107  			}))
   108  			s := httptest.NewServer(r)
   109  			defer s.Close()
   110  
   111  			req, err := http.NewRequestWithContext(context.Background(), "GET", s.URL, nil)
   112  			require.NoError(t, err)
   113  			req.Header.Set(headers.SourceHeader, tc.sourceHeader)
   114  			resp, err := s.Client().Do(req)
   115  			require.NoError(t, err)
   116  			require.NoError(t, resp.Body.Close())
   117  			if tc.invalidErr {
   118  				require.Equal(t, http.StatusBadRequest, resp.StatusCode)
   119  			} else {
   120  				require.Equal(t, http.StatusOK, resp.StatusCode)
   121  				testMsgs := recorded.FilterMessage("test").All()
   122  				require.Len(t, testMsgs, 1)
   123  				entry := testMsgs[0]
   124  				require.Equal(t, "test", entry.Message)
   125  				fields := entry.ContextMap()
   126  				if tc.expectedLog != "" {
   127  					require.Len(t, fields, 1)
   128  					require.Equal(t, tc.expectedLog, fields["source"])
   129  				}
   130  			}
   131  		})
   132  	}
   133  }