github.com/anycable/anycable-go@v1.5.1/sse/sse.go (about)

     1  package sse
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/json"
     6  	"errors"
     7  	"io"
     8  	"net/http"
     9  	"strconv"
    10  	"strings"
    11  
    12  	"github.com/anycable/anycable-go/common"
    13  	"github.com/anycable/anycable-go/node"
    14  	"github.com/anycable/anycable-go/server"
    15  	"github.com/anycable/anycable-go/utils"
    16  	"github.com/joomcode/errorx"
    17  )
    18  
    19  const (
    20  	signedStreamParam   = "signed_stream"
    21  	publicStreamParam   = "stream"
    22  	signedStreamChannel = "$pubsub"
    23  	turboStreamsParam   = "turbo_signed_stream_name"
    24  	turboStreamsChannel = "Turbo::StreamsChannel"
    25  	historySinceParam   = "history_since"
    26  )
    27  
    28  func NewSSESession(n *node.Node, w http.ResponseWriter, r *http.Request, info *server.RequestInfo) (*node.Session, error) {
    29  	conn := NewConnection(w)
    30  
    31  	unwrapData := r.Method == http.MethodGet
    32  
    33  	session := node.NewSession(n, conn, info.URL, info.Headers, info.UID, node.WithEncoder(&Encoder{unwrapData}))
    34  	res, err := n.Authenticate(session)
    35  
    36  	if err != nil {
    37  		return nil, err
    38  	}
    39  
    40  	if res.Status == common.SUCCESS {
    41  		return session, nil
    42  	} else {
    43  		return nil, nil
    44  	}
    45  }
    46  
    47  // Extract channel identifier or name from the request and build a subscribe command payload
    48  func subscribeCommandsFromRequest(r *http.Request) ([]*common.Message, error) {
    49  	if r.Method == http.MethodGet {
    50  		cmd, err := subscribeCommandFromGetRequest(r)
    51  
    52  		if err != nil {
    53  			return nil, err
    54  		}
    55  
    56  		if cmd == nil {
    57  			return nil, errors.New("no channel provided")
    58  		}
    59  
    60  		return []*common.Message{cmd}, nil
    61  
    62  	} else {
    63  		return subscribeCommandFromPostRequest(r)
    64  	}
    65  }
    66  
    67  func subscribeCommandFromGetRequest(r *http.Request) (*common.Message, error) {
    68  	msg := &common.Message{
    69  		Command: "subscribe",
    70  	}
    71  
    72  	// First, check if identifier is provided
    73  	identifier := r.URL.Query().Get("identifier")
    74  
    75  	if identifier == "" {
    76  		channel := r.URL.Query().Get("channel")
    77  
    78  		if channel != "" {
    79  			identifier = string(utils.ToJSON(map[string]string{"channel": channel}))
    80  		}
    81  	}
    82  
    83  	// Check for public stream name
    84  	if identifier == "" {
    85  		stream := r.URL.Query().Get(publicStreamParam)
    86  
    87  		if stream != "" {
    88  			identifier = string(utils.ToJSON(map[string]string{
    89  				"channel":     signedStreamChannel,
    90  				"stream_name": stream,
    91  			}))
    92  		}
    93  	}
    94  
    95  	// Check for signed stream name
    96  	if identifier == "" {
    97  		stream := r.URL.Query().Get(signedStreamParam)
    98  
    99  		if stream != "" {
   100  			identifier = string(utils.ToJSON(map[string]string{
   101  				"channel":            signedStreamChannel,
   102  				"signed_stream_name": stream,
   103  			}))
   104  		}
   105  	}
   106  
   107  	// Then, check for Turbo Streams name
   108  	if identifier == "" {
   109  		stream := r.URL.Query().Get(turboStreamsParam)
   110  
   111  		if stream != "" {
   112  			identifier = string(utils.ToJSON(map[string]string{
   113  				"channel":            turboStreamsChannel,
   114  				"signed_stream_name": stream,
   115  			}))
   116  		}
   117  	}
   118  
   119  	if identifier == "" {
   120  		return nil, nil
   121  	}
   122  
   123  	msg.Identifier = identifier
   124  
   125  	if lastId := r.Header.Get("last-event-id"); lastId != "" {
   126  		offsetParts := strings.SplitN(lastId, lastIdDelimeter, 3)
   127  
   128  		if len(offsetParts) == 3 {
   129  			offset, err := strconv.ParseUint(offsetParts[0], 10, 64)
   130  
   131  			if err != nil {
   132  				return nil, errorx.Decorate(err, "failed to parse last event id: %s", lastId)
   133  			}
   134  
   135  			epoch := offsetParts[1]
   136  			stream := offsetParts[2]
   137  
   138  			streams := make(map[string]common.HistoryPosition)
   139  
   140  			streams[stream] = common.HistoryPosition{Offset: offset, Epoch: epoch}
   141  
   142  			msg.History = common.HistoryRequest{
   143  				Streams: streams,
   144  			}
   145  		}
   146  	}
   147  
   148  	if since := r.URL.Query().Get(historySinceParam); since != "" {
   149  		sinceInt, err := strconv.ParseInt(since, 10, 64)
   150  		if err != nil {
   151  			return nil, errorx.Decorate(err, "failed to parse history since value: %s", since)
   152  		}
   153  
   154  		msg.History.Since = sinceInt
   155  	}
   156  
   157  	return msg, nil
   158  }
   159  
   160  func subscribeCommandFromPostRequest(r *http.Request) ([]*common.Message, error) {
   161  	var cmds []*common.Message
   162  
   163  	// Read commands (if any)
   164  	if r.Body != nil {
   165  		r.Body = http.MaxBytesReader(nil, r.Body, int64(defaultMaxBodySize))
   166  		requestData, err := io.ReadAll(r.Body)
   167  
   168  		if err != nil {
   169  			return nil, err
   170  		}
   171  
   172  		if len(requestData) > 0 {
   173  			lines := bytes.Split(requestData, []byte("\n"))
   174  
   175  			for _, line := range lines {
   176  				if len(line) > 0 {
   177  					var command common.Message
   178  					err := json.Unmarshal(line, &command)
   179  
   180  					if err != nil {
   181  						return nil, errorx.Decorate(err, "failed to parse command: %v", command)
   182  					}
   183  
   184  					cmds = append(cmds, &command)
   185  				}
   186  			}
   187  		}
   188  	}
   189  
   190  	return cmds, nil
   191  }