github.com/qri-io/qri@v0.10.1-0.20220104210721-c771715036cb/lib/websocket/websocket.go (about)

     1  package websocket
     2  
     3  import (
     4  	"context"
     5  	"encoding/json"
     6  	"fmt"
     7  	"io"
     8  	"net/http"
     9  	"sync"
    10  
    11  	"github.com/google/uuid"
    12  	golog "github.com/ipfs/go-log"
    13  	"github.com/qri-io/qri/auth/key"
    14  	"github.com/qri-io/qri/auth/token"
    15  	"github.com/qri-io/qri/event"
    16  	"nhooyr.io/websocket"
    17  	"nhooyr.io/websocket/wsjson"
    18  )
    19  
    20  const qriWebsocketProtocol = "qri-websocket"
    21  
    22  var (
    23  	errNotFound = fmt.Errorf("connection not found")
    24  
    25  	log = golog.Logger("websocket")
    26  )
    27  
    28  // newID returns a new websocket connection ID
    29  func newID() string {
    30  	return uuid.New().String()
    31  }
    32  
    33  // setIDRand sets the random reader that NewID uses as a source of random bytes
    34  // passing in nil will default to crypto.Rand. This can be used to make ID
    35  // generation deterministic for tests. eg:
    36  //    myString := "SomeRandomStringThatIsLong-SoYouCanCallItAsMuchAsNeeded..."
    37  //    lib.SetIDRand(strings.NewReader(myString))
    38  //    a := NewID()
    39  //    lib.SetIDRand(strings.NewReader(myString))
    40  //    b := NewID()
    41  func setIDRand(r io.Reader) {
    42  	uuid.SetRand(r)
    43  }
    44  
    45  // Handler defines the handler interface
    46  type Handler interface {
    47  	ConnectionHandler(w http.ResponseWriter, r *http.Request)
    48  }
    49  
    50  type connectionSet map[string]struct{}
    51  
    52  // connections maintains the set of active websocket connections & associated
    53  // connection metadata
    54  type connections struct {
    55  	conns         map[string]*conn
    56  	connsLock     sync.Mutex
    57  	keystore      key.Store
    58  	subscriptions map[string]connectionSet
    59  	subsLock      sync.Mutex
    60  }
    61  
    62  type conn struct {
    63  	id        string
    64  	profileID string
    65  	conn      *websocket.Conn
    66  }
    67  
    68  var _ Handler = (*connections)(nil)
    69  
    70  // NewHandler creates a new connections instance that clients
    71  // can connect to in order to get realtime events
    72  func NewHandler(ctx context.Context, bus event.Bus, keystore key.Store) (Handler, error) {
    73  	ws := &connections{
    74  		conns:         map[string]*conn{},
    75  		connsLock:     sync.Mutex{},
    76  		keystore:      keystore,
    77  		subscriptions: map[string]connectionSet{},
    78  		subsLock:      sync.Mutex{},
    79  	}
    80  
    81  	bus.SubscribeAll(ws.messageHandler)
    82  	return ws, nil
    83  }
    84  
    85  // ConnectionHandler handles websocket upgrade requests and accepts the connection
    86  func (h *connections) ConnectionHandler(w http.ResponseWriter, r *http.Request) {
    87  	wsc, err := websocket.Accept(w, r, &websocket.AcceptOptions{
    88  		Subprotocols:       []string{qriWebsocketProtocol},
    89  		InsecureSkipVerify: true,
    90  	})
    91  	if err != nil {
    92  		log.Debugf("Websocket accept error: %s", err)
    93  		return
    94  	}
    95  	id := newID()
    96  	c := &conn{
    97  		id:   id,
    98  		conn: wsc,
    99  	}
   100  	h.connsLock.Lock()
   101  	defer h.connsLock.Unlock()
   102  	h.conns[id] = c
   103  	go h.read(id)
   104  }
   105  
   106  func (h *connections) messageHandler(_ context.Context, e event.Event) error {
   107  	ctx := context.Background()
   108  	evt := map[string]interface{}{
   109  		"type":      string(e.Type),
   110  		"ts":        e.Timestamp,
   111  		"sessionID": e.SessionID,
   112  		"data":      e.Payload,
   113  	}
   114  
   115  	profileIDString := e.ProfileID
   116  	if profileIDString == "" {
   117  		return nil
   118  	}
   119  	connIDs, err := h.getConnIDs(profileIDString)
   120  	if err != nil {
   121  		log.Errorf("profile %q: %w", profileIDString, err)
   122  		return nil
   123  	}
   124  
   125  	for connID := range connIDs {
   126  		c, err := h.getConn(connID)
   127  		if err != nil {
   128  			h.unsubscribeConn(profileIDString, connID)
   129  			log.Errorf("connection %q, profile %q: %w", connID, profileIDString, err)
   130  			return nil
   131  		}
   132  		log.Debugf("sending event %q to websocket conns %q", e.Type, profileIDString)
   133  		if err := wsjson.Write(ctx, c.conn, evt); err != nil {
   134  			log.Errorf("connection %q: wsjson write error: %s", profileIDString, err)
   135  			return nil
   136  		}
   137  	}
   138  	return nil
   139  }
   140  
   141  // getConn gets a *conn from the map of connections
   142  func (h *connections) getConn(id string) (*conn, error) {
   143  	h.connsLock.Lock()
   144  	defer h.connsLock.Unlock()
   145  	c, ok := h.conns[id]
   146  	if !ok {
   147  		return nil, errNotFound
   148  	}
   149  	if c == nil {
   150  		return nil, errNotFound
   151  	}
   152  	return c, nil
   153  }
   154  
   155  // getConnID returns the connection ID associated with the given profile.ID string
   156  func (h *connections) getConnIDs(profileID string) (connectionSet, error) {
   157  	h.subsLock.Lock()
   158  	defer h.subsLock.Unlock()
   159  	ids, ok := h.subscriptions[profileID]
   160  	if !ok {
   161  		return nil, errNotFound
   162  	}
   163  	if ids == nil {
   164  		delete(h.subscriptions, profileID)
   165  		return nil, errNotFound
   166  	}
   167  	return ids, nil
   168  }
   169  
   170  // subscribeConn authenticates the given token and adds the connID to the map
   171  // of "subscribed" connections
   172  func (h *connections) subscribeConn(connID, tokenString string) error {
   173  	ctx := context.TODO()
   174  	tok, err := token.ParseAuthToken(ctx, tokenString, h.keystore)
   175  	if err != nil {
   176  		return err
   177  	}
   178  
   179  	claims, ok := tok.Claims.(*token.Claims)
   180  	if !ok || claims.Subject == "" {
   181  		h.removeConn(connID)
   182  		return fmt.Errorf("cannot get profile.ID from token")
   183  	}
   184  	// TODO(b5): at this point we have a valid signature of a profileID string
   185  	// but no proof that this profile is owned by the key that signed the
   186  	// token. We either need ProfileID == KeyID, or we need a UCAN. we need to
   187  	// check for those, ideally in a method within the profile package that
   188  	// abstracts over profile & key agreement
   189  
   190  	c, err := h.getConn(connID)
   191  	if err != nil {
   192  		return fmt.Errorf("connection %q: %w", connID, err)
   193  	}
   194  	c.profileID = claims.Subject
   195  
   196  	h.subsLock.Lock()
   197  	defer h.subsLock.Unlock()
   198  	connIDs, ok := h.subscriptions[claims.Subject]
   199  	if !ok || connIDs == nil {
   200  		connIDs = connectionSet{}
   201  	}
   202  	connIDs[connID] = struct{}{}
   203  	h.subscriptions[claims.Subject] = connIDs
   204  	log.Debugw("subscribeConn", "id", connID)
   205  	return nil
   206  }
   207  
   208  // unsubscribeConn remove the profileID and connID from the map of "subscribed"
   209  // connections
   210  func (h *connections) unsubscribeConn(profileID, connID string) {
   211  	connIDs, err := h.getConnIDs(profileID)
   212  	if err != nil {
   213  		return
   214  	}
   215  	for cid := range connIDs {
   216  		if connID == "" || cid == connID {
   217  			c, err := h.getConn(cid)
   218  			if err != nil || c == nil {
   219  				continue
   220  			}
   221  			c.profileID = ""
   222  		}
   223  	}
   224  
   225  	h.subsLock.Lock()
   226  	defer h.subsLock.Unlock()
   227  	if connID == "" {
   228  		delete(h.subscriptions, profileID)
   229  	} else {
   230  		if _, ok := h.subscriptions[profileID]; ok {
   231  			delete(h.subscriptions[profileID], connID)
   232  		}
   233  		if len(h.subscriptions[profileID]) == 0 {
   234  			delete(h.subscriptions, profileID)
   235  		}
   236  	}
   237  }
   238  
   239  // removeConn removes the conn from the map of connections and subscriptions
   240  // closing the connection if needed
   241  func (h *connections) removeConn(id string) {
   242  	c, err := h.getConn(id)
   243  	if err != nil {
   244  		return
   245  	}
   246  	defer func() {
   247  		c.conn.Close(websocket.StatusNormalClosure, "pruning connection")
   248  	}()
   249  	if c.profileID != "" {
   250  		h.unsubscribeConn(c.profileID, id)
   251  	}
   252  	h.connsLock.Lock()
   253  	defer h.connsLock.Unlock()
   254  	delete(h.conns, id)
   255  }
   256  
   257  // read listens to the given connection, handling any messages that come through
   258  // stops listening if it encounters any error
   259  func (h *connections) read(id string) error {
   260  	msg := &message{}
   261  
   262  	c, err := h.getConn(id)
   263  	if err != nil {
   264  		return fmt.Errorf("connection %q: %w", id, err)
   265  	}
   266  	ctx := context.Background()
   267  	for {
   268  		err = wsjson.Read(ctx, c.conn, msg)
   269  		if err != nil {
   270  			// all websocket methods that return w/ failure are closed
   271  			// we must prune the closed connection
   272  			h.removeConn(id)
   273  			return err
   274  		}
   275  		h.handleMessage(ctx, c, msg)
   276  	}
   277  }
   278  
   279  // handleMessage handles each message based on msgType
   280  func (h *connections) handleMessage(ctx context.Context, c *conn, msg *message) {
   281  	switch msg.Type {
   282  	case subscribeRequest:
   283  		subMsg := &subscribeMessage{}
   284  		if err := json.Unmarshal(msg.Payload, subMsg); err != nil {
   285  			log.Debugw("websocket unmarshal", "error", err, "connection id", c.id, "msg", msg)
   286  			h.write(ctx, c, &message{Type: subscribeFailure, Error: err})
   287  			return
   288  		}
   289  		if err := h.subscribeConn(c.id, subMsg.Token); err != nil {
   290  			log.Debugw("subscribeConn", "error", err, "connection id", c.id, "msg", msg)
   291  			h.write(ctx, c, &message{Type: subscribeFailure, Error: err})
   292  			return
   293  		}
   294  		h.write(ctx, c, &message{Type: subscribeSuccess})
   295  	case unsubscribeRequest:
   296  		h.unsubscribeConn(c.profileID, c.id)
   297  	default:
   298  		log.Debug("unknown message type over websocket %s: %q", c.id, msg.Type)
   299  	}
   300  }
   301  
   302  // write sends a json message over the connection
   303  func (h *connections) write(ctx context.Context, c *conn, msg *message) {
   304  	log.Debugf("sending message %q to websocket conns %q", msg.Type, c.id)
   305  	if err := wsjson.Write(ctx, c.conn, msg); err != nil {
   306  		log.Errorf("connection %q: wsjson write error: %s", c.id, err)
   307  		// the connection will close if there is any `write` error
   308  		// we must remove it from our own stores, so as not to hold
   309  		// onto any dead connections
   310  		h.removeConn(c.id)
   311  	}
   312  }
   313  
   314  // msgType is the type of message that we receive on the
   315  type msgType string
   316  
   317  const (
   318  	// subscribeRequest indicates the connection is trying to become
   319  	// an authenticated connection
   320  	// payload is a `subscribeMessage`
   321  	subscribeRequest = msgType("subscribe:request")
   322  	// subscribeSuccess indicates that the connection successfully
   323  	// upgraded to an authenticated connection
   324  	// payload is nil
   325  	subscribeSuccess = msgType("subscribe:success")
   326  	// subscribeFailure indicates that the connection did not
   327  	// upgrade to an authenticated connection
   328  	// payload is nil
   329  	subscribeFailure = msgType("subscribe:failure")
   330  	// unsubscribeRequest indicates the connection no longer wants
   331  	// to be authenticated
   332  	// payload is nil
   333  	unsubscribeRequest = msgType("unsubscribe:request")
   334  )
   335  
   336  // message is the expected structure of an incoming websocket message
   337  type message struct {
   338  	Type    msgType         `json:"type"`
   339  	Payload json.RawMessage `json:"payload"`
   340  	Error   error           `json:"error"`
   341  }
   342  
   343  // subscribeMessage is the expected structure of an incoming "subscribe"
   344  // message
   345  type subscribeMessage struct {
   346  	Token string `json:"token"`
   347  }