github.com/xmidt-org/webpa-common@v1.11.9/device/manager_test.go (about) 1 package device 2 3 import ( 4 "encoding/json" 5 "fmt" 6 "net/http" 7 "net/http/httptest" 8 "net/url" 9 "sync" 10 "testing" 11 "time" 12 13 "github.com/go-kit/kit/log" 14 "github.com/go-kit/kit/metrics" 15 16 "github.com/xmidt-org/webpa-common/convey" 17 "github.com/xmidt-org/webpa-common/xmetrics" 18 19 "github.com/justinas/alice" 20 "github.com/stretchr/testify/assert" 21 "github.com/stretchr/testify/mock" 22 "github.com/stretchr/testify/require" 23 "github.com/xmidt-org/webpa-common/logging" 24 "github.com/xmidt-org/wrp-go/v3" 25 ) 26 27 var ( 28 testDeviceIDs = []ID{ 29 IntToMAC(0xDEADBEEF), 30 IntToMAC(0x112233445566), 31 IntToMAC(0xFE881212CDCD), 32 IntToMAC(0x7F551928ABCD), 33 } 34 ) 35 36 // startWebsocketServer sets up a server-side environment for testing device-related websocket code 37 func startWebsocketServer(o *Options) (Manager, *httptest.Server, string) { 38 var ( 39 manager = NewManager(o) 40 server = httptest.NewServer( 41 alice.New(Timeout(o), UseID.FromHeader).Then( 42 &ConnectHandler{ 43 Logger: o.logger(), 44 Connector: manager, 45 }, 46 ), 47 ) 48 49 websocketURL, err = url.Parse(server.URL) 50 ) 51 52 if err != nil { 53 server.Close() 54 panic(fmt.Errorf("Unable to parse test server URL: %s", err)) 55 } 56 57 websocketURL.Scheme = "ws" 58 return manager, server, websocketURL.String() 59 } 60 61 func connectTestDevices(t *testing.T, dialer Dialer, connectURL string) map[ID]Connection { 62 devices := make(map[ID]Connection, len(testDeviceIDs)) 63 64 for _, id := range testDeviceIDs { 65 deviceConnection, _, err := dialer.DialDevice(string(id), connectURL, nil) 66 if err != nil { 67 t.Fatalf("Unable to dial test device: %s", err) 68 break 69 } 70 71 devices[id] = deviceConnection 72 } 73 74 return devices 75 } 76 77 func closeTestDevices(assert *assert.Assertions, devices map[ID]Connection) { 78 for _, connection := range devices { 79 assert.Nil(connection.Close()) 80 } 81 } 82 83 func testManagerConnectFilterDeny(t *testing.T) { 84 assert := assert.New(t) 85 mockFilter := new(mockFilter) 86 options := &Options{ 87 Logger: log.NewNopLogger(), 88 Filter: mockFilter, 89 } 90 91 manager := NewManager(options) 92 response := httptest.NewRecorder() 93 request := WithIDRequest(ID("mac:123412341234"), httptest.NewRequest("POST", "http://localhost.com", nil)) 94 95 mockFilter.On("AllowConnection", mock.Anything).Return(false, MatchResult{}).Once() 96 97 device, err := manager.Connect(response, request, nil) 98 assert.Nil(device) 99 assert.Equal(err, ErrorDeviceFilteredOut) 100 101 } 102 103 func testManagerConnectMissingDeviceContext(t *testing.T) { 104 assert := assert.New(t) 105 options := &Options{ 106 Logger: log.NewNopLogger(), 107 } 108 109 manager := NewManager(options) 110 response := httptest.NewRecorder() 111 request := httptest.NewRequest("POST", "http://localhost.com", nil) 112 113 device, err := manager.Connect(response, request, nil) 114 assert.Nil(device) 115 assert.Error(err) 116 assert.Equal(response.Code, http.StatusInternalServerError) 117 } 118 119 func testManagerConnectUpgradeError(t *testing.T) { 120 var ( 121 assert = assert.New(t) 122 options = &Options{ 123 Logger: log.NewNopLogger(), 124 Listeners: []Listener{ 125 func(e *Event) { 126 assert.Fail("The listener should not have been called") 127 }, 128 }, 129 } 130 131 manager = NewManager(options) 132 response = httptest.NewRecorder() 133 request = WithIDRequest(ID("mac:123412341234"), httptest.NewRequest("POST", "http://localhost.com", nil)) 134 responseHeader http.Header 135 ) 136 137 device, actualError := manager.Connect(response, request, responseHeader) 138 assert.Nil(device) 139 assert.Error(actualError) 140 } 141 142 func testManagerConnectVisit(t *testing.T) { 143 var ( 144 assert = assert.New(t) 145 connectWait = new(sync.WaitGroup) 146 connections = make(chan Interface, len(testDeviceIDs)) 147 148 options = &Options{ 149 Logger: log.NewNopLogger(), 150 Listeners: []Listener{ 151 func(event *Event) { 152 if event.Type == Connect { 153 defer connectWait.Done() 154 select { 155 case connections <- event.Device: 156 default: 157 assert.Fail("The connect listener should not block") 158 } 159 } 160 }, 161 }, 162 } 163 164 manager, server, connectURL = startWebsocketServer(options) 165 ) 166 167 defer server.Close() 168 connectWait.Add(len(testDeviceIDs)) 169 170 testDevices := connectTestDevices(t, DefaultDialer(), connectURL) 171 defer closeTestDevices(assert, testDevices) 172 173 connectWait.Wait() 174 close(connections) 175 assert.Equal(len(testDeviceIDs), len(connections)) 176 177 deviceSet := make(deviceSet) 178 for candidate := range connections { 179 deviceSet.add(candidate) 180 } 181 182 assert.Equal(len(testDeviceIDs), deviceSet.len()) 183 deviceSet.reset() 184 manager.VisitAll(deviceSet.managerCapture()) 185 assert.Equal(len(testDeviceIDs), deviceSet.len()) 186 } 187 188 func testManagerDisconnect(t *testing.T) { 189 assert := assert.New(t) 190 connectWait := new(sync.WaitGroup) 191 connectWait.Add(len(testDeviceIDs)) 192 193 disconnectWait := new(sync.WaitGroup) 194 disconnectWait.Add(len(testDeviceIDs)) 195 disconnections := make(chan Interface, len(testDeviceIDs)) 196 197 options := &Options{ 198 Logger: logging.NewTestLogger(nil, t), 199 Listeners: []Listener{ 200 func(event *Event) { 201 switch event.Type { 202 case Connect: 203 connectWait.Done() 204 case Disconnect: 205 defer disconnectWait.Done() 206 assert.True(event.Device.Closed()) 207 disconnections <- event.Device 208 } 209 }, 210 }, 211 } 212 213 manager, server, connectURL := startWebsocketServer(options) 214 defer server.Close() 215 216 testDevices := connectTestDevices(t, DefaultDialer(), connectURL) 217 defer closeTestDevices(assert, testDevices) 218 219 connectWait.Wait() 220 assert.Zero(manager.Disconnect(ID("nosuch"), CloseReason{})) 221 for _, id := range testDeviceIDs { 222 assert.Equal(true, manager.Disconnect(id, CloseReason{})) 223 } 224 225 disconnectWait.Wait() 226 close(disconnections) 227 assert.Equal(len(testDeviceIDs), len(disconnections)) 228 229 deviceSet := make(deviceSet) 230 deviceSet.drain(disconnections) 231 assert.Equal(len(testDeviceIDs), deviceSet.len()) 232 } 233 234 func testManagerDisconnectIf(t *testing.T) { 235 assert := assert.New(t) 236 connectWait := new(sync.WaitGroup) 237 connectWait.Add(len(testDeviceIDs)) 238 disconnections := make(chan Interface, len(testDeviceIDs)) 239 240 options := &Options{ 241 Logger: logging.NewTestLogger(nil, t), 242 Listeners: []Listener{ 243 func(event *Event) { 244 switch event.Type { 245 case Connect: 246 connectWait.Done() 247 case Disconnect: 248 assert.True(event.Device.Closed()) 249 disconnections <- event.Device 250 } 251 }, 252 }, 253 } 254 255 manager, server, connectURL := startWebsocketServer(options) 256 defer server.Close() 257 258 testDevices := connectTestDevices(t, DefaultDialer(), connectURL) 259 defer closeTestDevices(assert, testDevices) 260 261 connectWait.Wait() 262 deviceSet := make(deviceSet) 263 manager.VisitAll(deviceSet.managerCapture()) 264 assert.Equal(len(testDeviceIDs), deviceSet.len()) 265 266 assert.Zero(manager.DisconnectIf(func(ID) (CloseReason, bool) { return CloseReason{}, false })) 267 select { 268 case <-disconnections: 269 assert.Fail("No disconnections should have occurred") 270 default: 271 // the passing case 272 } 273 274 for _, id := range testDeviceIDs { 275 assert.Equal(1, manager.DisconnectIf(func(candidate ID) (CloseReason, bool) { return CloseReason{}, candidate == id })) 276 select { 277 case actual := <-disconnections: 278 assert.Equal(id, actual.ID()) 279 assert.True(actual.Closed()) 280 case <-time.After(10 * time.Second): 281 assert.Fail("No disconnection occurred within the timeout") 282 } 283 } 284 } 285 286 func testManagerRouteBadDestination(t *testing.T) { 287 var ( 288 assert = assert.New(t) 289 request = &Request{ 290 Message: &wrp.Message{ 291 Destination: "this is a bad destination", 292 }, 293 } 294 295 manager = NewManager(nil) 296 ) 297 298 response, err := manager.Route(request) 299 assert.Nil(response) 300 assert.Error(err) 301 } 302 303 func testManagerRouteDeviceNotFound(t *testing.T) { 304 var ( 305 assert = assert.New(t) 306 request = &Request{ 307 Message: &wrp.Message{ 308 Destination: "mac:112233445566", 309 }, 310 } 311 312 manager = NewManager(nil) 313 ) 314 315 response, err := manager.Route(request) 316 assert.Nil(response) 317 assert.Equal(ErrorDeviceNotFound, err) 318 } 319 320 func testManagerConnectIncludesConvey(t *testing.T) { 321 var ( 322 assert = assert.New(t) 323 require = require.New(t) 324 connectWait = new(sync.WaitGroup) 325 contents = make(chan []byte, 1) 326 327 options = &Options{ 328 Logger: log.NewNopLogger(), 329 Listeners: []Listener{ 330 func(event *Event) { 331 if event.Type == Connect { 332 defer connectWait.Done() 333 select { 334 case contents <- event.Contents: 335 default: 336 assert.Fail("The connect listener should not block") 337 } 338 } 339 }, 340 }, 341 } 342 343 _, server, connectURL = startWebsocketServer(options) 344 ) 345 346 defer server.Close() 347 connectWait.Add(1) 348 349 dialer := DefaultDialer() 350 351 /* 352 Convey header in base 64: 353 { 354 "hw-serial-number":123456789, 355 "webpa-protocol":"WebPA-1.6" 356 } 357 358 */ 359 header := &http.Header{ 360 "X-Webpa-Convey": {"eyAgDQogICAiaHctc2VyaWFsLW51bWJlciI6MTIzNDU2Nzg5LA0KICAgIndlYnBhLXByb3RvY29sIjoiV2ViUEEtMS42Ig0KfQ=="}, 361 } 362 363 deviceConnection, _, err := dialer.DialDevice(string(testDeviceIDs[0]), connectURL, *header) 364 require.NotNil(deviceConnection) 365 require.NoError(err) 366 367 defer assert.NoError(deviceConnection.Close()) 368 369 connectWait.Wait() 370 close(contents) 371 assert.Equal(1, len(contents)) 372 373 content := <-contents 374 convey := make(map[string]interface{}) 375 err = json.Unmarshal(content, &convey) 376 377 assert.Nil(err) 378 assert.Equal(2, len(convey)) 379 assert.Equal(float64(123456789), convey["hw-serial-number"]) 380 assert.Equal("WebPA-1.6", convey["webpa-protocol"]) 381 } 382 383 func TestManager(t *testing.T) { 384 t.Run("Connect", func(t *testing.T) { 385 t.Run("MissingDeviceContext", testManagerConnectMissingDeviceContext) 386 t.Run("FilterOutDevice", testManagerConnectFilterDeny) 387 t.Run("UpgradeError", testManagerConnectUpgradeError) 388 t.Run("Visit", testManagerConnectVisit) 389 t.Run("IncludesConvey", testManagerConnectIncludesConvey) 390 }) 391 392 t.Run("Route", func(t *testing.T) { 393 t.Run("BadDestination", testManagerRouteBadDestination) 394 t.Run("DeviceNotFound", testManagerRouteDeviceNotFound) 395 }) 396 397 t.Run("Disconnect", testManagerDisconnect) 398 t.Run("DisconnectIf", testManagerDisconnectIf) 399 } 400 401 func TestGaugeCardinality(t *testing.T) { 402 var ( 403 assert = assert.New(t) 404 r, err = xmetrics.NewRegistry(nil, Metrics) 405 m = NewManager(&Options{ 406 MetricsProvider: r, 407 }) 408 ) 409 assert.NoError(err) 410 411 assert.NotPanics(func() { 412 dec, err := m.(*manager).conveyHWMetric.Update(convey.C{"hw-model": "cardinality", "fw-name": "firmware-number", "model": "f"}, "partnerid", "comcast", "trust", "0") 413 assert.NoError(err) 414 dec() 415 }) 416 417 assert.Panics(func() { 418 m.(*manager).measures.Models.With("neat", "bad").Add(-1) 419 }) 420 } 421 422 func TestWRPSourceIsValid(t *testing.T) { 423 assert := assert.New(t) 424 canonicalID := ID("mac:112233445566") 425 testData := []struct { 426 Name string 427 Source string 428 IsValid bool 429 BaseLabelPairs map[string]string 430 }{ 431 { 432 Name: "EmptySource", 433 IsValid: false, 434 Source: " ", 435 BaseLabelPairs: map[string]string{"reason": "empty"}, 436 }, 437 438 { 439 Name: "ParseFailure", 440 IsValid: false, 441 Source: "serial>hacker/service", 442 BaseLabelPairs: map[string]string{"reason": "parse_error"}, 443 }, 444 { 445 Name: "IDMismatch", 446 IsValid: false, 447 Source: "mac:665544332211/service/some/path", 448 BaseLabelPairs: map[string]string{"reason": "id_mismatch"}, 449 }, 450 { 451 Name: "IDMatch", 452 IsValid: true, 453 Source: "mac:112233445566/service/some/path", 454 BaseLabelPairs: map[string]string{"reason": "id_match"}, 455 }, 456 } 457 458 for _, record := range testData { 459 t.Run(record.Name, func(t *testing.T) { 460 expectedStrictLabels, expectedLenientLabels := createLabelMaps(!record.IsValid, record.BaseLabelPairs) 461 462 d := new(device) 463 d.id = canonicalID 464 d.errorLog = log.WithPrefix(logging.NewTestLogger(nil, t), "id", canonicalID) 465 d.metadata = new(Metadata) 466 467 // strict mode 468 counter := newTestCounter() 469 message := &wrp.Message{Source: record.Source} 470 m := &manager{enforceWRPSourceCheck: true, measures: Measures{WRPSourceCheck: counter}} 471 ok := m.wrpSourceIsValid(message, d) 472 assert.Equal(record.IsValid, ok) 473 assert.Equal(expectedStrictLabels, counter.labelPairs) 474 475 // lenient mode 476 counter = newTestCounter() 477 message = &wrp.Message{Source: record.Source} 478 m = &manager{enforceWRPSourceCheck: false, measures: Measures{WRPSourceCheck: counter}} 479 480 ok = m.wrpSourceIsValid(message, d) 481 assert.True(ok) 482 assert.Equal(expectedLenientLabels, counter.labelPairs) 483 }) 484 } 485 486 } 487 488 func createLabelMaps(rejected bool, baseLabelPairs map[string]string) (strict map[string]string, lenient map[string]string) { 489 strict = make(map[string]string) 490 lenient = make(map[string]string) 491 492 for k, v := range baseLabelPairs { 493 strict[k] = v 494 lenient[k] = v 495 } 496 497 if rejected { 498 strict["outcome"] = "rejected" 499 } else { 500 strict["outcome"] = "accepted" 501 } 502 lenient["outcome"] = "accepted" 503 504 return 505 } 506 507 type testCounter struct { 508 count float64 509 labelPairs map[string]string 510 } 511 512 func (c *testCounter) Add(delta float64) { 513 c.count += delta 514 } 515 516 func (c *testCounter) With(labelValues ...string) metrics.Counter { 517 for i := 0; i < len(labelValues)-1; i += 2 { 518 c.labelPairs[labelValues[i]] = labelValues[i+1] 519 } 520 return c 521 } 522 523 func newTestCounter() *testCounter { 524 return &testCounter{ 525 labelPairs: make(map[string]string), 526 } 527 }