github.com/xmidt-org/webpa-common@v1.11.9/device/manager.go (about)

     1  package device
     2  
     3  import (
     4  	"encoding/json"
     5  	"io"
     6  	"net/http"
     7  	"strconv"
     8  	"strings"
     9  	"sync"
    10  	"time"
    11  
    12  	"github.com/xmidt-org/webpa-common/convey"
    13  	"github.com/xmidt-org/webpa-common/convey/conveymetric"
    14  
    15  	"github.com/go-kit/kit/log"
    16  	"github.com/gorilla/websocket"
    17  	"github.com/xmidt-org/webpa-common/convey/conveyhttp"
    18  	"github.com/xmidt-org/webpa-common/logging"
    19  	"github.com/xmidt-org/webpa-common/xhttp"
    20  	"github.com/xmidt-org/wrp-go/v3"
    21  )
    22  
    23  const MaxDevicesHeader = "X-Xmidt-Max-Devices"
    24  
    25  // DefaultWRPContentType is the content type used on inbound WRP messages which don't provide one.
    26  const DefaultWRPContentType = "application/octet-stream"
    27  
    28  // Connector is a strategy interface for managing device connections to a server.
    29  // Implementations are responsible for upgrading websocket connections and providing
    30  // for explicit disconnection.
    31  type Connector interface {
    32  	// Connect upgrade an HTTP connection to a websocket and begins concurrent
    33  	// management of the device.
    34  	Connect(http.ResponseWriter, *http.Request, http.Header) (Interface, error)
    35  
    36  	// Disconnect disconnects the device associated with the given id.
    37  	// If the id was found, this method returns true.
    38  	Disconnect(ID, CloseReason) bool
    39  
    40  	// DisconnectIf iterates over all devices known to this manager, applying the
    41  	// given predicate.  For any devices that result in true, this method disconnects them.
    42  	// Note that this method may pause connections and disconnections while it is executing.
    43  	// This method returns the number of devices that were disconnected.
    44  	//
    45  	// Only disconnection by ID is supported, which means that any identifier matching
    46  	// the predicate will result in *all* duplicate devices under that ID being removed.
    47  	//
    48  	// No methods on this Manager should be called from within the predicate function, or
    49  	// a deadlock will likely occur.
    50  	DisconnectIf(func(ID) (CloseReason, bool)) int
    51  
    52  	// DisconnectAll disconnects all devices from this instance, and returns the count of
    53  	// devices disconnected.
    54  	DisconnectAll(CloseReason) int
    55  
    56  	// GetFilter returns the Filter interface used for filtering connection requests
    57  	GetFilter() Filter
    58  }
    59  
    60  // Router handles dispatching messages to devices.
    61  type Router interface {
    62  	// Route dispatches a WRP request to exactly one device, identified by the ID
    63  	// field of the request.  Route is synchronous, and honors the cancellation semantics
    64  	// of the Request's context.
    65  	Route(*Request) (*Response, error)
    66  }
    67  
    68  // Registry is the strategy interface for querying the set of connected devices.  Methods
    69  // in this interface follow the Visitor pattern and are typically executed under a read lock.
    70  type Registry interface {
    71  	// Len returns the count of devices currently in this registry
    72  	Len() int
    73  
    74  	// Get returns the device associated with the given ID, if any
    75  	Get(ID) (Interface, bool)
    76  
    77  	// VisitAll applies the given visitor function to each device known to this manager.
    78  	//
    79  	// No methods on this Manager should be called from within the visitor function, or
    80  	// a deadlock will likely occur.
    81  	VisitAll(func(Interface) bool) int
    82  }
    83  
    84  type Filter interface {
    85  	AllowConnection(d Interface) (bool, MatchResult)
    86  }
    87  
    88  type MatchResult struct {
    89  	Location string
    90  	Key      string
    91  }
    92  
    93  type FilterFunc func(d Interface) (bool, MatchResult)
    94  
    95  func (filter FilterFunc) AllowConnection(d Interface) (bool, MatchResult) {
    96  	return filter(d)
    97  }
    98  
    99  // Manager supplies a hub for connecting and disconnecting devices as well as
   100  // an access point for obtaining device metadata.
   101  type Manager interface {
   102  	Connector
   103  	Router
   104  	Registry
   105  }
   106  
   107  // ManagerOption is a configuration option for a manager
   108  type ManagerOption func(*manager)
   109  
   110  // NewManager constructs a Manager from a set of options.  A ConnectionFactory will be
   111  // created from the options if one is not supplied.
   112  func NewManager(o *Options) Manager {
   113  	var (
   114  		logger      = o.logger()
   115  		debugLogger = logging.Debug(logger)
   116  		measures    = NewMeasures(o.metricsProvider())
   117  		wrpCheck    = o.wrpCheck()
   118  	)
   119  
   120  	debugLogger.Log(logging.MessageKey(), "source check configuration", "type", wrpCheck.Type)
   121  
   122  	return &manager{
   123  		logger:           logger,
   124  		errorLog:         logging.Error(logger),
   125  		debugLog:         debugLogger,
   126  		readDeadline:     NewDeadline(o.idlePeriod(), o.now()),
   127  		writeDeadline:    NewDeadline(o.writeTimeout(), o.now()),
   128  		upgrader:         o.upgrader(),
   129  		conveyTranslator: conveyhttp.NewHeaderTranslator("", nil),
   130  		devices: newRegistry(registryOptions{
   131  			Logger:   logger,
   132  			Limit:    o.maxDevices(),
   133  			Measures: measures,
   134  		}),
   135  		conveyHWMetric: conveymetric.NewConveyMetric(measures.Models, []conveymetric.TagLabelPair{
   136  			{
   137  				Tag:   "hw-model",
   138  				Label: "model",
   139  			},
   140  			{
   141  				Tag:   "fw-name",
   142  				Label: "firmware",
   143  			}}...),
   144  
   145  		deviceMessageQueueSize: o.deviceMessageQueueSize(),
   146  		pingPeriod:             o.pingPeriod(),
   147  
   148  		listeners:             o.listeners(),
   149  		measures:              measures,
   150  		enforceWRPSourceCheck: wrpCheck.Type == CheckTypeEnforce,
   151  		filter:                o.filter(),
   152  	}
   153  
   154  }
   155  
   156  // manager is the internal Manager implementation.
   157  type manager struct {
   158  	logger   log.Logger
   159  	errorLog log.Logger
   160  	debugLog log.Logger
   161  
   162  	readDeadline     func() time.Time
   163  	writeDeadline    func() time.Time
   164  	upgrader         *websocket.Upgrader
   165  	conveyTranslator conveyhttp.HeaderTranslator
   166  
   167  	devices        *registry
   168  	conveyHWMetric conveymetric.Interface
   169  
   170  	deviceMessageQueueSize int
   171  	pingPeriod             time.Duration
   172  
   173  	listeners             []Listener
   174  	measures              Measures
   175  	enforceWRPSourceCheck bool
   176  
   177  	filter Filter
   178  }
   179  
   180  func (m *manager) Connect(response http.ResponseWriter, request *http.Request, responseHeader http.Header) (Interface, error) {
   181  	m.debugLog.Log(logging.MessageKey(), "device connect", "url", request.URL)
   182  	ctx := request.Context()
   183  	id, ok := GetID(ctx)
   184  	if !ok {
   185  		xhttp.WriteError(
   186  			response,
   187  			http.StatusInternalServerError,
   188  			ErrorMissingDeviceNameContext,
   189  		)
   190  
   191  		return nil, ErrorMissingDeviceNameContext
   192  	}
   193  
   194  	metadata, ok := GetDeviceMetadata(ctx)
   195  	if !ok {
   196  		metadata = new(Metadata)
   197  	}
   198  
   199  	cvy, cvyErr := m.conveyTranslator.FromHeader(request.Header)
   200  	d := newDevice(deviceOptions{
   201  		ID:         id,
   202  		C:          cvy,
   203  		Compliance: convey.GetCompliance(cvyErr),
   204  		QueueSize:  m.deviceMessageQueueSize,
   205  		Metadata:   metadata,
   206  		Logger:     m.logger,
   207  	})
   208  
   209  	if allow, matchResults := m.filter.AllowConnection(d); !allow {
   210  		d.infoLog.Log("filter", "filter match found,", "location", matchResults.Location, "key", matchResults.Key)
   211  		return nil, ErrorDeviceFilteredOut
   212  	}
   213  
   214  	if len(metadata.Claims()) < 1 {
   215  		d.errorLog.Log(logging.MessageKey(), "missing security information")
   216  	}
   217  
   218  	if cvyErr == nil {
   219  		d.infoLog.Log("convey", cvy)
   220  	} else {
   221  		d.errorLog.Log(logging.MessageKey(), "bad or missing convey data", logging.ErrorKey(), cvyErr)
   222  	}
   223  
   224  	c, err := m.upgrader.Upgrade(response, request, responseHeader)
   225  	if err != nil {
   226  		d.errorLog.Log(logging.MessageKey(), "failed websocket upgrade", logging.ErrorKey(), err)
   227  		return nil, err
   228  	}
   229  
   230  	d.debugLog.Log(logging.MessageKey(), "websocket upgrade complete", "localAddress", c.LocalAddr().String())
   231  
   232  	pinger, err := NewPinger(c, m.measures.Ping, []byte(d.ID()), m.writeDeadline)
   233  	if err != nil {
   234  		d.errorLog.Log(logging.MessageKey(), "unable to create pinger", logging.ErrorKey(), err)
   235  		c.Close()
   236  		return nil, err
   237  	}
   238  
   239  	if err := m.devices.add(d); err != nil {
   240  		d.errorLog.Log(logging.MessageKey(), "unable to register device", logging.ErrorKey(), err)
   241  		c.Close()
   242  		return nil, err
   243  	}
   244  
   245  	event := &Event{
   246  		Type:   Connect,
   247  		Device: d,
   248  	}
   249  
   250  	if cvyErr == nil {
   251  		bytes, err := json.Marshal(cvy)
   252  		if err == nil {
   253  			event.Format = wrp.JSON
   254  			event.Contents = bytes
   255  		} else {
   256  			d.errorLog.Log(logging.MessageKey(), "unable to marshal the convey header", logging.ErrorKey(), err)
   257  		}
   258  	}
   259  	metricClosure, err := m.conveyHWMetric.Update(cvy, "partnerid", metadata.PartnerIDClaim(), "trust", strconv.Itoa(metadata.TrustClaim()))
   260  	if err != nil {
   261  		d.errorLog.Log(logging.MessageKey(), "failed to update convey metrics", logging.ErrorKey(), err)
   262  	}
   263  
   264  	d.conveyClosure = metricClosure
   265  	m.dispatch(event)
   266  
   267  	SetPongHandler(c, m.measures.Pong, m.readDeadline)
   268  	closeOnce := new(sync.Once)
   269  	go m.readPump(d, InstrumentReader(c, d.statistics), closeOnce)
   270  	go m.writePump(d, InstrumentWriter(c, d.statistics), pinger, closeOnce)
   271  
   272  	return d, nil
   273  }
   274  
   275  func (m *manager) dispatch(e *Event) {
   276  	for _, listener := range m.listeners {
   277  		listener(e)
   278  	}
   279  }
   280  
   281  // pumpClose handles the proper shutdown and logging of a device's pumps.
   282  // This method should be executed within a sync.Once, so that it only executes
   283  // once for a given device.
   284  //
   285  // Note that the write pump does additional cleanup.  In particular, the write pump
   286  // dispatches message failed events for any messages that were waiting to be delivered
   287  // at the time of pump closure.
   288  func (m *manager) pumpClose(d *device, c io.Closer, reason CloseReason) {
   289  	// remove will invoke requestClose()
   290  	m.devices.remove(d.id, reason)
   291  
   292  	closeError := c.Close()
   293  
   294  	d.errorLog.Log(logging.MessageKey(), "Closed device connection",
   295  		"closeError", closeError, "reasonError", reason.Err, "reason", reason.Text,
   296  		"finalStatistics", d.Statistics().String())
   297  
   298  	m.dispatch(
   299  		&Event{
   300  			Type:   Disconnect,
   301  			Device: d,
   302  		},
   303  	)
   304  	d.conveyClosure()
   305  }
   306  
   307  func (m *manager) wrpSourceIsValid(message *wrp.Message, d *device) bool {
   308  	expectedID := d.ID()
   309  	if len(strings.TrimSpace(message.Source)) == 0 {
   310  		d.errorLog.Log(logging.MessageKey(), "WRP source was empty", "trustLevel", d.Metadata().TrustClaim())
   311  		if m.enforceWRPSourceCheck {
   312  			m.measures.WRPSourceCheck.With("outcome", "rejected", "reason", "empty").Add(1)
   313  			return false
   314  		}
   315  		m.measures.WRPSourceCheck.With("outcome", "accepted", "reason", "empty").Add(1)
   316  		return true
   317  	}
   318  
   319  	actualID, err := ParseID(message.Source)
   320  	if err != nil {
   321  		d.errorLog.Log(logging.MessageKey(), "Failed to parse ID from WRP source", "trustLevel", d.Metadata().TrustClaim())
   322  		if m.enforceWRPSourceCheck {
   323  			m.measures.WRPSourceCheck.With("outcome", "rejected", "reason", "parse_error").Add(1)
   324  			return false
   325  		}
   326  		m.measures.WRPSourceCheck.With("outcome", "accepted", "reason", "parse_error").Add(1)
   327  		return true
   328  	}
   329  
   330  	if expectedID != actualID {
   331  		d.errorLog.Log(logging.MessageKey(), "ID in WRP source does not match device's ID", "spoofedID", actualID, "trustLevel", d.Metadata().TrustClaim())
   332  		if m.enforceWRPSourceCheck {
   333  			m.measures.WRPSourceCheck.With("outcome", "rejected", "reason", "id_mismatch").Add(1)
   334  			return false
   335  		}
   336  		m.measures.WRPSourceCheck.With("outcome", "accepted", "reason", "id_mismatch").Add(1)
   337  		return true
   338  	}
   339  
   340  	m.measures.WRPSourceCheck.With("outcome", "accepted", "reason", "id_match").Add(1)
   341  	return true
   342  }
   343  
   344  func addDeviceMetadataContext(message *wrp.Message, deviceMetadata *Metadata) {
   345  	message.PartnerIDs = []string{deviceMetadata.PartnerIDClaim()}
   346  
   347  	if message.Type == wrp.SimpleEventMessageType {
   348  		message.SessionID = deviceMetadata.SessionID()
   349  	}
   350  }
   351  
   352  // readPump is the goroutine which handles the stream of WRP messages from a device.
   353  // This goroutine exits when any error occurs on the connection.
   354  func (m *manager) readPump(d *device, r ReadCloser, closeOnce *sync.Once) {
   355  	defer d.debugLog.Log(logging.MessageKey(), "readPump exiting")
   356  	d.debugLog.Log(logging.MessageKey(), "readPump starting")
   357  
   358  	var (
   359  		readError error
   360  		decoder   = wrp.NewDecoder(nil, wrp.Msgpack)
   361  		encoder   = wrp.NewEncoder(nil, wrp.Msgpack)
   362  	)
   363  
   364  	// all the read pump has to do is ensure the device and the connection are closed
   365  	// it is the write pump's responsibility to do further cleanup
   366  	defer func() {
   367  		closeOnce.Do(func() { m.pumpClose(d, r, CloseReason{Err: readError, Text: "readerror"}) })
   368  	}()
   369  
   370  	for {
   371  		messageType, data, readError := r.ReadMessage()
   372  		if readError != nil {
   373  			d.errorLog.Log(logging.MessageKey(), "read error", logging.ErrorKey(), readError)
   374  			return
   375  		}
   376  
   377  		if messageType != websocket.BinaryMessage {
   378  			d.errorLog.Log(logging.MessageKey(), "skipping non-binary frame", "messageType", messageType)
   379  			continue
   380  		}
   381  
   382  		var (
   383  			message = new(wrp.Message)
   384  			event   = Event{
   385  				Type:     MessageReceived,
   386  				Device:   d,
   387  				Message:  message,
   388  				Format:   wrp.Msgpack,
   389  				Contents: data,
   390  			}
   391  		)
   392  
   393  		decoder.ResetBytes(data)
   394  		err := decoder.Decode(message)
   395  		if err != nil {
   396  			d.errorLog.Log(logging.MessageKey(), "skipping malformed WRP message", logging.ErrorKey(), err)
   397  			continue
   398  		}
   399  
   400  		if !m.wrpSourceIsValid(message, d) {
   401  			d.errorLog.Log(logging.MessageKey(), "skipping WRP message with invalid source")
   402  			continue
   403  		}
   404  
   405  		if len(strings.TrimSpace(message.ContentType)) == 0 {
   406  			message.ContentType = DefaultWRPContentType
   407  		}
   408  
   409  		addDeviceMetadataContext(message, d.Metadata())
   410  
   411  		if message.Type == wrp.SimpleRequestResponseMessageType {
   412  			m.measures.RequestResponse.Add(1.0)
   413  		}
   414  
   415  		encoder.ResetBytes(&event.Contents)
   416  		err = encoder.Encode(message)
   417  
   418  		if err != nil {
   419  			d.errorLog.Log(logging.MessageKey(), "unable to encode WRP message", logging.ErrorKey(), err)
   420  			continue
   421  		}
   422  
   423  		// update any waiting transaction
   424  		if message.IsTransactionPart() {
   425  			err := d.transactions.Complete(
   426  				message.TransactionKey(),
   427  				&Response{
   428  					Device:   d,
   429  					Message:  message,
   430  					Format:   wrp.Msgpack,
   431  					Contents: event.Contents,
   432  				},
   433  			)
   434  
   435  			if err != nil {
   436  				d.errorLog.Log(logging.MessageKey(), "Error while completing transaction", "transactionKey", message.TransactionKey(), logging.ErrorKey(), err)
   437  				event.Type = TransactionBroken
   438  				event.Error = err
   439  			} else {
   440  				event.Type = TransactionComplete
   441  			}
   442  		}
   443  		m.dispatch(&event)
   444  	}
   445  }
   446  
   447  // writePump is the goroutine which services messages addressed to the device.
   448  // this goroutine exits when either an explicit shutdown is requested or any
   449  // error occurs on the connection.
   450  func (m *manager) writePump(d *device, w WriteCloser, pinger func() error, closeOnce *sync.Once) {
   451  	defer d.debugLog.Log(logging.MessageKey(), "writePump exiting")
   452  	d.debugLog.Log(logging.MessageKey(), "writePump starting")
   453  
   454  	var (
   455  		envelope   *envelope
   456  		encoder    = wrp.NewEncoder(nil, wrp.Msgpack)
   457  		writeError error
   458  
   459  		pingTicker = time.NewTicker(m.pingPeriod)
   460  	)
   461  
   462  	// cleanup: we not only ensure that the device and connection are closed but also
   463  	// ensure that any messages that were waiting and/or failed are dispatched to
   464  	// the configured listener
   465  	defer func() {
   466  		pingTicker.Stop()
   467  		closeOnce.Do(func() { m.pumpClose(d, w, CloseReason{Err: writeError, Text: "write-error"}) })
   468  
   469  		// notify listener of any message that just now failed
   470  		// any writeError is passed via this event
   471  		if envelope != nil {
   472  			m.dispatch(&Event{
   473  				Type:     MessageFailed,
   474  				Device:   d,
   475  				Message:  envelope.request.Message,
   476  				Format:   envelope.request.Format,
   477  				Contents: envelope.request.Contents,
   478  				Error:    writeError,
   479  			})
   480  		}
   481  
   482  		// drain the messages, dispatching them as message failed events.  we never close
   483  		// the message channel, so just drain until a receive would block.
   484  		//
   485  		// Nil is passed explicitly as the error to indicate that these messages failed due
   486  		// to the device disconnecting, not due to an actual I/O error.
   487  		for {
   488  			select {
   489  			case undeliverable := <-d.messages:
   490  				d.errorLog.Log(logging.MessageKey(), "undeliverable message", "deviceMessage", undeliverable)
   491  				m.dispatch(&Event{
   492  					Type:     MessageFailed,
   493  					Device:   d,
   494  					Message:  undeliverable.request.Message,
   495  					Format:   undeliverable.request.Format,
   496  					Contents: undeliverable.request.Contents,
   497  					Error:    writeError,
   498  				})
   499  			default:
   500  				return
   501  			}
   502  		}
   503  	}()
   504  
   505  	for writeError == nil {
   506  		envelope = nil
   507  
   508  		select {
   509  		case <-d.shutdown:
   510  			d.debugLog.Log(logging.MessageKey(), "explicit shutdown")
   511  			writeError = w.Close()
   512  			return
   513  
   514  		case envelope = <-d.messages:
   515  			var frameContents []byte
   516  			if envelope.request.Format == wrp.Msgpack && len(envelope.request.Contents) > 0 {
   517  				frameContents = envelope.request.Contents
   518  			} else {
   519  				// if the request was in a format other than Msgpack, or if the caller did not pass
   520  				// Contents, then do the encoding here.
   521  				encoder.ResetBytes(&frameContents)
   522  				writeError = encoder.Encode(envelope.request.Message)
   523  				encoder.ResetBytes(nil)
   524  			}
   525  
   526  			if writeError == nil {
   527  				writeError = w.WriteMessage(websocket.BinaryMessage, frameContents)
   528  			}
   529  
   530  			event := Event{
   531  				Device:   d,
   532  				Message:  envelope.request.Message,
   533  				Format:   envelope.request.Format,
   534  				Contents: envelope.request.Contents,
   535  				Error:    writeError,
   536  			}
   537  
   538  			if writeError != nil {
   539  				envelope.complete <- writeError
   540  				event.Type = MessageFailed
   541  			} else {
   542  				event.Type = MessageSent
   543  			}
   544  
   545  			close(envelope.complete)
   546  			m.dispatch(&event)
   547  
   548  		case <-pingTicker.C:
   549  			writeError = pinger()
   550  		}
   551  	}
   552  }
   553  
   554  func (m *manager) Disconnect(id ID, reason CloseReason) bool {
   555  	_, ok := m.devices.remove(id, reason)
   556  	return ok
   557  }
   558  
   559  func (m *manager) DisconnectIf(filter func(ID) (CloseReason, bool)) int {
   560  	return m.devices.removeIf(func(d *device) (CloseReason, bool) {
   561  		return filter(d.id)
   562  	})
   563  }
   564  
   565  func (m *manager) DisconnectAll(reason CloseReason) int {
   566  	return m.devices.removeAll(reason)
   567  }
   568  
   569  func (m *manager) GetFilter() Filter {
   570  	return m.filter
   571  }
   572  
   573  func defaultFilterFunc() FilterFunc {
   574  	return func(d Interface) (bool, MatchResult) {
   575  		return true, MatchResult{}
   576  	}
   577  }
   578  
   579  func (m *manager) Len() int {
   580  	return m.devices.len()
   581  }
   582  
   583  func (m *manager) Get(id ID) (Interface, bool) {
   584  	return m.devices.get(id)
   585  }
   586  
   587  func (m *manager) VisitAll(visitor func(Interface) bool) int {
   588  	return m.devices.visit(func(d *device) bool {
   589  		return visitor(d)
   590  	})
   591  }
   592  
   593  func (m *manager) Route(request *Request) (*Response, error) {
   594  	if destination, err := request.ID(); err != nil {
   595  		return nil, err
   596  	} else if d, ok := m.devices.get(destination); ok {
   597  		return d.Send(request)
   598  	} else {
   599  		return nil, ErrorDeviceNotFound
   600  	}
   601  }