github.com/qri-io/qri@v0.10.1-0.20220104210721-c771715036cb/lib/websocket/websocket.go (about) 1 package websocket 2 3 import ( 4 "context" 5 "encoding/json" 6 "fmt" 7 "io" 8 "net/http" 9 "sync" 10 11 "github.com/google/uuid" 12 golog "github.com/ipfs/go-log" 13 "github.com/qri-io/qri/auth/key" 14 "github.com/qri-io/qri/auth/token" 15 "github.com/qri-io/qri/event" 16 "nhooyr.io/websocket" 17 "nhooyr.io/websocket/wsjson" 18 ) 19 20 const qriWebsocketProtocol = "qri-websocket" 21 22 var ( 23 errNotFound = fmt.Errorf("connection not found") 24 25 log = golog.Logger("websocket") 26 ) 27 28 // newID returns a new websocket connection ID 29 func newID() string { 30 return uuid.New().String() 31 } 32 33 // setIDRand sets the random reader that NewID uses as a source of random bytes 34 // passing in nil will default to crypto.Rand. This can be used to make ID 35 // generation deterministic for tests. eg: 36 // myString := "SomeRandomStringThatIsLong-SoYouCanCallItAsMuchAsNeeded..." 37 // lib.SetIDRand(strings.NewReader(myString)) 38 // a := NewID() 39 // lib.SetIDRand(strings.NewReader(myString)) 40 // b := NewID() 41 func setIDRand(r io.Reader) { 42 uuid.SetRand(r) 43 } 44 45 // Handler defines the handler interface 46 type Handler interface { 47 ConnectionHandler(w http.ResponseWriter, r *http.Request) 48 } 49 50 type connectionSet map[string]struct{} 51 52 // connections maintains the set of active websocket connections & associated 53 // connection metadata 54 type connections struct { 55 conns map[string]*conn 56 connsLock sync.Mutex 57 keystore key.Store 58 subscriptions map[string]connectionSet 59 subsLock sync.Mutex 60 } 61 62 type conn struct { 63 id string 64 profileID string 65 conn *websocket.Conn 66 } 67 68 var _ Handler = (*connections)(nil) 69 70 // NewHandler creates a new connections instance that clients 71 // can connect to in order to get realtime events 72 func NewHandler(ctx context.Context, bus event.Bus, keystore key.Store) (Handler, error) { 73 ws := &connections{ 74 conns: map[string]*conn{}, 75 connsLock: sync.Mutex{}, 76 keystore: keystore, 77 subscriptions: map[string]connectionSet{}, 78 subsLock: sync.Mutex{}, 79 } 80 81 bus.SubscribeAll(ws.messageHandler) 82 return ws, nil 83 } 84 85 // ConnectionHandler handles websocket upgrade requests and accepts the connection 86 func (h *connections) ConnectionHandler(w http.ResponseWriter, r *http.Request) { 87 wsc, err := websocket.Accept(w, r, &websocket.AcceptOptions{ 88 Subprotocols: []string{qriWebsocketProtocol}, 89 InsecureSkipVerify: true, 90 }) 91 if err != nil { 92 log.Debugf("Websocket accept error: %s", err) 93 return 94 } 95 id := newID() 96 c := &conn{ 97 id: id, 98 conn: wsc, 99 } 100 h.connsLock.Lock() 101 defer h.connsLock.Unlock() 102 h.conns[id] = c 103 go h.read(id) 104 } 105 106 func (h *connections) messageHandler(_ context.Context, e event.Event) error { 107 ctx := context.Background() 108 evt := map[string]interface{}{ 109 "type": string(e.Type), 110 "ts": e.Timestamp, 111 "sessionID": e.SessionID, 112 "data": e.Payload, 113 } 114 115 profileIDString := e.ProfileID 116 if profileIDString == "" { 117 return nil 118 } 119 connIDs, err := h.getConnIDs(profileIDString) 120 if err != nil { 121 log.Errorf("profile %q: %w", profileIDString, err) 122 return nil 123 } 124 125 for connID := range connIDs { 126 c, err := h.getConn(connID) 127 if err != nil { 128 h.unsubscribeConn(profileIDString, connID) 129 log.Errorf("connection %q, profile %q: %w", connID, profileIDString, err) 130 return nil 131 } 132 log.Debugf("sending event %q to websocket conns %q", e.Type, profileIDString) 133 if err := wsjson.Write(ctx, c.conn, evt); err != nil { 134 log.Errorf("connection %q: wsjson write error: %s", profileIDString, err) 135 return nil 136 } 137 } 138 return nil 139 } 140 141 // getConn gets a *conn from the map of connections 142 func (h *connections) getConn(id string) (*conn, error) { 143 h.connsLock.Lock() 144 defer h.connsLock.Unlock() 145 c, ok := h.conns[id] 146 if !ok { 147 return nil, errNotFound 148 } 149 if c == nil { 150 return nil, errNotFound 151 } 152 return c, nil 153 } 154 155 // getConnID returns the connection ID associated with the given profile.ID string 156 func (h *connections) getConnIDs(profileID string) (connectionSet, error) { 157 h.subsLock.Lock() 158 defer h.subsLock.Unlock() 159 ids, ok := h.subscriptions[profileID] 160 if !ok { 161 return nil, errNotFound 162 } 163 if ids == nil { 164 delete(h.subscriptions, profileID) 165 return nil, errNotFound 166 } 167 return ids, nil 168 } 169 170 // subscribeConn authenticates the given token and adds the connID to the map 171 // of "subscribed" connections 172 func (h *connections) subscribeConn(connID, tokenString string) error { 173 ctx := context.TODO() 174 tok, err := token.ParseAuthToken(ctx, tokenString, h.keystore) 175 if err != nil { 176 return err 177 } 178 179 claims, ok := tok.Claims.(*token.Claims) 180 if !ok || claims.Subject == "" { 181 h.removeConn(connID) 182 return fmt.Errorf("cannot get profile.ID from token") 183 } 184 // TODO(b5): at this point we have a valid signature of a profileID string 185 // but no proof that this profile is owned by the key that signed the 186 // token. We either need ProfileID == KeyID, or we need a UCAN. we need to 187 // check for those, ideally in a method within the profile package that 188 // abstracts over profile & key agreement 189 190 c, err := h.getConn(connID) 191 if err != nil { 192 return fmt.Errorf("connection %q: %w", connID, err) 193 } 194 c.profileID = claims.Subject 195 196 h.subsLock.Lock() 197 defer h.subsLock.Unlock() 198 connIDs, ok := h.subscriptions[claims.Subject] 199 if !ok || connIDs == nil { 200 connIDs = connectionSet{} 201 } 202 connIDs[connID] = struct{}{} 203 h.subscriptions[claims.Subject] = connIDs 204 log.Debugw("subscribeConn", "id", connID) 205 return nil 206 } 207 208 // unsubscribeConn remove the profileID and connID from the map of "subscribed" 209 // connections 210 func (h *connections) unsubscribeConn(profileID, connID string) { 211 connIDs, err := h.getConnIDs(profileID) 212 if err != nil { 213 return 214 } 215 for cid := range connIDs { 216 if connID == "" || cid == connID { 217 c, err := h.getConn(cid) 218 if err != nil || c == nil { 219 continue 220 } 221 c.profileID = "" 222 } 223 } 224 225 h.subsLock.Lock() 226 defer h.subsLock.Unlock() 227 if connID == "" { 228 delete(h.subscriptions, profileID) 229 } else { 230 if _, ok := h.subscriptions[profileID]; ok { 231 delete(h.subscriptions[profileID], connID) 232 } 233 if len(h.subscriptions[profileID]) == 0 { 234 delete(h.subscriptions, profileID) 235 } 236 } 237 } 238 239 // removeConn removes the conn from the map of connections and subscriptions 240 // closing the connection if needed 241 func (h *connections) removeConn(id string) { 242 c, err := h.getConn(id) 243 if err != nil { 244 return 245 } 246 defer func() { 247 c.conn.Close(websocket.StatusNormalClosure, "pruning connection") 248 }() 249 if c.profileID != "" { 250 h.unsubscribeConn(c.profileID, id) 251 } 252 h.connsLock.Lock() 253 defer h.connsLock.Unlock() 254 delete(h.conns, id) 255 } 256 257 // read listens to the given connection, handling any messages that come through 258 // stops listening if it encounters any error 259 func (h *connections) read(id string) error { 260 msg := &message{} 261 262 c, err := h.getConn(id) 263 if err != nil { 264 return fmt.Errorf("connection %q: %w", id, err) 265 } 266 ctx := context.Background() 267 for { 268 err = wsjson.Read(ctx, c.conn, msg) 269 if err != nil { 270 // all websocket methods that return w/ failure are closed 271 // we must prune the closed connection 272 h.removeConn(id) 273 return err 274 } 275 h.handleMessage(ctx, c, msg) 276 } 277 } 278 279 // handleMessage handles each message based on msgType 280 func (h *connections) handleMessage(ctx context.Context, c *conn, msg *message) { 281 switch msg.Type { 282 case subscribeRequest: 283 subMsg := &subscribeMessage{} 284 if err := json.Unmarshal(msg.Payload, subMsg); err != nil { 285 log.Debugw("websocket unmarshal", "error", err, "connection id", c.id, "msg", msg) 286 h.write(ctx, c, &message{Type: subscribeFailure, Error: err}) 287 return 288 } 289 if err := h.subscribeConn(c.id, subMsg.Token); err != nil { 290 log.Debugw("subscribeConn", "error", err, "connection id", c.id, "msg", msg) 291 h.write(ctx, c, &message{Type: subscribeFailure, Error: err}) 292 return 293 } 294 h.write(ctx, c, &message{Type: subscribeSuccess}) 295 case unsubscribeRequest: 296 h.unsubscribeConn(c.profileID, c.id) 297 default: 298 log.Debug("unknown message type over websocket %s: %q", c.id, msg.Type) 299 } 300 } 301 302 // write sends a json message over the connection 303 func (h *connections) write(ctx context.Context, c *conn, msg *message) { 304 log.Debugf("sending message %q to websocket conns %q", msg.Type, c.id) 305 if err := wsjson.Write(ctx, c.conn, msg); err != nil { 306 log.Errorf("connection %q: wsjson write error: %s", c.id, err) 307 // the connection will close if there is any `write` error 308 // we must remove it from our own stores, so as not to hold 309 // onto any dead connections 310 h.removeConn(c.id) 311 } 312 } 313 314 // msgType is the type of message that we receive on the 315 type msgType string 316 317 const ( 318 // subscribeRequest indicates the connection is trying to become 319 // an authenticated connection 320 // payload is a `subscribeMessage` 321 subscribeRequest = msgType("subscribe:request") 322 // subscribeSuccess indicates that the connection successfully 323 // upgraded to an authenticated connection 324 // payload is nil 325 subscribeSuccess = msgType("subscribe:success") 326 // subscribeFailure indicates that the connection did not 327 // upgrade to an authenticated connection 328 // payload is nil 329 subscribeFailure = msgType("subscribe:failure") 330 // unsubscribeRequest indicates the connection no longer wants 331 // to be authenticated 332 // payload is nil 333 unsubscribeRequest = msgType("unsubscribe:request") 334 ) 335 336 // message is the expected structure of an incoming websocket message 337 type message struct { 338 Type msgType `json:"type"` 339 Payload json.RawMessage `json:"payload"` 340 Error error `json:"error"` 341 } 342 343 // subscribeMessage is the expected structure of an incoming "subscribe" 344 // message 345 type subscribeMessage struct { 346 Token string `json:"token"` 347 }