github.com/diamondburned/arikawa/v2@v2.1.0/utils/wsutil/conn.go (about)

     1  package wsutil
     2  
     3  import (
     4  	"bytes"
     5  	"compress/zlib"
     6  	"context"
     7  	"io"
     8  	"net/http"
     9  	"strings"
    10  	"time"
    11  
    12  	"github.com/gorilla/websocket"
    13  	"github.com/pkg/errors"
    14  )
    15  
    16  // CopyBufferSize is used for the initial size of the internal WS' buffer. Its
    17  // size is 4KB.
    18  var CopyBufferSize = 4096
    19  
    20  // MaxCapUntilReset determines the maximum capacity before the bytes buffer is
    21  // re-allocated. It is roughly 16KB, quadruple CopyBufferSize.
    22  var MaxCapUntilReset = CopyBufferSize * 4
    23  
    24  // CloseDeadline controls the deadline to wait for sending the Close frame.
    25  var CloseDeadline = time.Second
    26  
    27  // ErrWebsocketClosed is returned if the websocket is already closed.
    28  var ErrWebsocketClosed = errors.New("websocket is closed")
    29  
    30  // Connection is an interface that abstracts around a generic Websocket driver.
    31  // This connection expects the driver to handle compression by itself, including
    32  // modifying the connection URL. The implementation doesn't have to be safe for
    33  // concurrent use.
    34  type Connection interface {
    35  	// Dial dials the address (string). Context needs to be passed in for
    36  	// timeout. This method should also be re-usable after Close is called.
    37  	Dial(context.Context, string) error
    38  
    39  	// Listen returns an event channel that sends over events constantly. It can
    40  	// return nil if there isn't an ongoing connection.
    41  	Listen() <-chan Event
    42  
    43  	// Send allows the caller to send bytes. It does not need to clean itself
    44  	// up on errors, as the Websocket wrapper will do that.
    45  	//
    46  	// If the data is nil, it should send a close frame
    47  	Send(context.Context, []byte) error
    48  
    49  	// Close should close the websocket connection. The underlying connection
    50  	// may be reused, but this Connection instance will be reused with Dial. The
    51  	// Connection must still be reusable even if Close returns an error.
    52  	Close() error
    53  }
    54  
    55  // GracefulCloser is an interface used by Connections that support graceful
    56  // closure of their websocket connection.
    57  type GracefulCloser interface {
    58  	// CloseGracefully sends a close frame and then closes the websocket
    59  	// connection.
    60  	CloseGracefully() error
    61  }
    62  
    63  // Conn is the default Websocket connection. It tries to compresses all payloads
    64  // using zlib.
    65  type Conn struct {
    66  	Dialer websocket.Dialer
    67  	Header http.Header
    68  	Conn   *websocket.Conn
    69  	events chan Event
    70  }
    71  
    72  var _ Connection = (*Conn)(nil)
    73  
    74  // NewConn creates a new default websocket connection with a default dialer.
    75  func NewConn() *Conn {
    76  	return NewConnWithDialer(websocket.Dialer{
    77  		Proxy:             http.ProxyFromEnvironment,
    78  		HandshakeTimeout:  WSTimeout,
    79  		ReadBufferSize:    CopyBufferSize,
    80  		WriteBufferSize:   CopyBufferSize,
    81  		EnableCompression: true,
    82  	})
    83  }
    84  
    85  // NewConnWithDialer creates a new default websocket connection with a custom
    86  // dialer.
    87  func NewConnWithDialer(dialer websocket.Dialer) *Conn {
    88  	return &Conn{
    89  		Dialer: dialer,
    90  		Header: http.Header{
    91  			"Accept-Encoding": {"zlib"},
    92  		},
    93  	}
    94  }
    95  
    96  func (c *Conn) Dial(ctx context.Context, addr string) (err error) {
    97  	// BUG which prevents stream compression.
    98  	// See https://github.com/golang/go/issues/31514.
    99  
   100  	c.Conn, _, err = c.Dialer.DialContext(ctx, addr, c.Header)
   101  	if err != nil {
   102  		return errors.Wrap(err, "failed to dial WS")
   103  	}
   104  
   105  	// Reset the deadline.
   106  	c.Conn.SetWriteDeadline(resetDeadline)
   107  
   108  	c.events = make(chan Event, WSBuffer)
   109  	go startReadLoop(c.Conn, c.events)
   110  
   111  	return err
   112  }
   113  
   114  // Listen returns an event channel if there is a connection associated with it.
   115  // It returns nil if there is none.
   116  func (c *Conn) Listen() <-chan Event {
   117  	return c.events
   118  }
   119  
   120  // resetDeadline is used to reset the write deadline after using the context's.
   121  var resetDeadline = time.Time{}
   122  
   123  func (c *Conn) Send(ctx context.Context, b []byte) error {
   124  	d, ok := ctx.Deadline()
   125  	if ok {
   126  		c.Conn.SetWriteDeadline(d)
   127  		defer c.Conn.SetWriteDeadline(resetDeadline)
   128  	}
   129  
   130  	if err := c.Conn.WriteMessage(websocket.TextMessage, b); err != nil {
   131  		return err
   132  	}
   133  
   134  	return nil
   135  }
   136  
   137  func (c *Conn) Close() error {
   138  	WSDebug("Conn: Close is called; shutting down the Websocket connection.")
   139  
   140  	// Have a deadline before closing.
   141  	var deadline = time.Now().Add(5 * time.Second)
   142  	c.Conn.SetWriteDeadline(deadline)
   143  
   144  	// Close the WS.
   145  	err := c.Conn.Close()
   146  
   147  	c.Conn.SetWriteDeadline(resetDeadline)
   148  
   149  	WSDebug("Conn: Websocket closed; error:", err)
   150  	WSDebug("Conn: Flushing events...")
   151  
   152  	// Flush all events before closing the channel. This will return as soon as
   153  	// c.events is closed, or after closed.
   154  	for range c.events {
   155  	}
   156  
   157  	WSDebug("Flushed events.")
   158  
   159  	return err
   160  }
   161  
   162  func (c *Conn) CloseGracefully() error {
   163  	WSDebug("Conn: CloseGracefully is called; sending close frame.")
   164  
   165  	c.Conn.SetWriteDeadline(time.Now().Add(CloseDeadline))
   166  
   167  	err := c.Conn.WriteMessage(websocket.CloseMessage,
   168  		websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
   169  	if err != nil {
   170  		WSError(err)
   171  	}
   172  
   173  	WSDebug("Conn: Close frame sent; error:", err)
   174  
   175  	return c.Close()
   176  }
   177  
   178  // loopState is a thread-unsafe disposable state container for the read loop.
   179  // It's made to completely separate the read loop of any synchronization that
   180  // doesn't involve the websocket connection itself.
   181  type loopState struct {
   182  	conn *websocket.Conn
   183  	zlib io.ReadCloser
   184  	buf  bytes.Buffer
   185  }
   186  
   187  func startReadLoop(conn *websocket.Conn, eventCh chan<- Event) {
   188  	// Clean up the events channel in the end.
   189  	defer close(eventCh)
   190  
   191  	// Allocate the read loop its own private resources.
   192  	state := loopState{conn: conn}
   193  	state.buf.Grow(CopyBufferSize)
   194  
   195  	for {
   196  		b, err := state.handle()
   197  		if err != nil {
   198  			WSDebug("Conn: Read error:", err)
   199  
   200  			// Is the error an EOF?
   201  			if errors.Is(err, io.EOF) {
   202  				// Yes it is, exit.
   203  				return
   204  			}
   205  
   206  			// Is the error an intentional close call? Go 1.16 exposes
   207  			// ErrClosing, but we have to do this for now.
   208  			if strings.HasSuffix(err.Error(), "use of closed network connection") {
   209  				return
   210  			}
   211  
   212  			// Unusual error; log and exit:
   213  			eventCh <- Event{nil, errors.Wrap(err, "WS error")}
   214  			return
   215  		}
   216  
   217  		// If the payload length is 0, skip it.
   218  		if len(b) == 0 {
   219  			continue
   220  		}
   221  
   222  		eventCh <- Event{b, nil}
   223  	}
   224  }
   225  
   226  func (state *loopState) handle() ([]byte, error) {
   227  	// skip message type
   228  	t, r, err := state.conn.NextReader()
   229  	if err != nil {
   230  		return nil, err
   231  	}
   232  
   233  	if t == websocket.BinaryMessage {
   234  		// Probably a zlib payload.
   235  
   236  		if state.zlib == nil {
   237  			z, err := zlib.NewReader(r)
   238  			if err != nil {
   239  				return nil, errors.Wrap(err, "failed to create a zlib reader")
   240  			}
   241  			state.zlib = z
   242  		} else {
   243  			if err := state.zlib.(zlib.Resetter).Reset(r, nil); err != nil {
   244  				return nil, errors.Wrap(err, "failed to reset zlib reader")
   245  			}
   246  		}
   247  
   248  		defer state.zlib.Close()
   249  		r = state.zlib
   250  	}
   251  
   252  	return state.readAll(r)
   253  }
   254  
   255  // readAll reads bytes into an existing buffer, copy it over, then wipe the old
   256  // buffer.
   257  func (state *loopState) readAll(r io.Reader) ([]byte, error) {
   258  	defer state.buf.Reset()
   259  
   260  	if _, err := state.buf.ReadFrom(r); err != nil {
   261  		return nil, err
   262  	}
   263  
   264  	// Copy the bytes so we could empty the buffer for reuse.
   265  	cpy := make([]byte, state.buf.Len())
   266  	copy(cpy, state.buf.Bytes())
   267  
   268  	// If the buffer's capacity is over the limit, then re-allocate a new one.
   269  	if state.buf.Cap() > MaxCapUntilReset {
   270  		state.buf = bytes.Buffer{}
   271  		state.buf.Grow(CopyBufferSize)
   272  	}
   273  
   274  	return cpy, nil
   275  }