github.com/anycable/anycable-go@v1.5.1/sse/connection.go (about)

     1  package sse
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"errors"
     7  	"net"
     8  	"net/http"
     9  	"sync"
    10  	"time"
    11  
    12  	"github.com/anycable/anycable-go/node"
    13  )
    14  
    15  type Connection struct {
    16  	writer http.ResponseWriter
    17  
    18  	ctx      context.Context
    19  	cancelFn context.CancelFunc
    20  
    21  	done        bool
    22  	established bool
    23  	// Backlog is used to store messages sent to client before connection is established
    24  	backlog *bytes.Buffer
    25  
    26  	mu sync.Mutex
    27  }
    28  
    29  var _ node.Connection = (*Connection)(nil)
    30  
    31  // NewConnection creates a new long-polling connection wrapper
    32  func NewConnection(w http.ResponseWriter) *Connection {
    33  	ctx, cancel := context.WithCancel(context.Background())
    34  	return &Connection{
    35  		writer:   w,
    36  		backlog:  bytes.NewBuffer(nil),
    37  		ctx:      ctx,
    38  		cancelFn: cancel,
    39  	}
    40  }
    41  
    42  func (c *Connection) Read() ([]byte, error) {
    43  	return nil, errors.New("unsupported")
    44  }
    45  
    46  func (c *Connection) Write(msg []byte, deadline time.Time) error {
    47  	c.mu.Lock()
    48  	defer c.mu.Unlock()
    49  
    50  	if c.done {
    51  		return nil
    52  	}
    53  
    54  	if !c.established {
    55  		c.backlog.Write(msg)
    56  		c.backlog.Write([]byte("\n\n"))
    57  		return nil
    58  	}
    59  
    60  	_, err := c.writer.Write(msg)
    61  
    62  	if err != nil {
    63  		return err
    64  	}
    65  
    66  	_, err = c.writer.Write([]byte("\n\n"))
    67  
    68  	if err != nil {
    69  		return err
    70  	}
    71  
    72  	c.writer.(http.Flusher).Flush()
    73  
    74  	return nil
    75  }
    76  
    77  func (c *Connection) WriteBinary(msg []byte, deadline time.Time) error {
    78  	return errors.New("unsupported")
    79  }
    80  
    81  func (c *Connection) Context() context.Context {
    82  	return c.ctx
    83  }
    84  
    85  func (c *Connection) Close(code int, reason string) {
    86  	c.mu.Lock()
    87  	defer c.mu.Unlock()
    88  
    89  	if c.done {
    90  		return
    91  	}
    92  
    93  	c.done = true
    94  
    95  	c.cancelFn()
    96  }
    97  
    98  // Mark as closed to avoid writing to closed connection
    99  func (c *Connection) Established() {
   100  	c.mu.Lock()
   101  	defer c.mu.Unlock()
   102  
   103  	c.established = true
   104  
   105  	if c.backlog.Len() > 0 {
   106  		c.writer.Write(c.backlog.Bytes()) // nolint: errcheck
   107  		c.writer.(http.Flusher).Flush()
   108  		c.backlog.Reset()
   109  	}
   110  }
   111  
   112  func (c *Connection) Descriptor() net.Conn {
   113  	return nil
   114  }