github.com/xmidt-org/webpa-common@v1.11.9/device/drain/start_test.go (about) 1 package drain 2 3 import ( 4 "bytes" 5 "context" 6 "errors" 7 "net/http" 8 "net/http/httptest" 9 "testing" 10 "time" 11 12 "github.com/xmidt-org/webpa-common/device/devicegate" 13 14 "github.com/stretchr/testify/assert" 15 "github.com/xmidt-org/webpa-common/logging" 16 ) 17 18 func testStartServeHTTPDefaultLogger(t *testing.T) { 19 var ( 20 assert = assert.New(t) 21 22 d = new(mockDrainer) 23 done <-chan struct{} = make(chan struct{}) 24 start = Start{d} 25 26 response = httptest.NewRecorder() 27 request = httptest.NewRequest("POST", "/", nil) 28 ) 29 30 d.On("Start", Job{}).Return(done, Job{Count: 126, Percent: 10, Rate: 12, Tick: 5 * time.Minute}, error(nil)) 31 start.ServeHTTP(response, request) 32 assert.Equal(http.StatusOK, response.Code) 33 assert.Equal("application/json", response.Header().Get("Content-Type")) 34 assert.JSONEq( 35 `{"count": 126, "percent": 10, "rate": 12, "tick": "5m0s"}`, 36 response.Body.String(), 37 ) 38 39 d.AssertExpectations(t) 40 } 41 42 func testStartServeHTTPValid(t *testing.T) { 43 testData := []struct { 44 uri string 45 expected Job 46 }{ 47 { 48 uri: "/foo", 49 expected: Job{}, 50 }, 51 { 52 uri: "/foo?count=100", 53 expected: Job{Count: 100}, 54 }, 55 { 56 uri: "/foo?rate=10", 57 expected: Job{Rate: 10}, 58 }, 59 { 60 uri: "/foo?rate=23&tick=1m", 61 expected: Job{Rate: 23, Tick: time.Minute}, 62 }, 63 { 64 uri: "/foo?count=22&rate=10&tick=20s", 65 expected: Job{Count: 22, Rate: 10, Tick: 20 * time.Second}, 66 }, 67 } 68 69 for _, record := range testData { 70 t.Run(record.uri, func(t *testing.T) { 71 var ( 72 assert = assert.New(t) 73 74 d = new(mockDrainer) 75 done <-chan struct{} = make(chan struct{}) 76 start = Start{d} 77 78 ctx = logging.WithLogger(context.Background(), logging.NewTestLogger(nil, t)) 79 response = httptest.NewRecorder() 80 request = httptest.NewRequest("POST", record.uri, nil).WithContext(ctx) 81 ) 82 83 d.On("Start", record.expected).Return(done, Job{Count: 47192, Percent: 57, Rate: 500, Tick: 37 * time.Second, DrainFilter: record.expected.DrainFilter}, error(nil)).Once() 84 start.ServeHTTP(response, request) 85 assert.Equal(http.StatusOK, response.Code) 86 assert.Equal("application/json", response.Header().Get("Content-Type")) 87 assert.JSONEq( 88 `{"count": 47192, "percent": 57, "rate": 500, "tick": "37s"}`, 89 response.Body.String(), 90 ) 91 d.AssertExpectations(t) 92 }) 93 } 94 } 95 96 func testStartServeHTTPWithBody(t *testing.T) { 97 df := &drainFilter{ 98 filter: &devicegate.FilterGate{ 99 FilterStore: devicegate.FilterStore(map[string]devicegate.Set{ 100 "test": &devicegate.FilterSet{Set: map[interface{}]bool{ 101 "test1": true, 102 "test2": true, 103 }}, 104 }), 105 }, 106 filterRequest: devicegate.FilterRequest{ 107 Key: "test", 108 Values: []interface{}{"test1", "test2"}, 109 }, 110 } 111 112 testData := []struct { 113 description string 114 body []byte 115 expected Job 116 expectedJSON string 117 expectedStatusCode int 118 }{ 119 { 120 description: "Success with body", 121 body: []byte(`{"key": "test", "values": ["test1", "test2"]}`), 122 expected: Job{Count: 22, Rate: 10, Tick: 20 * time.Second, DrainFilter: df}, 123 expectedJSON: `{"count": 47192, "percent": 57, "rate": 500, "tick": "37s", "filter":{"key": "test", "values": ["test1", "test2"]}}`, 124 expectedStatusCode: http.StatusOK, 125 }, 126 { 127 description: "Unmarshal error", 128 body: []byte(`this is not a filter request`), 129 expected: Job{Count: 22, Rate: 10, Tick: 20 * time.Second}, 130 expectedStatusCode: http.StatusBadRequest, 131 }, 132 { 133 description: "Empty Body", 134 body: []byte(`{}`), 135 expected: Job{Count: 22, Rate: 10, Tick: 20 * time.Second}, 136 expectedStatusCode: http.StatusOK, 137 }, 138 { 139 description: "No value field", 140 body: []byte(`{"key": "test"}`), 141 expected: Job{Count: 22, Rate: 10, Tick: 20 * time.Second}, 142 expectedStatusCode: http.StatusOK, 143 }, 144 { 145 description: "No key", 146 body: []byte(`{"values": ["test1", "test2"]}`), 147 expected: Job{Count: 22, Rate: 10, Tick: 20 * time.Second}, 148 expectedStatusCode: http.StatusOK, 149 }, 150 { 151 description: "Empty values array", 152 body: []byte(`{"key": "test", "values": []}`), 153 expected: Job{Count: 22, Rate: 10, Tick: 20 * time.Second}, 154 expectedStatusCode: http.StatusOK, 155 }, 156 } 157 158 for _, record := range testData { 159 t.Run(record.description, func(t *testing.T) { 160 var ( 161 assert = assert.New(t) 162 163 d = new(mockDrainer) 164 done <-chan struct{} = make(chan struct{}) 165 start = Start{d} 166 167 ctx = logging.WithLogger(context.Background(), logging.NewTestLogger(nil, t)) 168 response = httptest.NewRecorder() 169 request = httptest.NewRequest("POST", "/foo?count=22&rate=10&tick=20s", bytes.NewBuffer(record.body)).WithContext(ctx) 170 ) 171 172 if record.expectedStatusCode == http.StatusOK { 173 d.On("Start", record.expected).Return(done, Job{Count: 47192, Percent: 57, Rate: 500, Tick: 37 * time.Second, DrainFilter: record.expected.DrainFilter}, error(nil)).Once() 174 } 175 start.ServeHTTP(response, request) 176 assert.Equal(record.expectedStatusCode, response.Code) 177 assert.Equal("application/json", response.Header().Get("Content-Type")) 178 if record.expectedStatusCode == http.StatusOK { 179 if len(record.expectedJSON) == 0 { 180 assert.JSONEq( 181 `{"count": 47192, "percent": 57, "rate": 500, "tick": "37s"}`, 182 response.Body.String(), 183 ) 184 } else { 185 assert.JSONEq(record.expectedJSON, response.Body.String()) 186 } 187 } 188 189 d.AssertExpectations(t) 190 }) 191 } 192 193 } 194 195 func testStartServeHTTPParseFormError(t *testing.T) { 196 var ( 197 assert = assert.New(t) 198 199 d = new(mockDrainer) 200 start = Start{d} 201 202 ctx = logging.WithLogger(context.Background(), logging.NewTestLogger(nil, t)) 203 response = httptest.NewRecorder() 204 request = httptest.NewRequest("POST", "/foo?%TT*&&", nil).WithContext(ctx) 205 ) 206 207 start.ServeHTTP(response, request) 208 assert.Equal(http.StatusBadRequest, response.Code) 209 d.AssertExpectations(t) 210 } 211 212 func testStartServeHTTPInvalidQuery(t *testing.T) { 213 var ( 214 assert = assert.New(t) 215 216 d = new(mockDrainer) 217 start = Start{d} 218 219 ctx = logging.WithLogger(context.Background(), logging.NewTestLogger(nil, t)) 220 response = httptest.NewRecorder() 221 request = httptest.NewRequest("POST", "/foo?count=asdf", nil).WithContext(ctx) 222 ) 223 224 start.ServeHTTP(response, request) 225 assert.Equal(http.StatusBadRequest, response.Code) 226 d.AssertExpectations(t) 227 } 228 229 func testStartServeHTTPStartError(t *testing.T) { 230 var ( 231 assert = assert.New(t) 232 233 d = new(mockDrainer) 234 done <-chan struct{} 235 start = Start{d} 236 expectedError = errors.New("expected") 237 238 ctx = logging.WithLogger(context.Background(), logging.NewTestLogger(nil, t)) 239 response = httptest.NewRecorder() 240 request = httptest.NewRequest("POST", "/foo?count=100", nil).WithContext(ctx) 241 ) 242 243 d.On("Start", Job{Count: 100}).Return(done, Job{}, expectedError).Once() 244 start.ServeHTTP(response, request) 245 assert.Equal(http.StatusConflict, response.Code) 246 d.AssertExpectations(t) 247 } 248 249 func TestStart(t *testing.T) { 250 t.Run("ServeHTTP", func(t *testing.T) { 251 t.Run("DefaultLogger", testStartServeHTTPDefaultLogger) 252 t.Run("Valid", testStartServeHTTPValid) 253 t.Run("WithBody", testStartServeHTTPWithBody) 254 t.Run("ParseFormError", testStartServeHTTPParseFormError) 255 t.Run("InvalidQuery", testStartServeHTTPInvalidQuery) 256 t.Run("StartError", testStartServeHTTPStartError) 257 }) 258 }