github.com/anycable/anycable-go@v1.5.1/sse/sse.go (about) 1 package sse 2 3 import ( 4 "bytes" 5 "encoding/json" 6 "errors" 7 "io" 8 "net/http" 9 "strconv" 10 "strings" 11 12 "github.com/anycable/anycable-go/common" 13 "github.com/anycable/anycable-go/node" 14 "github.com/anycable/anycable-go/server" 15 "github.com/anycable/anycable-go/utils" 16 "github.com/joomcode/errorx" 17 ) 18 19 const ( 20 signedStreamParam = "signed_stream" 21 publicStreamParam = "stream" 22 signedStreamChannel = "$pubsub" 23 turboStreamsParam = "turbo_signed_stream_name" 24 turboStreamsChannel = "Turbo::StreamsChannel" 25 historySinceParam = "history_since" 26 ) 27 28 func NewSSESession(n *node.Node, w http.ResponseWriter, r *http.Request, info *server.RequestInfo) (*node.Session, error) { 29 conn := NewConnection(w) 30 31 unwrapData := r.Method == http.MethodGet 32 33 session := node.NewSession(n, conn, info.URL, info.Headers, info.UID, node.WithEncoder(&Encoder{unwrapData})) 34 res, err := n.Authenticate(session) 35 36 if err != nil { 37 return nil, err 38 } 39 40 if res.Status == common.SUCCESS { 41 return session, nil 42 } else { 43 return nil, nil 44 } 45 } 46 47 // Extract channel identifier or name from the request and build a subscribe command payload 48 func subscribeCommandsFromRequest(r *http.Request) ([]*common.Message, error) { 49 if r.Method == http.MethodGet { 50 cmd, err := subscribeCommandFromGetRequest(r) 51 52 if err != nil { 53 return nil, err 54 } 55 56 if cmd == nil { 57 return nil, errors.New("no channel provided") 58 } 59 60 return []*common.Message{cmd}, nil 61 62 } else { 63 return subscribeCommandFromPostRequest(r) 64 } 65 } 66 67 func subscribeCommandFromGetRequest(r *http.Request) (*common.Message, error) { 68 msg := &common.Message{ 69 Command: "subscribe", 70 } 71 72 // First, check if identifier is provided 73 identifier := r.URL.Query().Get("identifier") 74 75 if identifier == "" { 76 channel := r.URL.Query().Get("channel") 77 78 if channel != "" { 79 identifier = string(utils.ToJSON(map[string]string{"channel": channel})) 80 } 81 } 82 83 // Check for public stream name 84 if identifier == "" { 85 stream := r.URL.Query().Get(publicStreamParam) 86 87 if stream != "" { 88 identifier = string(utils.ToJSON(map[string]string{ 89 "channel": signedStreamChannel, 90 "stream_name": stream, 91 })) 92 } 93 } 94 95 // Check for signed stream name 96 if identifier == "" { 97 stream := r.URL.Query().Get(signedStreamParam) 98 99 if stream != "" { 100 identifier = string(utils.ToJSON(map[string]string{ 101 "channel": signedStreamChannel, 102 "signed_stream_name": stream, 103 })) 104 } 105 } 106 107 // Then, check for Turbo Streams name 108 if identifier == "" { 109 stream := r.URL.Query().Get(turboStreamsParam) 110 111 if stream != "" { 112 identifier = string(utils.ToJSON(map[string]string{ 113 "channel": turboStreamsChannel, 114 "signed_stream_name": stream, 115 })) 116 } 117 } 118 119 if identifier == "" { 120 return nil, nil 121 } 122 123 msg.Identifier = identifier 124 125 if lastId := r.Header.Get("last-event-id"); lastId != "" { 126 offsetParts := strings.SplitN(lastId, lastIdDelimeter, 3) 127 128 if len(offsetParts) == 3 { 129 offset, err := strconv.ParseUint(offsetParts[0], 10, 64) 130 131 if err != nil { 132 return nil, errorx.Decorate(err, "failed to parse last event id: %s", lastId) 133 } 134 135 epoch := offsetParts[1] 136 stream := offsetParts[2] 137 138 streams := make(map[string]common.HistoryPosition) 139 140 streams[stream] = common.HistoryPosition{Offset: offset, Epoch: epoch} 141 142 msg.History = common.HistoryRequest{ 143 Streams: streams, 144 } 145 } 146 } 147 148 if since := r.URL.Query().Get(historySinceParam); since != "" { 149 sinceInt, err := strconv.ParseInt(since, 10, 64) 150 if err != nil { 151 return nil, errorx.Decorate(err, "failed to parse history since value: %s", since) 152 } 153 154 msg.History.Since = sinceInt 155 } 156 157 return msg, nil 158 } 159 160 func subscribeCommandFromPostRequest(r *http.Request) ([]*common.Message, error) { 161 var cmds []*common.Message 162 163 // Read commands (if any) 164 if r.Body != nil { 165 r.Body = http.MaxBytesReader(nil, r.Body, int64(defaultMaxBodySize)) 166 requestData, err := io.ReadAll(r.Body) 167 168 if err != nil { 169 return nil, err 170 } 171 172 if len(requestData) > 0 { 173 lines := bytes.Split(requestData, []byte("\n")) 174 175 for _, line := range lines { 176 if len(line) > 0 { 177 var command common.Message 178 err := json.Unmarshal(line, &command) 179 180 if err != nil { 181 return nil, errorx.Decorate(err, "failed to parse command: %v", command) 182 } 183 184 cmds = append(cmds, &command) 185 } 186 } 187 } 188 } 189 190 return cmds, nil 191 }