github.com/diamondburned/arikawa@v1.3.14/utils/wsutil/conn.go (about)

     1  package wsutil
     2  
     3  import (
     4  	"bytes"
     5  	"compress/zlib"
     6  	"context"
     7  	"io"
     8  	"net/http"
     9  	"strings"
    10  	"time"
    11  
    12  	"github.com/gorilla/websocket"
    13  	"github.com/pkg/errors"
    14  )
    15  
    16  // CopyBufferSize is used for the initial size of the internal WS' buffer. Its
    17  // size is 4KB.
    18  var CopyBufferSize = 4096
    19  
    20  // MaxCapUntilReset determines the maximum capacity before the bytes buffer is
    21  // re-allocated. It is roughly 16KB, quadruple CopyBufferSize.
    22  var MaxCapUntilReset = CopyBufferSize * 4
    23  
    24  // CloseDeadline controls the deadline to wait for sending the Close frame.
    25  var CloseDeadline = time.Second
    26  
    27  // ErrWebsocketClosed is returned if the websocket is already closed.
    28  var ErrWebsocketClosed = errors.New("websocket is closed")
    29  
    30  // Connection is an interface that abstracts around a generic Websocket driver.
    31  // This connection expects the driver to handle compression by itself, including
    32  // modifying the connection URL. The implementation doesn't have to be safe for
    33  // concurrent use.
    34  type Connection interface {
    35  	// Dial dials the address (string). Context needs to be passed in for
    36  	// timeout. This method should also be re-usable after Close is called.
    37  	Dial(context.Context, string) error
    38  
    39  	// Listen returns an event channel that sends over events constantly. It can
    40  	// return nil if there isn't an ongoing connection.
    41  	Listen() <-chan Event
    42  
    43  	// Send allows the caller to send bytes. It does not need to clean itself
    44  	// up on errors, as the Websocket wrapper will do that.
    45  	Send(context.Context, []byte) error
    46  
    47  	// Close should close the websocket connection. The underlying connection
    48  	// may be reused, but this Connection instance will be reused with Dial. The
    49  	// Connection must still be reusable even if Close returns an error.
    50  	Close() error
    51  }
    52  
    53  // Conn is the default Websocket connection. It tries to compresses all payloads
    54  // using zlib.
    55  type Conn struct {
    56  	Dialer websocket.Dialer
    57  	Header http.Header
    58  	Conn   *websocket.Conn
    59  	events chan Event
    60  }
    61  
    62  var _ Connection = (*Conn)(nil)
    63  
    64  // NewConn creates a new default websocket connection with a default dialer.
    65  func NewConn() *Conn {
    66  	return NewConnWithDialer(websocket.Dialer{
    67  		Proxy:             http.ProxyFromEnvironment,
    68  		HandshakeTimeout:  WSTimeout,
    69  		ReadBufferSize:    CopyBufferSize,
    70  		WriteBufferSize:   CopyBufferSize,
    71  		EnableCompression: true,
    72  	})
    73  }
    74  
    75  // NewConn creates a new default websocket connection with a custom dialer.
    76  func NewConnWithDialer(dialer websocket.Dialer) *Conn {
    77  	return &Conn{
    78  		Dialer: dialer,
    79  		Header: http.Header{
    80  			"Accept-Encoding": {"zlib"},
    81  		},
    82  	}
    83  }
    84  
    85  func (c *Conn) Dial(ctx context.Context, addr string) (err error) {
    86  	// BUG which prevents stream compression.
    87  	// See https://github.com/golang/go/issues/31514.
    88  
    89  	c.Conn, _, err = c.Dialer.DialContext(ctx, addr, c.Header)
    90  	if err != nil {
    91  		return errors.Wrap(err, "failed to dial WS")
    92  	}
    93  
    94  	// Reset the deadline.
    95  	c.Conn.SetWriteDeadline(resetDeadline)
    96  
    97  	c.events = make(chan Event, WSBuffer)
    98  	go startReadLoop(c.Conn, c.events)
    99  
   100  	return err
   101  }
   102  
   103  // Listen returns an event channel if there is a connection associated with it.
   104  // It returns nil if there is none.
   105  func (c *Conn) Listen() <-chan Event {
   106  	return c.events
   107  }
   108  
   109  // resetDeadline is used to reset the write deadline after using the context's.
   110  var resetDeadline = time.Time{}
   111  
   112  func (c *Conn) Send(ctx context.Context, b []byte) error {
   113  	d, ok := ctx.Deadline()
   114  	if ok {
   115  		c.Conn.SetWriteDeadline(d)
   116  		defer c.Conn.SetWriteDeadline(resetDeadline)
   117  	}
   118  
   119  	if err := c.Conn.WriteMessage(websocket.TextMessage, b); err != nil {
   120  		return err
   121  	}
   122  
   123  	return nil
   124  }
   125  
   126  func (c *Conn) Close() error {
   127  	WSDebug("Conn: Close is called; shutting down the Websocket connection.")
   128  
   129  	// Have a deadline before closing.
   130  	var deadline = time.Now().Add(5 * time.Second)
   131  	c.Conn.SetWriteDeadline(deadline)
   132  
   133  	// Close the WS.
   134  	err := c.Conn.Close()
   135  
   136  	c.Conn.SetWriteDeadline(resetDeadline)
   137  
   138  	WSDebug("Conn: Websocket closed; error:", err)
   139  	WSDebug("Conn: Flusing events...")
   140  
   141  	// Flush all events before closing the channel. This will return as soon as
   142  	// c.events is closed, or after closed.
   143  	for range c.events {
   144  	}
   145  
   146  	WSDebug("Flushed events.")
   147  
   148  	return err
   149  }
   150  
   151  // loopState is a thread-unsafe disposable state container for the read loop.
   152  // It's made to completely separate the read loop of any synchronization that
   153  // doesn't involve the websocket connection itself.
   154  type loopState struct {
   155  	conn *websocket.Conn
   156  	zlib io.ReadCloser
   157  	buf  bytes.Buffer
   158  }
   159  
   160  func startReadLoop(conn *websocket.Conn, eventCh chan<- Event) {
   161  	// Clean up the events channel in the end.
   162  	defer close(eventCh)
   163  
   164  	// Allocate the read loop its own private resources.
   165  	state := loopState{conn: conn}
   166  	state.buf.Grow(CopyBufferSize)
   167  
   168  	for {
   169  		b, err := state.handle()
   170  		if err != nil {
   171  			WSDebug("Conn: Read error:", err)
   172  
   173  			// Is the error an EOF?
   174  			if errors.Is(err, io.EOF) {
   175  				// Yes it is, exit.
   176  				return
   177  			}
   178  
   179  			// Is the error an intentional close call? Go 1.16 exposes
   180  			// ErrClosing, but we have to do this for now.
   181  			if strings.HasSuffix(err.Error(), "use of closed network connection") {
   182  				return
   183  			}
   184  
   185  			// Check if the error is a normal one:
   186  			if websocket.IsCloseError(err, websocket.CloseNormalClosure) {
   187  				return
   188  			}
   189  
   190  			// Unusual error; log and exit:
   191  			eventCh <- Event{nil, errors.Wrap(err, "WS error")}
   192  			return
   193  		}
   194  
   195  		// If the payload length is 0, skip it.
   196  		if len(b) == 0 {
   197  			continue
   198  		}
   199  
   200  		eventCh <- Event{b, nil}
   201  	}
   202  }
   203  
   204  func (state *loopState) handle() ([]byte, error) {
   205  	// skip message type
   206  	t, r, err := state.conn.NextReader()
   207  	if err != nil {
   208  		return nil, err
   209  	}
   210  
   211  	if t == websocket.BinaryMessage {
   212  		// Probably a zlib payload.
   213  
   214  		if state.zlib == nil {
   215  			z, err := zlib.NewReader(r)
   216  			if err != nil {
   217  				return nil, errors.Wrap(err, "failed to create a zlib reader")
   218  			}
   219  			state.zlib = z
   220  		} else {
   221  			if err := state.zlib.(zlib.Resetter).Reset(r, nil); err != nil {
   222  				return nil, errors.Wrap(err, "failed to reset zlib reader")
   223  			}
   224  		}
   225  
   226  		defer state.zlib.Close()
   227  		r = state.zlib
   228  	}
   229  
   230  	return state.readAll(r)
   231  }
   232  
   233  // readAll reads bytes into an existing buffer, copy it over, then wipe the old
   234  // buffer.
   235  func (state *loopState) readAll(r io.Reader) ([]byte, error) {
   236  	defer state.buf.Reset()
   237  
   238  	if _, err := state.buf.ReadFrom(r); err != nil {
   239  		return nil, err
   240  	}
   241  
   242  	// Copy the bytes so we could empty the buffer for reuse.
   243  	cpy := make([]byte, state.buf.Len())
   244  	copy(cpy, state.buf.Bytes())
   245  
   246  	// If the buffer's capacity is over the limit, then re-allocate a new one.
   247  	if state.buf.Cap() > MaxCapUntilReset {
   248  		state.buf = bytes.Buffer{}
   249  		state.buf.Grow(CopyBufferSize)
   250  	}
   251  
   252  	return cpy, nil
   253  }