github.com/sohaha/zlsgo@v1.7.13-0.20240501141223-10dd1a906f76/znet/sse.go (about)

     1  package znet
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"errors"
     7  	"io"
     8  	"net/http"
     9  	"strconv"
    10  	"time"
    11  
    12  	"github.com/sohaha/zlsgo/zstring"
    13  )
    14  
    15  type SSE struct {
    16  	ctx       context.Context
    17  	events    chan *sseEvent
    18  	net       *Context
    19  	option    *SSEOption
    20  	ctxCancel context.CancelFunc
    21  	flush     func()
    22  	lastID    string
    23  	method    string
    24  	Comment   []byte
    25  }
    26  
    27  type sseEvent struct {
    28  	ID      string
    29  	Event   string
    30  	Comment string
    31  	Data    []byte
    32  }
    33  
    34  func (s *SSE) LastEventID() string {
    35  	return s.lastID
    36  }
    37  
    38  func (s *SSE) Done() <-chan struct{} {
    39  	return s.ctx.Done()
    40  }
    41  
    42  func (s *SSE) Stop() {
    43  	s.ctxCancel()
    44  }
    45  
    46  func (s *SSE) sendComment() {
    47  	s.events <- &sseEvent{
    48  		Comment: "ping",
    49  	}
    50  }
    51  
    52  func (s *SSE) Send(id string, data string, event ...string) error {
    53  	return s.SendByte(id, zstring.String2Bytes(data), event...)
    54  }
    55  
    56  func (s *SSE) Push() {
    57  	w := s.net.Writer
    58  	r := s.net.Request
    59  
    60  	s.net.Abort(http.StatusOK)
    61  	s.net.write()
    62  	s.flush()
    63  
    64  	heartbeatsTime := s.option.HeartbeatsTime
    65  	if heartbeatsTime == 0 {
    66  		heartbeatsTime = 15000
    67  	}
    68  	ticker := time.NewTicker(time.Duration(heartbeatsTime) * time.Millisecond)
    69  
    70  	defer ticker.Stop()
    71  
    72  	b := zstring.Buffer(7)
    73  sseFor:
    74  	for {
    75  		select {
    76  		case <-ticker.C:
    77  			go s.sendComment()
    78  		case <-r.Context().Done():
    79  			s.ctxCancel()
    80  			break sseFor
    81  		case <-s.ctx.Done():
    82  			break sseFor
    83  		case ev := <-s.events:
    84  			if len(ev.Data) > 0 {
    85  				if ev.ID != "" {
    86  					b.WriteString("id: ")
    87  					b.WriteString(ev.ID)
    88  					b.WriteString("\n")
    89  				}
    90  
    91  				if bytes.HasPrefix(ev.Data, []byte(":")) {
    92  					b.Write(ev.Data)
    93  					b.WriteString("\n")
    94  				} else {
    95  					if bytes.IndexByte(ev.Data, '\n') > 0 {
    96  						for _, v := range bytes.Split(ev.Data, []byte("\n")) {
    97  							b.WriteString("data: ")
    98  							b.Write(v)
    99  							b.WriteString("\n")
   100  						}
   101  					} else {
   102  						b.WriteString("data: ")
   103  						b.Write(ev.Data)
   104  						b.WriteString("\n")
   105  					}
   106  				}
   107  
   108  				if len(ev.Event) > 0 {
   109  					b.WriteString("event: ")
   110  					b.WriteString(ev.Event)
   111  					b.WriteString("\n")
   112  				}
   113  
   114  				if s.option.RetryTime > 0 {
   115  					b.WriteString("retry: ")
   116  					b.WriteString(strconv.Itoa(s.option.RetryTime))
   117  					b.WriteString("\n")
   118  				}
   119  			}
   120  
   121  			if len(ev.Comment) > 0 {
   122  				b.WriteString(": ")
   123  				b.WriteString(ev.Comment)
   124  				b.WriteString("\n")
   125  			}
   126  
   127  			b.WriteString("\n")
   128  
   129  			data := zstring.String2Bytes(b.String())
   130  			_, _ = w.Write(data)
   131  			s.flush()
   132  
   133  			b.Reset()
   134  			b.Grow(7)
   135  		}
   136  	}
   137  }
   138  
   139  func (s *SSE) SendByte(id string, data []byte, event ...string) error {
   140  	if s.ctx.Err() != nil {
   141  		return errors.New("client has been closed")
   142  	}
   143  
   144  	ev := &sseEvent{
   145  		ID:   id,
   146  		Data: data,
   147  	}
   148  
   149  	if len(event) > 0 {
   150  		ev.Event = event[0]
   151  	}
   152  
   153  	s.events <- ev
   154  
   155  	return nil
   156  }
   157  
   158  type SSEOption struct {
   159  	RetryTime      int
   160  	HeartbeatsTime int
   161  }
   162  
   163  func NewSSE(c *Context, opts ...func(lastID string, opts *SSEOption)) *SSE {
   164  	id := c.GetHeader("Last-Event-ID")
   165  	ctx, cancel := context.WithCancel(context.TODO())
   166  	s := &SSE{
   167  		lastID:    id,
   168  		events:    make(chan *sseEvent, 1),
   169  		net:       c,
   170  		ctx:       ctx,
   171  		ctxCancel: cancel,
   172  		option: &SSEOption{
   173  			// RetryTime:      3000,
   174  			HeartbeatsTime: 15000,
   175  		},
   176  	}
   177  
   178  	for _, opt := range opts {
   179  		opt(id, s.option)
   180  	}
   181  
   182  	flusher, _ := s.net.Writer.(http.Flusher)
   183  
   184  	s.flush = func() {
   185  		if c.Request.Context().Err() != nil {
   186  			return
   187  		}
   188  		flusher.Flush()
   189  	}
   190  
   191  	s.net.SetHeader("Content-Type", "text/event-stream")
   192  	s.net.SetHeader("Cache-Control", "no-cache")
   193  	s.net.SetHeader("Connection", "keep-alive")
   194  	c.prevData.Code.Store(http.StatusNoContent)
   195  	s.net.Engine.shutdowns = append(s.net.Engine.shutdowns, func() {
   196  		s.Stop()
   197  	})
   198  	return s
   199  }
   200  
   201  func (c *Context) Stream(step func(w io.Writer) bool) bool {
   202  	w := c.Writer
   203  	flusher, _ := w.(http.Flusher)
   204  	c.write()
   205  	for {
   206  		if c.stopHandle.Load() {
   207  			return false
   208  		}
   209  		keepOpen := step(w)
   210  		flusher.Flush()
   211  		if !keepOpen {
   212  			return false
   213  		}
   214  	}
   215  }