github.com/anycable/anycable-go@v1.5.1/sse/handler.go (about)

     1  package sse
     2  
     3  import (
     4  	"context"
     5  	"log/slog"
     6  	"net/http"
     7  	"strings"
     8  
     9  	"github.com/anycable/anycable-go/common"
    10  	"github.com/anycable/anycable-go/node"
    11  	"github.com/anycable/anycable-go/server"
    12  	"github.com/anycable/anycable-go/version"
    13  	"github.com/anycable/anycable-go/ws"
    14  )
    15  
    16  // SSEHandler generates a new http handler for SSE connections
    17  func SSEHandler(n *node.Node, shutdownCtx context.Context, headersExtractor server.HeadersExtractor, config *Config, l *slog.Logger) http.Handler {
    18  	var allowedHosts []string
    19  
    20  	if config.AllowedOrigins == "" {
    21  		allowedHosts = []string{}
    22  	} else {
    23  		allowedHosts = strings.Split(config.AllowedOrigins, ",")
    24  	}
    25  
    26  	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    27  		// Write CORS headers
    28  		server.WriteCORSHeaders(w, r, allowedHosts)
    29  
    30  		// Respond to preflight requests
    31  		if r.Method == http.MethodOptions {
    32  			w.WriteHeader(http.StatusOK)
    33  			return
    34  		}
    35  
    36  		// SSE only supports GET and POST requests
    37  		if r.Method != http.MethodGet && r.Method != http.MethodPost {
    38  			w.WriteHeader(http.StatusMethodNotAllowed)
    39  			return
    40  		}
    41  
    42  		// Prepare common headers
    43  		w.Header().Set("X-AnyCable-Version", version.Version())
    44  		if r.ProtoMajor == 1 {
    45  			// An endpoint MUST NOT generate an HTTP/2 message containing connection-specific header fields.
    46  			// Source: RFC7540.
    47  			w.Header().Set("Connection", "keep-alive")
    48  		}
    49  		w.Header().Set("Content-Type", "text/event-stream; charset=utf-8")
    50  		w.Header().Set("X-Content-Type-Options", "nosniff")
    51  		w.Header().Set("X-Accel-Buffering", "no")
    52  		w.Header().Set("Cache-Control", "private, no-cache, no-store, must-revalidate, max-age=0") // HTTP 1.1
    53  		w.Header().Set("Pragma", "no-cache")                                                       // HTTP 1.0
    54  		w.Header().Set("Expire", "0")
    55  
    56  		flusher, ok := w.(http.Flusher)
    57  		if !ok {
    58  			w.WriteHeader(http.StatusNotImplemented)
    59  			return
    60  		}
    61  
    62  		info, err := server.NewRequestInfo(r, headersExtractor)
    63  		if err != nil {
    64  			w.WriteHeader(http.StatusBadRequest)
    65  			return
    66  		}
    67  
    68  		sessionCtx := l.With("sid", info.UID).With("transport", "sse")
    69  
    70  		subscribeCmds, err := subscribeCommandsFromRequest(r)
    71  
    72  		if err != nil {
    73  			sessionCtx.Error("failed to build subscribe command", "error", err)
    74  			w.WriteHeader(http.StatusBadRequest)
    75  			return
    76  		}
    77  
    78  		// Finally, we can establish a session
    79  		session, err := NewSSESession(n, w, r, info)
    80  
    81  		if err != nil {
    82  			sessionCtx.Error("failed to establish sesssion", "error", err)
    83  			w.WriteHeader(http.StatusBadRequest)
    84  			return
    85  		}
    86  
    87  		if session == nil {
    88  			sessionCtx.Error("authentication failed")
    89  			w.WriteHeader(http.StatusUnauthorized)
    90  			return
    91  		}
    92  
    93  		// Make sure we remove the session from the node when we're done (especially if we return earlier due to rejected subscription)
    94  		defer session.Disconnect("Closed", ws.CloseNormalClosure)
    95  
    96  		conn := session.UnderlyingConn().(*Connection)
    97  
    98  		for _, subscribeCmd := range subscribeCmds {
    99  			// Subscribe to the channel
   100  			res, err := n.Subscribe(session, subscribeCmd)
   101  
   102  			if err != nil || res == nil {
   103  				sessionCtx.Error("failed to subscribe", "error", err)
   104  				w.WriteHeader(http.StatusBadRequest)
   105  				return
   106  			}
   107  
   108  			// Subscription rejected
   109  			if res.Status != common.SUCCESS {
   110  				sessionCtx.Debug("subscription rejected")
   111  				w.WriteHeader(http.StatusBadRequest)
   112  				return
   113  			}
   114  		}
   115  
   116  		w.WriteHeader(http.StatusOK)
   117  		flusher.Flush()
   118  
   119  		conn.Established()
   120  		sessionCtx.Debug("session established")
   121  
   122  		shutdownReceived := false
   123  
   124  		for {
   125  			select {
   126  			case <-shutdownCtx.Done():
   127  				if !shutdownReceived {
   128  					shutdownReceived = true
   129  					sessionCtx.Debug("server shutdown")
   130  					session.DisconnectWithMessage(
   131  						&common.DisconnectMessage{Type: "disconnect", Reason: common.SERVER_RESTART_REASON, Reconnect: true},
   132  						common.SERVER_RESTART_REASON,
   133  					)
   134  				}
   135  			case <-r.Context().Done():
   136  				sessionCtx.Debug("request terminated")
   137  				session.DisconnectNow("Closed", ws.CloseNormalClosure)
   138  				return
   139  			case <-conn.Context().Done():
   140  				sessionCtx.Debug("session completed")
   141  				return
   142  			}
   143  		}
   144  	})
   145  }