github.com/xmidt-org/webpa-common@v1.11.9/device/devicegate/filterHandler_test.go (about) 1 package devicegate 2 3 import ( 4 "bytes" 5 "context" 6 "net/http" 7 "net/http/httptest" 8 "testing" 9 10 "github.com/stretchr/testify/assert" 11 "github.com/stretchr/testify/mock" 12 "github.com/xmidt-org/webpa-common/logging" 13 ) 14 15 func TestServeHTTPGet(t *testing.T) { 16 var ( 17 assert = assert.New(t) 18 logger = logging.NewTestLogger(nil, t) 19 ctx = logging.WithLogger(context.Background(), logger) 20 21 response = httptest.NewRecorder() 22 request = httptest.NewRequest("GET", "/", nil) 23 24 mockDeviceGate = new(mockDeviceGate) 25 26 f = FilterHandler{ 27 Gate: mockDeviceGate, 28 } 29 ) 30 31 mockDeviceGate.On("VisitAll", mock.Anything).Return(0) 32 mockDeviceGate.On("MarshalJSON").Return([]byte(`{}`), nil).Once() 33 f.GetFilters(response, request.WithContext(ctx)) 34 assert.Equal(http.StatusOK, response.Code) 35 assert.NotEmpty(response.Body) 36 37 } 38 39 func TestBadRequest(t *testing.T) { 40 var ( 41 logger = logging.NewTestLogger(nil, t) 42 ctx = logging.WithLogger(context.Background(), logger) 43 44 mockDeviceGate = new(mockDeviceGate) 45 f = FilterHandler{ 46 Gate: mockDeviceGate, 47 } 48 ) 49 50 tests := []struct { 51 description string 52 reqBody []byte 53 expectedStatusCode int 54 testDelete bool 55 }{ 56 { 57 description: "Unmarshal error", 58 reqBody: []byte(`this is not a filter request`), 59 expectedStatusCode: http.StatusBadRequest, 60 testDelete: true, 61 }, 62 { 63 description: "No filter key parameter", 64 reqBody: []byte(`{"test": "test"}`), 65 expectedStatusCode: http.StatusBadRequest, 66 testDelete: true, 67 }, 68 { 69 description: "No filter values", 70 reqBody: []byte(`{"key": "test"}`), 71 expectedStatusCode: http.StatusBadRequest, 72 }, 73 { 74 description: "Filter key not allowed", 75 reqBody: []byte(`{"key": "test", "values": ["test", "test1"]}`), 76 expectedStatusCode: http.StatusBadRequest, 77 }, 78 } 79 80 mockDeviceGate.On("GetAllowedFilters").Return(&FilterSet{}, true) 81 82 for _, tc := range tests { 83 t.Run(tc.description, func(t *testing.T) { 84 assert := assert.New(t) 85 86 requests := []*http.Request{ 87 httptest.NewRequest("POST", "/", bytes.NewBuffer(tc.reqBody)), 88 httptest.NewRequest("PUT", "/", bytes.NewBuffer(tc.reqBody)), 89 } 90 91 if tc.testDelete { 92 requests = append(requests, httptest.NewRequest("DELETE", "/", bytes.NewBuffer(tc.reqBody))) 93 } 94 95 response := httptest.NewRecorder() 96 97 for _, req := range requests { 98 f.UpdateFilters(response, req.WithContext(ctx)) 99 assert.Equal(tc.expectedStatusCode, response.Code) 100 } 101 102 }) 103 } 104 } 105 106 func TestSuccessfulAdd(t *testing.T) { 107 var ( 108 logger = logging.NewTestLogger(nil, t) 109 ctx = logging.WithLogger(context.Background(), logger) 110 ) 111 112 tests := []struct { 113 description string 114 request *http.Request 115 newKey bool 116 expectedStatusCode int 117 allowedFilters *FilterSet 118 allowedFiltersSet bool 119 }{ 120 { 121 description: "Successful POST", 122 request: httptest.NewRequest("POST", "/", bytes.NewBuffer([]byte(`{"key": "test", "values": ["test1", "test2"]}`))).WithContext(ctx), 123 newKey: true, 124 expectedStatusCode: http.StatusCreated, 125 }, 126 { 127 description: "Successful POST Update", 128 request: httptest.NewRequest("POST", "/", bytes.NewBuffer([]byte(`{"key": "test", "values": ["random new value"]}`))).WithContext(ctx), 129 expectedStatusCode: http.StatusOK, 130 }, 131 { 132 description: "Successful PUT", 133 request: httptest.NewRequest("PUT", "/", bytes.NewBuffer([]byte(`{"key": "test", "values": ["test1", "test2"]}`))).WithContext(ctx), 134 newKey: true, 135 expectedStatusCode: http.StatusCreated, 136 }, 137 { 138 description: "Successful PUT Update", 139 request: httptest.NewRequest("PUT", "/", bytes.NewBuffer([]byte(`{"key": "test", "values": ["random new value"]}`))).WithContext(ctx), 140 expectedStatusCode: http.StatusOK, 141 }, 142 { 143 description: "Successful with Allowed Filters", 144 request: httptest.NewRequest("POST", "/", bytes.NewBuffer([]byte(`{"key": "test", "values": ["test1", "test2"]}`))).WithContext(ctx), 145 newKey: true, 146 expectedStatusCode: http.StatusCreated, 147 allowedFilters: &FilterSet{Set: map[interface{}]bool{"test": true}}, 148 allowedFiltersSet: true, 149 }, 150 } 151 152 for _, tc := range tests { 153 t.Run(tc.description, func(t *testing.T) { 154 assert := assert.New(t) 155 mockDeviceGate := new(mockDeviceGate) 156 f := FilterHandler{ 157 Gate: mockDeviceGate, 158 } 159 160 mockDeviceGate.On("MarshalJSON").Return([]byte(`{}`), nil) 161 mockDeviceGate.On("GetAllowedFilters").Return(tc.allowedFilters, tc.allowedFiltersSet).Once() 162 mockDeviceGate.On("SetFilter", mock.AnythingOfType("string"), mock.Anything).Return(nil, tc.newKey).Once() 163 mockDeviceGate.On("VisitAll", mock.Anything).Return(0).Once() 164 165 response := httptest.NewRecorder() 166 f.UpdateFilters(response, tc.request) 167 assert.Equal(tc.expectedStatusCode, response.Code) 168 169 }) 170 } 171 172 } 173 174 func TestDelete(t *testing.T) { 175 var ( 176 logger = logging.NewTestLogger(nil, t) 177 ctx = logging.WithLogger(context.Background(), logger) 178 assert = assert.New(t) 179 response = httptest.NewRecorder() 180 181 mockDeviceGate = new(mockDeviceGate) 182 f = FilterHandler{ 183 Gate: mockDeviceGate, 184 } 185 ) 186 187 mockDeviceGate.On("DeleteFilter", "test").Return(true).Once() 188 mockDeviceGate.On("VisitAll", mock.Anything).Return(0).Once() 189 mockDeviceGate.On("MarshalJSON").Return([]byte(`{}`), nil).Once() 190 191 req := httptest.NewRequest("DELETE", "/", bytes.NewBuffer([]byte(`{"key": "test"}`))) 192 f.DeleteFilter(response, req.WithContext(ctx)) 193 assert.Equal(http.StatusOK, response.Code) 194 } 195 196 func TestGateLogger(t *testing.T) { 197 198 var ( 199 logger = logging.NewTestLogger(nil, t) 200 gate = &FilterGate{ 201 FilterStore: FilterStore(map[string]Set{ 202 "partner-id": &FilterSet{ 203 Set: map[interface{}]bool{"comcast": true}, 204 }, 205 }), 206 } 207 208 gl = GateLogger{ 209 Logger: logger, 210 } 211 212 assert = assert.New(t) 213 ) 214 215 tests := []struct { 216 description string 217 next http.Handler 218 expectedEmpty bool 219 }{ 220 { 221 description: "Success", 222 next: http.HandlerFunc(func(response http.ResponseWriter, request *http.Request) { 223 response.WriteHeader(201) 224 newCtx := context.WithValue(request.Context(), gateKey, gate) 225 *request = *request.WithContext(newCtx) 226 }), 227 }, 228 { 229 description: "No gate set in context", 230 next: http.HandlerFunc(func(response http.ResponseWriter, request *http.Request) { 231 response.WriteHeader(201) 232 }), 233 expectedEmpty: true, 234 }, 235 } 236 237 for _, tc := range tests { 238 response := httptest.NewRecorder() 239 req := httptest.NewRequest("GET", "/", nil) 240 handler := gl.LogFilters(tc.next) 241 handler.ServeHTTP(response, req) 242 243 if tc.expectedEmpty { 244 assert.Empty(response.Body) 245 } else { 246 assert.NotEmpty(response.Body) 247 } 248 249 } 250 251 }