github.com/philippseith/signalr@v0.6.3/httpmux.go (about) 1 package signalr 2 3 import ( 4 "crypto/rand" 5 "encoding/base64" 6 "encoding/json" 7 "fmt" 8 "net/http" 9 "strconv" 10 "strings" 11 "sync" 12 "time" 13 14 "github.com/teivah/onecontext" 15 "nhooyr.io/websocket" 16 ) 17 18 type httpMux struct { 19 mx sync.RWMutex 20 connectionMap map[string]Connection 21 server Server 22 } 23 24 func newHTTPMux(server Server) *httpMux { 25 return &httpMux{ 26 connectionMap: make(map[string]Connection), 27 server: server, 28 } 29 } 30 31 func (h *httpMux) ServeHTTP(writer http.ResponseWriter, request *http.Request) { 32 switch request.Method { 33 case "POST": 34 h.handlePost(writer, request) 35 case "GET": 36 h.handleGet(writer, request) 37 default: 38 writer.WriteHeader(http.StatusBadRequest) 39 } 40 } 41 42 func (h *httpMux) handlePost(writer http.ResponseWriter, request *http.Request) { 43 connectionID := request.URL.Query().Get("id") 44 if connectionID == "" { 45 writer.WriteHeader(http.StatusBadRequest) 46 return 47 } 48 info, _ := h.server.prefixLoggers("") 49 for { 50 h.mx.RLock() 51 c, ok := h.connectionMap[connectionID] 52 h.mx.RUnlock() 53 if ok { 54 // Connection is initiated 55 switch conn := c.(type) { 56 case *serverSSEConnection: 57 writer.WriteHeader(conn.consumeRequest(request)) 58 return 59 case *negotiateConnection: 60 // connection start initiated but not completed 61 default: 62 // ConnectionID already used for WebSocket(?) 63 writer.WriteHeader(http.StatusConflict) 64 return 65 } 66 } else { 67 writer.WriteHeader(http.StatusNotFound) 68 return 69 } 70 <-time.After(10 * time.Millisecond) 71 _ = info.Log("event", "handlePost for SSE connection repeated") 72 } 73 } 74 75 func (h *httpMux) handleGet(writer http.ResponseWriter, request *http.Request) { 76 upgrade := false 77 for _, connHead := range strings.Split(request.Header.Get("Connection"), ",") { 78 if strings.ToLower(strings.TrimSpace(connHead)) == "upgrade" { 79 upgrade = true 80 break 81 } 82 } 83 if upgrade && 84 strings.ToLower(request.Header.Get("Upgrade")) == "websocket" { 85 h.handleWebsocket(writer, request) 86 } else if strings.ToLower(request.Header.Get("Accept")) == "text/event-stream" { 87 h.handleServerSentEvent(writer, request) 88 } else { 89 writer.WriteHeader(http.StatusBadRequest) 90 } 91 } 92 93 func (h *httpMux) handleServerSentEvent(writer http.ResponseWriter, request *http.Request) { 94 connectionID := request.URL.Query().Get("id") 95 if connectionID == "" { 96 writer.WriteHeader(http.StatusBadRequest) 97 return 98 } 99 h.mx.RLock() 100 c, ok := h.connectionMap[connectionID] 101 h.mx.RUnlock() 102 if ok { 103 if _, ok := c.(*negotiateConnection); ok { 104 ctx, _ := onecontext.Merge(h.server.context(), request.Context()) 105 sseConn, jobChan, jobResultChan, err := newServerSSEConnection(ctx, c.ConnectionID()) 106 if err != nil { 107 writer.WriteHeader(http.StatusInternalServerError) 108 return 109 } 110 flusher, ok := writer.(http.Flusher) 111 if !ok { 112 writer.WriteHeader(http.StatusInternalServerError) 113 return 114 } 115 // Connection is negotiated but not initiated 116 // We compose http and send it over sse 117 writer.Header().Set("Content-Type", "text/event-stream") 118 writer.Header().Set("Connection", "keep-alive") 119 writer.Header().Set("Cache-Control", "no-cache") 120 writer.WriteHeader(http.StatusOK) 121 // End this Server Sent Event (yes, your response now is one and the client will wait for this initial event to end) 122 _, _ = fmt.Fprint(writer, ":\r\n\r\n") 123 writer.(http.Flusher).Flush() 124 go func() { 125 // We can't WriteHeader 500 if we get an error as we already wrote the header, so ignore it. 126 _ = h.serveConnection(sseConn) 127 }() 128 // Loop for write jobs from the sseServerConnection 129 for buf := range jobChan { 130 n, err := writer.Write(buf) 131 if err == nil { 132 flusher.Flush() 133 } 134 jobResultChan <- RWJobResult{n: n, err: err} 135 } 136 close(jobResultChan) 137 } else { 138 // connectionID in use 139 writer.WriteHeader(http.StatusConflict) 140 } 141 } else { 142 writer.WriteHeader(http.StatusNotFound) 143 } 144 } 145 146 func (h *httpMux) handleWebsocket(writer http.ResponseWriter, request *http.Request) { 147 accOptions := &websocket.AcceptOptions{ 148 CompressionMode: websocket.CompressionContextTakeover, 149 InsecureSkipVerify: h.server.insecureSkipVerify(), 150 OriginPatterns: h.server.originPatterns(), 151 } 152 websocketConn, err := websocket.Accept(writer, request, accOptions) 153 if err != nil { 154 _, debug := h.server.loggers() 155 _ = debug.Log(evt, "handleWebsocket", msg, "error accepting websockets", "error", err) 156 // don't need to write an error header here as websocket.Accept has already used http.Error 157 return 158 } 159 websocketConn.SetReadLimit(int64(h.server.maximumReceiveMessageSize())) 160 connectionMapKey := request.URL.Query().Get("id") 161 if connectionMapKey == "" { 162 // Support websocket connection without negotiate 163 connectionMapKey = newConnectionID() 164 h.mx.Lock() 165 h.connectionMap[connectionMapKey] = &negotiateConnection{ 166 ConnectionBase{connectionID: connectionMapKey}, 167 } 168 h.mx.Unlock() 169 } 170 h.mx.RLock() 171 c, ok := h.connectionMap[connectionMapKey] 172 h.mx.RUnlock() 173 if ok { 174 if _, ok := c.(*negotiateConnection); ok { 175 // Connection is negotiated but not initiated 176 ctx, _ := onecontext.Merge(h.server.context(), request.Context()) 177 err = h.serveConnection(newWebSocketConnection(ctx, c.ConnectionID(), websocketConn)) 178 if err != nil { 179 _ = websocketConn.Close(1005, err.Error()) 180 } 181 } else { 182 // Already initiated 183 _ = websocketConn.Close(1002, "Bad request") 184 } 185 } else { 186 // Not negotiated 187 _ = websocketConn.Close(1002, "Not found") 188 } 189 } 190 191 func (h *httpMux) negotiate(w http.ResponseWriter, req *http.Request) { 192 if req.Method != "POST" { 193 w.WriteHeader(http.StatusBadRequest) 194 } else { 195 connectionID := newConnectionID() 196 connectionMapKey := connectionID 197 negotiateVersion, err := strconv.Atoi(req.Header.Get("negotiateVersion")) 198 if err != nil { 199 negotiateVersion = 0 200 } 201 connectionToken := "" 202 if negotiateVersion == 1 { 203 connectionToken = newConnectionID() 204 connectionMapKey = connectionToken 205 } 206 h.mx.Lock() 207 h.connectionMap[connectionMapKey] = &negotiateConnection{ 208 ConnectionBase{connectionID: connectionID}, 209 } 210 h.mx.Unlock() 211 var availableTransports []availableTransport 212 for _, transport := range h.server.availableTransports() { 213 switch transport { 214 case TransportServerSentEvents: 215 availableTransports = append(availableTransports, 216 availableTransport{ 217 Transport: string(TransportServerSentEvents), 218 TransferFormats: []string{string(TransferFormatText)}, 219 }) 220 case TransportWebSockets: 221 availableTransports = append(availableTransports, 222 availableTransport{ 223 Transport: string(TransportWebSockets), 224 TransferFormats: []string{string(TransferFormatText), string(TransferFormatBinary)}, 225 }) 226 } 227 } 228 response := negotiateResponse{ 229 ConnectionToken: connectionToken, 230 ConnectionID: connectionID, 231 NegotiateVersion: negotiateVersion, 232 AvailableTransports: availableTransports, 233 } 234 235 w.WriteHeader(http.StatusOK) 236 _ = json.NewEncoder(w).Encode(response) // Can't imagine an error when encoding 237 } 238 } 239 240 func (h *httpMux) serveConnection(c Connection) error { 241 h.mx.Lock() 242 h.connectionMap[c.ConnectionID()] = c 243 h.mx.Unlock() 244 return h.server.Serve(c) 245 } 246 247 func newConnectionID() string { 248 bytes := make([]byte, 16) 249 // rand.Read only fails when the systems random number generator fails. Rare case, ignore 250 _, _ = rand.Read(bytes) 251 // Important: Use URLEncoding. StdEncoding contains "/" which will be randomly part of the connectionID and cause parsing problems 252 return base64.URLEncoding.EncodeToString(bytes) 253 } 254 255 type negotiateConnection struct { 256 ConnectionBase 257 } 258 259 func (n *negotiateConnection) Read([]byte) (int, error) { 260 return 0, nil 261 } 262 263 func (n *negotiateConnection) Write([]byte) (int, error) { 264 return 0, nil 265 }