github.com/xmidt-org/webpa-common@v1.11.9/device/manager_test.go (about)

     1  package device
     2  
     3  import (
     4  	"encoding/json"
     5  	"fmt"
     6  	"net/http"
     7  	"net/http/httptest"
     8  	"net/url"
     9  	"sync"
    10  	"testing"
    11  	"time"
    12  
    13  	"github.com/go-kit/kit/log"
    14  	"github.com/go-kit/kit/metrics"
    15  
    16  	"github.com/xmidt-org/webpa-common/convey"
    17  	"github.com/xmidt-org/webpa-common/xmetrics"
    18  
    19  	"github.com/justinas/alice"
    20  	"github.com/stretchr/testify/assert"
    21  	"github.com/stretchr/testify/mock"
    22  	"github.com/stretchr/testify/require"
    23  	"github.com/xmidt-org/webpa-common/logging"
    24  	"github.com/xmidt-org/wrp-go/v3"
    25  )
    26  
    27  var (
    28  	testDeviceIDs = []ID{
    29  		IntToMAC(0xDEADBEEF),
    30  		IntToMAC(0x112233445566),
    31  		IntToMAC(0xFE881212CDCD),
    32  		IntToMAC(0x7F551928ABCD),
    33  	}
    34  )
    35  
    36  // startWebsocketServer sets up a server-side environment for testing device-related websocket code
    37  func startWebsocketServer(o *Options) (Manager, *httptest.Server, string) {
    38  	var (
    39  		manager = NewManager(o)
    40  		server  = httptest.NewServer(
    41  			alice.New(Timeout(o), UseID.FromHeader).Then(
    42  				&ConnectHandler{
    43  					Logger:    o.logger(),
    44  					Connector: manager,
    45  				},
    46  			),
    47  		)
    48  
    49  		websocketURL, err = url.Parse(server.URL)
    50  	)
    51  
    52  	if err != nil {
    53  		server.Close()
    54  		panic(fmt.Errorf("Unable to parse test server URL: %s", err))
    55  	}
    56  
    57  	websocketURL.Scheme = "ws"
    58  	return manager, server, websocketURL.String()
    59  }
    60  
    61  func connectTestDevices(t *testing.T, dialer Dialer, connectURL string) map[ID]Connection {
    62  	devices := make(map[ID]Connection, len(testDeviceIDs))
    63  
    64  	for _, id := range testDeviceIDs {
    65  		deviceConnection, _, err := dialer.DialDevice(string(id), connectURL, nil)
    66  		if err != nil {
    67  			t.Fatalf("Unable to dial test device: %s", err)
    68  			break
    69  		}
    70  
    71  		devices[id] = deviceConnection
    72  	}
    73  
    74  	return devices
    75  }
    76  
    77  func closeTestDevices(assert *assert.Assertions, devices map[ID]Connection) {
    78  	for _, connection := range devices {
    79  		assert.Nil(connection.Close())
    80  	}
    81  }
    82  
    83  func testManagerConnectFilterDeny(t *testing.T) {
    84  	assert := assert.New(t)
    85  	mockFilter := new(mockFilter)
    86  	options := &Options{
    87  		Logger: log.NewNopLogger(),
    88  		Filter: mockFilter,
    89  	}
    90  
    91  	manager := NewManager(options)
    92  	response := httptest.NewRecorder()
    93  	request := WithIDRequest(ID("mac:123412341234"), httptest.NewRequest("POST", "http://localhost.com", nil))
    94  
    95  	mockFilter.On("AllowConnection", mock.Anything).Return(false, MatchResult{}).Once()
    96  
    97  	device, err := manager.Connect(response, request, nil)
    98  	assert.Nil(device)
    99  	assert.Equal(err, ErrorDeviceFilteredOut)
   100  
   101  }
   102  
   103  func testManagerConnectMissingDeviceContext(t *testing.T) {
   104  	assert := assert.New(t)
   105  	options := &Options{
   106  		Logger: log.NewNopLogger(),
   107  	}
   108  
   109  	manager := NewManager(options)
   110  	response := httptest.NewRecorder()
   111  	request := httptest.NewRequest("POST", "http://localhost.com", nil)
   112  
   113  	device, err := manager.Connect(response, request, nil)
   114  	assert.Nil(device)
   115  	assert.Error(err)
   116  	assert.Equal(response.Code, http.StatusInternalServerError)
   117  }
   118  
   119  func testManagerConnectUpgradeError(t *testing.T) {
   120  	var (
   121  		assert  = assert.New(t)
   122  		options = &Options{
   123  			Logger: log.NewNopLogger(),
   124  			Listeners: []Listener{
   125  				func(e *Event) {
   126  					assert.Fail("The listener should not have been called")
   127  				},
   128  			},
   129  		}
   130  
   131  		manager        = NewManager(options)
   132  		response       = httptest.NewRecorder()
   133  		request        = WithIDRequest(ID("mac:123412341234"), httptest.NewRequest("POST", "http://localhost.com", nil))
   134  		responseHeader http.Header
   135  	)
   136  
   137  	device, actualError := manager.Connect(response, request, responseHeader)
   138  	assert.Nil(device)
   139  	assert.Error(actualError)
   140  }
   141  
   142  func testManagerConnectVisit(t *testing.T) {
   143  	var (
   144  		assert      = assert.New(t)
   145  		connectWait = new(sync.WaitGroup)
   146  		connections = make(chan Interface, len(testDeviceIDs))
   147  
   148  		options = &Options{
   149  			Logger: log.NewNopLogger(),
   150  			Listeners: []Listener{
   151  				func(event *Event) {
   152  					if event.Type == Connect {
   153  						defer connectWait.Done()
   154  						select {
   155  						case connections <- event.Device:
   156  						default:
   157  							assert.Fail("The connect listener should not block")
   158  						}
   159  					}
   160  				},
   161  			},
   162  		}
   163  
   164  		manager, server, connectURL = startWebsocketServer(options)
   165  	)
   166  
   167  	defer server.Close()
   168  	connectWait.Add(len(testDeviceIDs))
   169  
   170  	testDevices := connectTestDevices(t, DefaultDialer(), connectURL)
   171  	defer closeTestDevices(assert, testDevices)
   172  
   173  	connectWait.Wait()
   174  	close(connections)
   175  	assert.Equal(len(testDeviceIDs), len(connections))
   176  
   177  	deviceSet := make(deviceSet)
   178  	for candidate := range connections {
   179  		deviceSet.add(candidate)
   180  	}
   181  
   182  	assert.Equal(len(testDeviceIDs), deviceSet.len())
   183  	deviceSet.reset()
   184  	manager.VisitAll(deviceSet.managerCapture())
   185  	assert.Equal(len(testDeviceIDs), deviceSet.len())
   186  }
   187  
   188  func testManagerDisconnect(t *testing.T) {
   189  	assert := assert.New(t)
   190  	connectWait := new(sync.WaitGroup)
   191  	connectWait.Add(len(testDeviceIDs))
   192  
   193  	disconnectWait := new(sync.WaitGroup)
   194  	disconnectWait.Add(len(testDeviceIDs))
   195  	disconnections := make(chan Interface, len(testDeviceIDs))
   196  
   197  	options := &Options{
   198  		Logger: logging.NewTestLogger(nil, t),
   199  		Listeners: []Listener{
   200  			func(event *Event) {
   201  				switch event.Type {
   202  				case Connect:
   203  					connectWait.Done()
   204  				case Disconnect:
   205  					defer disconnectWait.Done()
   206  					assert.True(event.Device.Closed())
   207  					disconnections <- event.Device
   208  				}
   209  			},
   210  		},
   211  	}
   212  
   213  	manager, server, connectURL := startWebsocketServer(options)
   214  	defer server.Close()
   215  
   216  	testDevices := connectTestDevices(t, DefaultDialer(), connectURL)
   217  	defer closeTestDevices(assert, testDevices)
   218  
   219  	connectWait.Wait()
   220  	assert.Zero(manager.Disconnect(ID("nosuch"), CloseReason{}))
   221  	for _, id := range testDeviceIDs {
   222  		assert.Equal(true, manager.Disconnect(id, CloseReason{}))
   223  	}
   224  
   225  	disconnectWait.Wait()
   226  	close(disconnections)
   227  	assert.Equal(len(testDeviceIDs), len(disconnections))
   228  
   229  	deviceSet := make(deviceSet)
   230  	deviceSet.drain(disconnections)
   231  	assert.Equal(len(testDeviceIDs), deviceSet.len())
   232  }
   233  
   234  func testManagerDisconnectIf(t *testing.T) {
   235  	assert := assert.New(t)
   236  	connectWait := new(sync.WaitGroup)
   237  	connectWait.Add(len(testDeviceIDs))
   238  	disconnections := make(chan Interface, len(testDeviceIDs))
   239  
   240  	options := &Options{
   241  		Logger: logging.NewTestLogger(nil, t),
   242  		Listeners: []Listener{
   243  			func(event *Event) {
   244  				switch event.Type {
   245  				case Connect:
   246  					connectWait.Done()
   247  				case Disconnect:
   248  					assert.True(event.Device.Closed())
   249  					disconnections <- event.Device
   250  				}
   251  			},
   252  		},
   253  	}
   254  
   255  	manager, server, connectURL := startWebsocketServer(options)
   256  	defer server.Close()
   257  
   258  	testDevices := connectTestDevices(t, DefaultDialer(), connectURL)
   259  	defer closeTestDevices(assert, testDevices)
   260  
   261  	connectWait.Wait()
   262  	deviceSet := make(deviceSet)
   263  	manager.VisitAll(deviceSet.managerCapture())
   264  	assert.Equal(len(testDeviceIDs), deviceSet.len())
   265  
   266  	assert.Zero(manager.DisconnectIf(func(ID) (CloseReason, bool) { return CloseReason{}, false }))
   267  	select {
   268  	case <-disconnections:
   269  		assert.Fail("No disconnections should have occurred")
   270  	default:
   271  		// the passing case
   272  	}
   273  
   274  	for _, id := range testDeviceIDs {
   275  		assert.Equal(1, manager.DisconnectIf(func(candidate ID) (CloseReason, bool) { return CloseReason{}, candidate == id }))
   276  		select {
   277  		case actual := <-disconnections:
   278  			assert.Equal(id, actual.ID())
   279  			assert.True(actual.Closed())
   280  		case <-time.After(10 * time.Second):
   281  			assert.Fail("No disconnection occurred within the timeout")
   282  		}
   283  	}
   284  }
   285  
   286  func testManagerRouteBadDestination(t *testing.T) {
   287  	var (
   288  		assert  = assert.New(t)
   289  		request = &Request{
   290  			Message: &wrp.Message{
   291  				Destination: "this is a bad destination",
   292  			},
   293  		}
   294  
   295  		manager = NewManager(nil)
   296  	)
   297  
   298  	response, err := manager.Route(request)
   299  	assert.Nil(response)
   300  	assert.Error(err)
   301  }
   302  
   303  func testManagerRouteDeviceNotFound(t *testing.T) {
   304  	var (
   305  		assert  = assert.New(t)
   306  		request = &Request{
   307  			Message: &wrp.Message{
   308  				Destination: "mac:112233445566",
   309  			},
   310  		}
   311  
   312  		manager = NewManager(nil)
   313  	)
   314  
   315  	response, err := manager.Route(request)
   316  	assert.Nil(response)
   317  	assert.Equal(ErrorDeviceNotFound, err)
   318  }
   319  
   320  func testManagerConnectIncludesConvey(t *testing.T) {
   321  	var (
   322  		assert      = assert.New(t)
   323  		require     = require.New(t)
   324  		connectWait = new(sync.WaitGroup)
   325  		contents    = make(chan []byte, 1)
   326  
   327  		options = &Options{
   328  			Logger: log.NewNopLogger(),
   329  			Listeners: []Listener{
   330  				func(event *Event) {
   331  					if event.Type == Connect {
   332  						defer connectWait.Done()
   333  						select {
   334  						case contents <- event.Contents:
   335  						default:
   336  							assert.Fail("The connect listener should not block")
   337  						}
   338  					}
   339  				},
   340  			},
   341  		}
   342  
   343  		_, server, connectURL = startWebsocketServer(options)
   344  	)
   345  
   346  	defer server.Close()
   347  	connectWait.Add(1)
   348  
   349  	dialer := DefaultDialer()
   350  
   351  	/*
   352  		Convey header in base 64:
   353  			{
   354  				"hw-serial-number":123456789,
   355  				"webpa-protocol":"WebPA-1.6"
   356  			}
   357  
   358  	*/
   359  	header := &http.Header{
   360  		"X-Webpa-Convey": {"eyAgDQogICAiaHctc2VyaWFsLW51bWJlciI6MTIzNDU2Nzg5LA0KICAgIndlYnBhLXByb3RvY29sIjoiV2ViUEEtMS42Ig0KfQ=="},
   361  	}
   362  
   363  	deviceConnection, _, err := dialer.DialDevice(string(testDeviceIDs[0]), connectURL, *header)
   364  	require.NotNil(deviceConnection)
   365  	require.NoError(err)
   366  
   367  	defer assert.NoError(deviceConnection.Close())
   368  
   369  	connectWait.Wait()
   370  	close(contents)
   371  	assert.Equal(1, len(contents))
   372  
   373  	content := <-contents
   374  	convey := make(map[string]interface{})
   375  	err = json.Unmarshal(content, &convey)
   376  
   377  	assert.Nil(err)
   378  	assert.Equal(2, len(convey))
   379  	assert.Equal(float64(123456789), convey["hw-serial-number"])
   380  	assert.Equal("WebPA-1.6", convey["webpa-protocol"])
   381  }
   382  
   383  func TestManager(t *testing.T) {
   384  	t.Run("Connect", func(t *testing.T) {
   385  		t.Run("MissingDeviceContext", testManagerConnectMissingDeviceContext)
   386  		t.Run("FilterOutDevice", testManagerConnectFilterDeny)
   387  		t.Run("UpgradeError", testManagerConnectUpgradeError)
   388  		t.Run("Visit", testManagerConnectVisit)
   389  		t.Run("IncludesConvey", testManagerConnectIncludesConvey)
   390  	})
   391  
   392  	t.Run("Route", func(t *testing.T) {
   393  		t.Run("BadDestination", testManagerRouteBadDestination)
   394  		t.Run("DeviceNotFound", testManagerRouteDeviceNotFound)
   395  	})
   396  
   397  	t.Run("Disconnect", testManagerDisconnect)
   398  	t.Run("DisconnectIf", testManagerDisconnectIf)
   399  }
   400  
   401  func TestGaugeCardinality(t *testing.T) {
   402  	var (
   403  		assert = assert.New(t)
   404  		r, err = xmetrics.NewRegistry(nil, Metrics)
   405  		m      = NewManager(&Options{
   406  			MetricsProvider: r,
   407  		})
   408  	)
   409  	assert.NoError(err)
   410  
   411  	assert.NotPanics(func() {
   412  		dec, err := m.(*manager).conveyHWMetric.Update(convey.C{"hw-model": "cardinality", "fw-name": "firmware-number", "model": "f"}, "partnerid", "comcast", "trust", "0")
   413  		assert.NoError(err)
   414  		dec()
   415  	})
   416  
   417  	assert.Panics(func() {
   418  		m.(*manager).measures.Models.With("neat", "bad").Add(-1)
   419  	})
   420  }
   421  
   422  func TestWRPSourceIsValid(t *testing.T) {
   423  	assert := assert.New(t)
   424  	canonicalID := ID("mac:112233445566")
   425  	testData := []struct {
   426  		Name           string
   427  		Source         string
   428  		IsValid        bool
   429  		BaseLabelPairs map[string]string
   430  	}{
   431  		{
   432  			Name:    "EmptySource",
   433  			IsValid: false,
   434  			Source: "   	",
   435  			BaseLabelPairs: map[string]string{"reason": "empty"},
   436  		},
   437  
   438  		{
   439  			Name:           "ParseFailure",
   440  			IsValid:        false,
   441  			Source:         "serial>hacker/service",
   442  			BaseLabelPairs: map[string]string{"reason": "parse_error"},
   443  		},
   444  		{
   445  			Name:           "IDMismatch",
   446  			IsValid:        false,
   447  			Source:         "mac:665544332211/service/some/path",
   448  			BaseLabelPairs: map[string]string{"reason": "id_mismatch"},
   449  		},
   450  		{
   451  			Name:           "IDMatch",
   452  			IsValid:        true,
   453  			Source:         "mac:112233445566/service/some/path",
   454  			BaseLabelPairs: map[string]string{"reason": "id_match"},
   455  		},
   456  	}
   457  
   458  	for _, record := range testData {
   459  		t.Run(record.Name, func(t *testing.T) {
   460  			expectedStrictLabels, expectedLenientLabels := createLabelMaps(!record.IsValid, record.BaseLabelPairs)
   461  
   462  			d := new(device)
   463  			d.id = canonicalID
   464  			d.errorLog = log.WithPrefix(logging.NewTestLogger(nil, t), "id", canonicalID)
   465  			d.metadata = new(Metadata)
   466  
   467  			// strict mode
   468  			counter := newTestCounter()
   469  			message := &wrp.Message{Source: record.Source}
   470  			m := &manager{enforceWRPSourceCheck: true, measures: Measures{WRPSourceCheck: counter}}
   471  			ok := m.wrpSourceIsValid(message, d)
   472  			assert.Equal(record.IsValid, ok)
   473  			assert.Equal(expectedStrictLabels, counter.labelPairs)
   474  
   475  			// lenient mode
   476  			counter = newTestCounter()
   477  			message = &wrp.Message{Source: record.Source}
   478  			m = &manager{enforceWRPSourceCheck: false, measures: Measures{WRPSourceCheck: counter}}
   479  
   480  			ok = m.wrpSourceIsValid(message, d)
   481  			assert.True(ok)
   482  			assert.Equal(expectedLenientLabels, counter.labelPairs)
   483  		})
   484  	}
   485  
   486  }
   487  
   488  func createLabelMaps(rejected bool, baseLabelPairs map[string]string) (strict map[string]string, lenient map[string]string) {
   489  	strict = make(map[string]string)
   490  	lenient = make(map[string]string)
   491  
   492  	for k, v := range baseLabelPairs {
   493  		strict[k] = v
   494  		lenient[k] = v
   495  	}
   496  
   497  	if rejected {
   498  		strict["outcome"] = "rejected"
   499  	} else {
   500  		strict["outcome"] = "accepted"
   501  	}
   502  	lenient["outcome"] = "accepted"
   503  
   504  	return
   505  }
   506  
   507  type testCounter struct {
   508  	count      float64
   509  	labelPairs map[string]string
   510  }
   511  
   512  func (c *testCounter) Add(delta float64) {
   513  	c.count += delta
   514  }
   515  
   516  func (c *testCounter) With(labelValues ...string) metrics.Counter {
   517  	for i := 0; i < len(labelValues)-1; i += 2 {
   518  		c.labelPairs[labelValues[i]] = labelValues[i+1]
   519  	}
   520  	return c
   521  }
   522  
   523  func newTestCounter() *testCounter {
   524  	return &testCounter{
   525  		labelPairs: make(map[string]string),
   526  	}
   527  }