github.com/sohaha/zlsgo@v1.7.13-0.20240501141223-10dd1a906f76/zhttp/sse.go (about)

     1  package zhttp
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"net/http"
     7  	"strconv"
     8  	"strings"
     9  	"time"
    10  
    11  	"github.com/sohaha/zlsgo/zerror"
    12  	"github.com/sohaha/zlsgo/zstring"
    13  )
    14  
    15  type (
    16  	SSEEngine struct {
    17  		ctx          context.Context
    18  		eventCh      chan *SSEEvent
    19  		errCh        chan error
    20  		ctxCancel    context.CancelFunc
    21  		verifyHeader func(http.Header) bool
    22  		option       SSEOption
    23  		readyState   int
    24  	}
    25  
    26  	SSEEvent struct {
    27  		ID        string
    28  		Event     string
    29  		Undefined []byte
    30  		Data      []byte
    31  	}
    32  )
    33  
    34  var (
    35  	delim   = []byte{':'} // []byte{':', ' '}
    36  	ping    = []byte("ping")
    37  	dataEnd = byte('\n')
    38  )
    39  
    40  func (sse *SSEEngine) Event() <-chan *SSEEvent {
    41  	return sse.eventCh
    42  }
    43  
    44  func (sse *SSEEngine) Close() {
    45  	sse.ctxCancel()
    46  }
    47  
    48  func (sse *SSEEngine) Done() <-chan struct{} {
    49  	return sse.ctx.Done()
    50  }
    51  
    52  func (sse *SSEEngine) Error() <-chan error {
    53  	return sse.errCh
    54  }
    55  
    56  func (sse *SSEEngine) VerifyHeader(fn func(http.Header) bool) {
    57  	sse.verifyHeader = fn
    58  }
    59  
    60  func (sse *SSEEngine) OnMessage(fn func(*SSEEvent)) (<-chan struct{}, error) {
    61  	done := make(chan struct{}, 1)
    62  	select {
    63  	case <-sse.Done():
    64  		done <- struct{}{}
    65  		return done, nil
    66  	case e := <-sse.Error():
    67  		done <- struct{}{}
    68  		return done, e
    69  	case v := <-sse.Event():
    70  		go func() {
    71  			fn(v)
    72  			for {
    73  				select {
    74  				case <-sse.Done():
    75  					done <- struct{}{}
    76  					return
    77  				case <-sse.Error():
    78  					done <- struct{}{}
    79  					return
    80  				case v := <-sse.Event():
    81  					fn(v)
    82  				}
    83  			}
    84  		}()
    85  
    86  		return done, nil
    87  	}
    88  }
    89  
    90  func SSE(url string, v ...interface{}) *SSEEngine {
    91  	sse, err := std.SSE(url, nil, v...)
    92  	if err != nil {
    93  		sse.errCh <- err
    94  	}
    95  	return sse
    96  }
    97  
    98  func (e *Engine) sseReq(method, url string, v ...interface{}) (*Res, error) {
    99  	r, err := e.Do(method, url, v...)
   100  	if err != nil {
   101  		return nil, err
   102  	}
   103  	statusCode := r.resp.StatusCode
   104  	if statusCode == http.StatusNoContent {
   105  		return nil, nil
   106  	}
   107  
   108  	if statusCode != http.StatusOK {
   109  		return nil, zerror.With(zerror.New(zerror.ErrCode(statusCode), r.String()), "status code is "+strconv.Itoa(statusCode))
   110  	}
   111  	return r, nil
   112  }
   113  
   114  type SSEOption struct {
   115  	Method   string
   116  	RetryNum int
   117  }
   118  
   119  func (e *Engine) SSE(url string, opt func(*SSEOption), v ...interface{}) (*SSEEngine, error) {
   120  	var (
   121  		retry     = 3000
   122  		currEvent = &SSEEvent{}
   123  	)
   124  	o := SSEOption{
   125  		Method:   "POST",
   126  		RetryNum: -1,
   127  	}
   128  	if opt != nil {
   129  		opt(&o)
   130  	}
   131  	ctx, cancel := context.WithCancel(context.TODO())
   132  	sse := &SSEEngine{
   133  		readyState: 0,
   134  		ctx:        ctx,
   135  		option:     o,
   136  		ctxCancel:  cancel,
   137  		eventCh:    make(chan *SSEEvent),
   138  		errCh:      make(chan error),
   139  		verifyHeader: func(h http.Header) bool {
   140  			return strings.Contains(h.Get("Content-Type"), "text/event-stream")
   141  		},
   142  	}
   143  
   144  	lastID := ""
   145  	data := append(v, Header{"Accept": "text/event-stream", "Connection": "keep-alive"}, sse.ctx)
   146  
   147  	r, err := e.sseReq(sse.option.Method, url, data...)
   148  	if err != nil {
   149  		return sse, err
   150  	}
   151  
   152  	go func() {
   153  		for {
   154  			if sse.ctx.Err() != nil {
   155  				break
   156  			}
   157  			if err == nil {
   158  				if r != nil {
   159  					if sse.verifyHeader != nil && !sse.verifyHeader(r.Response().Header) {
   160  						sse.eventCh <- &SSEEvent{
   161  							Undefined: r.Bytes(),
   162  						}
   163  						r = nil
   164  					}
   165  				}
   166  
   167  				if r == nil {
   168  					sse.readyState = 2
   169  					cancel()
   170  					return
   171  				}
   172  
   173  				sse.readyState = 1
   174  
   175  				isPing := false
   176  				_ = r.Stream(func(line []byte, eof bool) error {
   177  					i := len(line)
   178  					if i == 1 && line[0] == dataEnd {
   179  						if !isPing {
   180  							sse.eventCh <- currEvent
   181  							currEvent = &SSEEvent{}
   182  							isPing = false
   183  						} else {
   184  							currEvent = &SSEEvent{}
   185  						}
   186  
   187  						return nil
   188  					}
   189  
   190  					if i < 2 {
   191  						return nil
   192  					}
   193  
   194  					spl := bytes.SplitN(line, delim, 2)
   195  					if len(spl) < 2 {
   196  						currEvent.Undefined = line
   197  						return nil
   198  					}
   199  
   200  					if len(spl[0]) == 0 {
   201  						isPing = bytes.Equal(ping, bytes.TrimSpace(spl[1]))
   202  						if !isPing {
   203  							currEvent.Undefined = spl[1]
   204  						}
   205  						return nil
   206  					}
   207  
   208  					val := bytes.TrimSuffix(spl[1], []byte{'\n'})
   209  					val = bytes.TrimPrefix(val, []byte{' '})
   210  
   211  					switch zstring.Bytes2String(spl[0]) {
   212  					case "id":
   213  						lastID = zstring.Bytes2String(val)
   214  						currEvent.ID = lastID
   215  					case "event":
   216  						currEvent.Event = zstring.Bytes2String(val)
   217  					case "data":
   218  						if len(currEvent.Data) > 0 {
   219  							sse.eventCh <- currEvent
   220  							currEvent = &SSEEvent{}
   221  							isPing = false
   222  						}
   223  						currEvent.Data = append(currEvent.Data, val...)
   224  					case "retry":
   225  						if t, err := strconv.Atoi(zstring.Bytes2String(val)); err == nil {
   226  							retry = t
   227  						}
   228  					}
   229  					if eof && !isPing {
   230  						sse.eventCh <- currEvent
   231  						currEvent = &SSEEvent{}
   232  					}
   233  					return nil
   234  				})
   235  
   236  				if sse.option.RetryNum >= 0 {
   237  					if sse.option.RetryNum == 0 {
   238  						cancel()
   239  						return
   240  					}
   241  					sse.option.RetryNum--
   242  				}
   243  			} else {
   244  				sse.errCh <- err
   245  			}
   246  
   247  			sse.readyState = 0
   248  			time.Sleep(time.Millisecond * time.Duration(retry))
   249  			ndata := data
   250  			if lastID != "" {
   251  				ndata = append(ndata, Header{"Last-Event-ID": lastID})
   252  			}
   253  			r, err = e.sseReq(sse.option.Method, url, ndata...)
   254  		}
   255  	}()
   256  
   257  	return sse, nil
   258  }