github.com/philippseith/signalr@v0.6.3/serversseconnection.go (about)

     1  package signalr
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"io"
     7  	"net/http"
     8  	"strings"
     9  	"sync"
    10  	"time"
    11  )
    12  
    13  type serverSSEConnection struct {
    14  	ConnectionBase
    15  	mx            sync.Mutex
    16  	postWriting   bool
    17  	postWriter    io.Writer
    18  	postReader    io.Reader
    19  	jobChan       chan []byte
    20  	jobResultChan chan RWJobResult
    21  }
    22  
    23  func newServerSSEConnection(ctx context.Context, connectionID string) (*serverSSEConnection, <-chan []byte, chan RWJobResult, error) {
    24  	s := serverSSEConnection{
    25  		ConnectionBase: *NewConnectionBase(ctx, connectionID),
    26  		jobChan:        make(chan []byte, 1),
    27  		jobResultChan:  make(chan RWJobResult, 1),
    28  	}
    29  	s.postReader, s.postWriter = io.Pipe()
    30  	go func() {
    31  		<-s.Context().Done()
    32  		s.mx.Lock()
    33  		close(s.jobChan)
    34  		s.mx.Unlock()
    35  	}()
    36  	return &s, s.jobChan, s.jobResultChan, nil
    37  }
    38  
    39  func (s *serverSSEConnection) consumeRequest(request *http.Request) int {
    40  	if err := s.Context().Err(); err != nil {
    41  		return http.StatusGone // 410
    42  	}
    43  	s.mx.Lock()
    44  	if s.postWriting {
    45  		s.mx.Unlock()
    46  		return http.StatusConflict // 409
    47  	}
    48  	s.postWriting = true
    49  	s.mx.Unlock()
    50  	defer func() {
    51  		_ = request.Body.Close()
    52  	}()
    53  	body, err := io.ReadAll(request.Body)
    54  	if err != nil {
    55  		return http.StatusBadRequest // 400
    56  	} else if _, err := s.postWriter.Write(body); err != nil {
    57  		return http.StatusInternalServerError // 500
    58  	}
    59  	s.mx.Lock()
    60  	s.postWriting = false
    61  	s.mx.Unlock()
    62  	<-time.After(50 * time.Millisecond)
    63  	return http.StatusOK // 200
    64  }
    65  
    66  func (s *serverSSEConnection) Read(p []byte) (n int, err error) {
    67  	n, err = ReadWriteWithContext(s.Context(),
    68  		func() (int, error) { return s.postReader.Read(p) },
    69  		func() { _, _ = s.postWriter.Write([]byte("\n")) })
    70  	if err != nil {
    71  		err = fmt.Errorf("%T: %w", s, err)
    72  	}
    73  	return n, err
    74  }
    75  
    76  func (s *serverSSEConnection) Write(p []byte) (n int, err error) {
    77  	if err := s.Context().Err(); err != nil {
    78  		return 0, fmt.Errorf("%T: %w", s, s.Context().Err())
    79  	}
    80  	payload := ""
    81  	for _, line := range strings.Split(strings.TrimRight(string(p), "\n"), "\n") {
    82  		payload = payload + "data: " + line + "\n"
    83  	}
    84  	// prevent race with goroutine closing the jobChan
    85  	s.mx.Lock()
    86  	if s.Context().Err() == nil {
    87  		s.jobChan <- []byte(payload + "\n")
    88  	} else {
    89  		return 0, fmt.Errorf("%T: %w", s, s.Context().Err())
    90  	}
    91  	s.mx.Unlock()
    92  	select {
    93  	case <-s.Context().Done():
    94  		return 0, fmt.Errorf("%T: %w", s, s.Context().Err())
    95  	case r := <-s.jobResultChan:
    96  		return r.n, r.err
    97  	}
    98  
    99  }