github.com/anycable/anycable-go@v1.5.1/sse/handler.go (about) 1 package sse 2 3 import ( 4 "context" 5 "log/slog" 6 "net/http" 7 "strings" 8 9 "github.com/anycable/anycable-go/common" 10 "github.com/anycable/anycable-go/node" 11 "github.com/anycable/anycable-go/server" 12 "github.com/anycable/anycable-go/version" 13 "github.com/anycable/anycable-go/ws" 14 ) 15 16 // SSEHandler generates a new http handler for SSE connections 17 func SSEHandler(n *node.Node, shutdownCtx context.Context, headersExtractor server.HeadersExtractor, config *Config, l *slog.Logger) http.Handler { 18 var allowedHosts []string 19 20 if config.AllowedOrigins == "" { 21 allowedHosts = []string{} 22 } else { 23 allowedHosts = strings.Split(config.AllowedOrigins, ",") 24 } 25 26 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 27 // Write CORS headers 28 server.WriteCORSHeaders(w, r, allowedHosts) 29 30 // Respond to preflight requests 31 if r.Method == http.MethodOptions { 32 w.WriteHeader(http.StatusOK) 33 return 34 } 35 36 // SSE only supports GET and POST requests 37 if r.Method != http.MethodGet && r.Method != http.MethodPost { 38 w.WriteHeader(http.StatusMethodNotAllowed) 39 return 40 } 41 42 // Prepare common headers 43 w.Header().Set("X-AnyCable-Version", version.Version()) 44 if r.ProtoMajor == 1 { 45 // An endpoint MUST NOT generate an HTTP/2 message containing connection-specific header fields. 46 // Source: RFC7540. 47 w.Header().Set("Connection", "keep-alive") 48 } 49 w.Header().Set("Content-Type", "text/event-stream; charset=utf-8") 50 w.Header().Set("X-Content-Type-Options", "nosniff") 51 w.Header().Set("X-Accel-Buffering", "no") 52 w.Header().Set("Cache-Control", "private, no-cache, no-store, must-revalidate, max-age=0") // HTTP 1.1 53 w.Header().Set("Pragma", "no-cache") // HTTP 1.0 54 w.Header().Set("Expire", "0") 55 56 flusher, ok := w.(http.Flusher) 57 if !ok { 58 w.WriteHeader(http.StatusNotImplemented) 59 return 60 } 61 62 info, err := server.NewRequestInfo(r, headersExtractor) 63 if err != nil { 64 w.WriteHeader(http.StatusBadRequest) 65 return 66 } 67 68 sessionCtx := l.With("sid", info.UID).With("transport", "sse") 69 70 subscribeCmds, err := subscribeCommandsFromRequest(r) 71 72 if err != nil { 73 sessionCtx.Error("failed to build subscribe command", "error", err) 74 w.WriteHeader(http.StatusBadRequest) 75 return 76 } 77 78 // Finally, we can establish a session 79 session, err := NewSSESession(n, w, r, info) 80 81 if err != nil { 82 sessionCtx.Error("failed to establish sesssion", "error", err) 83 w.WriteHeader(http.StatusBadRequest) 84 return 85 } 86 87 if session == nil { 88 sessionCtx.Error("authentication failed") 89 w.WriteHeader(http.StatusUnauthorized) 90 return 91 } 92 93 // Make sure we remove the session from the node when we're done (especially if we return earlier due to rejected subscription) 94 defer session.Disconnect("Closed", ws.CloseNormalClosure) 95 96 conn := session.UnderlyingConn().(*Connection) 97 98 for _, subscribeCmd := range subscribeCmds { 99 // Subscribe to the channel 100 res, err := n.Subscribe(session, subscribeCmd) 101 102 if err != nil || res == nil { 103 sessionCtx.Error("failed to subscribe", "error", err) 104 w.WriteHeader(http.StatusBadRequest) 105 return 106 } 107 108 // Subscription rejected 109 if res.Status != common.SUCCESS { 110 sessionCtx.Debug("subscription rejected") 111 w.WriteHeader(http.StatusBadRequest) 112 return 113 } 114 } 115 116 w.WriteHeader(http.StatusOK) 117 flusher.Flush() 118 119 conn.Established() 120 sessionCtx.Debug("session established") 121 122 shutdownReceived := false 123 124 for { 125 select { 126 case <-shutdownCtx.Done(): 127 if !shutdownReceived { 128 shutdownReceived = true 129 sessionCtx.Debug("server shutdown") 130 session.DisconnectWithMessage( 131 &common.DisconnectMessage{Type: "disconnect", Reason: common.SERVER_RESTART_REASON, Reconnect: true}, 132 common.SERVER_RESTART_REASON, 133 ) 134 } 135 case <-r.Context().Done(): 136 sessionCtx.Debug("request terminated") 137 session.DisconnectNow("Closed", ws.CloseNormalClosure) 138 return 139 case <-conn.Context().Done(): 140 sessionCtx.Debug("session completed") 141 return 142 } 143 } 144 }) 145 }