github.com/philippseith/signalr@v0.6.3/serversseconnection.go (about) 1 package signalr 2 3 import ( 4 "context" 5 "fmt" 6 "io" 7 "net/http" 8 "strings" 9 "sync" 10 "time" 11 ) 12 13 type serverSSEConnection struct { 14 ConnectionBase 15 mx sync.Mutex 16 postWriting bool 17 postWriter io.Writer 18 postReader io.Reader 19 jobChan chan []byte 20 jobResultChan chan RWJobResult 21 } 22 23 func newServerSSEConnection(ctx context.Context, connectionID string) (*serverSSEConnection, <-chan []byte, chan RWJobResult, error) { 24 s := serverSSEConnection{ 25 ConnectionBase: *NewConnectionBase(ctx, connectionID), 26 jobChan: make(chan []byte, 1), 27 jobResultChan: make(chan RWJobResult, 1), 28 } 29 s.postReader, s.postWriter = io.Pipe() 30 go func() { 31 <-s.Context().Done() 32 s.mx.Lock() 33 close(s.jobChan) 34 s.mx.Unlock() 35 }() 36 return &s, s.jobChan, s.jobResultChan, nil 37 } 38 39 func (s *serverSSEConnection) consumeRequest(request *http.Request) int { 40 if err := s.Context().Err(); err != nil { 41 return http.StatusGone // 410 42 } 43 s.mx.Lock() 44 if s.postWriting { 45 s.mx.Unlock() 46 return http.StatusConflict // 409 47 } 48 s.postWriting = true 49 s.mx.Unlock() 50 defer func() { 51 _ = request.Body.Close() 52 }() 53 body, err := io.ReadAll(request.Body) 54 if err != nil { 55 return http.StatusBadRequest // 400 56 } else if _, err := s.postWriter.Write(body); err != nil { 57 return http.StatusInternalServerError // 500 58 } 59 s.mx.Lock() 60 s.postWriting = false 61 s.mx.Unlock() 62 <-time.After(50 * time.Millisecond) 63 return http.StatusOK // 200 64 } 65 66 func (s *serverSSEConnection) Read(p []byte) (n int, err error) { 67 n, err = ReadWriteWithContext(s.Context(), 68 func() (int, error) { return s.postReader.Read(p) }, 69 func() { _, _ = s.postWriter.Write([]byte("\n")) }) 70 if err != nil { 71 err = fmt.Errorf("%T: %w", s, err) 72 } 73 return n, err 74 } 75 76 func (s *serverSSEConnection) Write(p []byte) (n int, err error) { 77 if err := s.Context().Err(); err != nil { 78 return 0, fmt.Errorf("%T: %w", s, s.Context().Err()) 79 } 80 payload := "" 81 for _, line := range strings.Split(strings.TrimRight(string(p), "\n"), "\n") { 82 payload = payload + "data: " + line + "\n" 83 } 84 // prevent race with goroutine closing the jobChan 85 s.mx.Lock() 86 if s.Context().Err() == nil { 87 s.jobChan <- []byte(payload + "\n") 88 } else { 89 return 0, fmt.Errorf("%T: %w", s, s.Context().Err()) 90 } 91 s.mx.Unlock() 92 select { 93 case <-s.Context().Done(): 94 return 0, fmt.Errorf("%T: %w", s, s.Context().Err()) 95 case r := <-s.jobResultChan: 96 return r.n, r.err 97 } 98 99 }