github.com/philippseith/signalr@v0.6.3/httpmux.go (about)

     1  package signalr
     2  
     3  import (
     4  	"crypto/rand"
     5  	"encoding/base64"
     6  	"encoding/json"
     7  	"fmt"
     8  	"net/http"
     9  	"strconv"
    10  	"strings"
    11  	"sync"
    12  	"time"
    13  
    14  	"github.com/teivah/onecontext"
    15  	"nhooyr.io/websocket"
    16  )
    17  
    18  type httpMux struct {
    19  	mx            sync.RWMutex
    20  	connectionMap map[string]Connection
    21  	server        Server
    22  }
    23  
    24  func newHTTPMux(server Server) *httpMux {
    25  	return &httpMux{
    26  		connectionMap: make(map[string]Connection),
    27  		server:        server,
    28  	}
    29  }
    30  
    31  func (h *httpMux) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
    32  	switch request.Method {
    33  	case "POST":
    34  		h.handlePost(writer, request)
    35  	case "GET":
    36  		h.handleGet(writer, request)
    37  	default:
    38  		writer.WriteHeader(http.StatusBadRequest)
    39  	}
    40  }
    41  
    42  func (h *httpMux) handlePost(writer http.ResponseWriter, request *http.Request) {
    43  	connectionID := request.URL.Query().Get("id")
    44  	if connectionID == "" {
    45  		writer.WriteHeader(http.StatusBadRequest)
    46  		return
    47  	}
    48  	info, _ := h.server.prefixLoggers("")
    49  	for {
    50  		h.mx.RLock()
    51  		c, ok := h.connectionMap[connectionID]
    52  		h.mx.RUnlock()
    53  		if ok {
    54  			// Connection is initiated
    55  			switch conn := c.(type) {
    56  			case *serverSSEConnection:
    57  				writer.WriteHeader(conn.consumeRequest(request))
    58  				return
    59  			case *negotiateConnection:
    60  				// connection start initiated but not completed
    61  			default:
    62  				// ConnectionID already used for WebSocket(?)
    63  				writer.WriteHeader(http.StatusConflict)
    64  				return
    65  			}
    66  		} else {
    67  			writer.WriteHeader(http.StatusNotFound)
    68  			return
    69  		}
    70  		<-time.After(10 * time.Millisecond)
    71  		_ = info.Log("event", "handlePost for SSE connection repeated")
    72  	}
    73  }
    74  
    75  func (h *httpMux) handleGet(writer http.ResponseWriter, request *http.Request) {
    76  	upgrade := false
    77  	for _, connHead := range strings.Split(request.Header.Get("Connection"), ",") {
    78  		if strings.ToLower(strings.TrimSpace(connHead)) == "upgrade" {
    79  			upgrade = true
    80  			break
    81  		}
    82  	}
    83  	if upgrade &&
    84  		strings.ToLower(request.Header.Get("Upgrade")) == "websocket" {
    85  		h.handleWebsocket(writer, request)
    86  	} else if strings.ToLower(request.Header.Get("Accept")) == "text/event-stream" {
    87  		h.handleServerSentEvent(writer, request)
    88  	} else {
    89  		writer.WriteHeader(http.StatusBadRequest)
    90  	}
    91  }
    92  
    93  func (h *httpMux) handleServerSentEvent(writer http.ResponseWriter, request *http.Request) {
    94  	connectionID := request.URL.Query().Get("id")
    95  	if connectionID == "" {
    96  		writer.WriteHeader(http.StatusBadRequest)
    97  		return
    98  	}
    99  	h.mx.RLock()
   100  	c, ok := h.connectionMap[connectionID]
   101  	h.mx.RUnlock()
   102  	if ok {
   103  		if _, ok := c.(*negotiateConnection); ok {
   104  			ctx, _ := onecontext.Merge(h.server.context(), request.Context())
   105  			sseConn, jobChan, jobResultChan, err := newServerSSEConnection(ctx, c.ConnectionID())
   106  			if err != nil {
   107  				writer.WriteHeader(http.StatusInternalServerError)
   108  				return
   109  			}
   110  			flusher, ok := writer.(http.Flusher)
   111  			if !ok {
   112  				writer.WriteHeader(http.StatusInternalServerError)
   113  				return
   114  			}
   115  			// Connection is negotiated but not initiated
   116  			// We compose http and send it over sse
   117  			writer.Header().Set("Content-Type", "text/event-stream")
   118  			writer.Header().Set("Connection", "keep-alive")
   119  			writer.Header().Set("Cache-Control", "no-cache")
   120  			writer.WriteHeader(http.StatusOK)
   121  			// End this Server Sent Event (yes, your response now is one and the client will wait for this initial event to end)
   122  			_, _ = fmt.Fprint(writer, ":\r\n\r\n")
   123  			writer.(http.Flusher).Flush()
   124  			go func() {
   125  				// We can't WriteHeader 500 if we get an error as we already wrote the header, so ignore it.
   126  				_ = h.serveConnection(sseConn)
   127  			}()
   128  			// Loop for write jobs from the sseServerConnection
   129  			for buf := range jobChan {
   130  				n, err := writer.Write(buf)
   131  				if err == nil {
   132  					flusher.Flush()
   133  				}
   134  				jobResultChan <- RWJobResult{n: n, err: err}
   135  			}
   136  			close(jobResultChan)
   137  		} else {
   138  			// connectionID in use
   139  			writer.WriteHeader(http.StatusConflict)
   140  		}
   141  	} else {
   142  		writer.WriteHeader(http.StatusNotFound)
   143  	}
   144  }
   145  
   146  func (h *httpMux) handleWebsocket(writer http.ResponseWriter, request *http.Request) {
   147  	accOptions := &websocket.AcceptOptions{
   148  		CompressionMode:    websocket.CompressionContextTakeover,
   149  		InsecureSkipVerify: h.server.insecureSkipVerify(),
   150  		OriginPatterns:     h.server.originPatterns(),
   151  	}
   152  	websocketConn, err := websocket.Accept(writer, request, accOptions)
   153  	if err != nil {
   154  		_, debug := h.server.loggers()
   155  		_ = debug.Log(evt, "handleWebsocket", msg, "error accepting websockets", "error", err)
   156  		// don't need to write an error header here as websocket.Accept has already used http.Error
   157  		return
   158  	}
   159  	websocketConn.SetReadLimit(int64(h.server.maximumReceiveMessageSize()))
   160  	connectionMapKey := request.URL.Query().Get("id")
   161  	if connectionMapKey == "" {
   162  		// Support websocket connection without negotiate
   163  		connectionMapKey = newConnectionID()
   164  		h.mx.Lock()
   165  		h.connectionMap[connectionMapKey] = &negotiateConnection{
   166  			ConnectionBase{connectionID: connectionMapKey},
   167  		}
   168  		h.mx.Unlock()
   169  	}
   170  	h.mx.RLock()
   171  	c, ok := h.connectionMap[connectionMapKey]
   172  	h.mx.RUnlock()
   173  	if ok {
   174  		if _, ok := c.(*negotiateConnection); ok {
   175  			// Connection is negotiated but not initiated
   176  			ctx, _ := onecontext.Merge(h.server.context(), request.Context())
   177  			err = h.serveConnection(newWebSocketConnection(ctx, c.ConnectionID(), websocketConn))
   178  			if err != nil {
   179  				_ = websocketConn.Close(1005, err.Error())
   180  			}
   181  		} else {
   182  			// Already initiated
   183  			_ = websocketConn.Close(1002, "Bad request")
   184  		}
   185  	} else {
   186  		// Not negotiated
   187  		_ = websocketConn.Close(1002, "Not found")
   188  	}
   189  }
   190  
   191  func (h *httpMux) negotiate(w http.ResponseWriter, req *http.Request) {
   192  	if req.Method != "POST" {
   193  		w.WriteHeader(http.StatusBadRequest)
   194  	} else {
   195  		connectionID := newConnectionID()
   196  		connectionMapKey := connectionID
   197  		negotiateVersion, err := strconv.Atoi(req.Header.Get("negotiateVersion"))
   198  		if err != nil {
   199  			negotiateVersion = 0
   200  		}
   201  		connectionToken := ""
   202  		if negotiateVersion == 1 {
   203  			connectionToken = newConnectionID()
   204  			connectionMapKey = connectionToken
   205  		}
   206  		h.mx.Lock()
   207  		h.connectionMap[connectionMapKey] = &negotiateConnection{
   208  			ConnectionBase{connectionID: connectionID},
   209  		}
   210  		h.mx.Unlock()
   211  		var availableTransports []availableTransport
   212  		for _, transport := range h.server.availableTransports() {
   213  			switch transport {
   214  			case TransportServerSentEvents:
   215  				availableTransports = append(availableTransports,
   216  					availableTransport{
   217  						Transport:       string(TransportServerSentEvents),
   218  						TransferFormats: []string{string(TransferFormatText)},
   219  					})
   220  			case TransportWebSockets:
   221  				availableTransports = append(availableTransports,
   222  					availableTransport{
   223  						Transport:       string(TransportWebSockets),
   224  						TransferFormats: []string{string(TransferFormatText), string(TransferFormatBinary)},
   225  					})
   226  			}
   227  		}
   228  		response := negotiateResponse{
   229  			ConnectionToken:     connectionToken,
   230  			ConnectionID:        connectionID,
   231  			NegotiateVersion:    negotiateVersion,
   232  			AvailableTransports: availableTransports,
   233  		}
   234  
   235  		w.WriteHeader(http.StatusOK)
   236  		_ = json.NewEncoder(w).Encode(response) // Can't imagine an error when encoding
   237  	}
   238  }
   239  
   240  func (h *httpMux) serveConnection(c Connection) error {
   241  	h.mx.Lock()
   242  	h.connectionMap[c.ConnectionID()] = c
   243  	h.mx.Unlock()
   244  	return h.server.Serve(c)
   245  }
   246  
   247  func newConnectionID() string {
   248  	bytes := make([]byte, 16)
   249  	// rand.Read only fails when the systems random number generator fails. Rare case, ignore
   250  	_, _ = rand.Read(bytes)
   251  	// Important: Use URLEncoding. StdEncoding contains "/" which will be randomly part of the connectionID and cause parsing problems
   252  	return base64.URLEncoding.EncodeToString(bytes)
   253  }
   254  
   255  type negotiateConnection struct {
   256  	ConnectionBase
   257  }
   258  
   259  func (n *negotiateConnection) Read([]byte) (int, error) {
   260  	return 0, nil
   261  }
   262  
   263  func (n *negotiateConnection) Write([]byte) (int, error) {
   264  	return 0, nil
   265  }