github.com/diamondburned/arikawa/v2@v2.1.0/utils/wsutil/ws.go (about)

     1  // Package wsutil provides abstractions around the Websocket, including rate
     2  // limits.
     3  package wsutil
     4  
     5  import (
     6  	"context"
     7  	"log"
     8  	"sync"
     9  	"time"
    10  
    11  	"github.com/pkg/errors"
    12  	"golang.org/x/time/rate"
    13  )
    14  
    15  var (
    16  	// WSTimeout is the timeout for connecting and writing to the Websocket,
    17  	// before Gateway cancels and fails.
    18  	WSTimeout = 30 * time.Second
    19  	// WSBuffer is the size of the Event channel. This has to be at least 1 to
    20  	// make space for the first Event: Ready or Resumed.
    21  	WSBuffer = 10
    22  	// WSError is the default error handler
    23  	WSError = func(err error) { log.Println("Gateway error:", err) }
    24  	// WSDebug is used for extra debug logging. This is expected to behave
    25  	// similarly to log.Println().
    26  	WSDebug = func(v ...interface{}) {}
    27  )
    28  
    29  type Event struct {
    30  	Data []byte
    31  
    32  	// Error is non-nil if Data is nil.
    33  	Error error
    34  }
    35  
    36  // Websocket is a wrapper around a websocket Conn with thread safety and rate
    37  // limiting for sending and throttling.
    38  type Websocket struct {
    39  	mutex  sync.Mutex
    40  	conn   Connection
    41  	addr   string
    42  	closed bool
    43  
    44  	sendLimiter *rate.Limiter
    45  	dialLimiter *rate.Limiter
    46  
    47  	// Constants. These must not be changed after the Websocket instance is used
    48  	// once, as they are not thread-safe.
    49  
    50  	// Timeout for connecting and writing to the Websocket, uses default
    51  	// WSTimeout (global).
    52  	Timeout time.Duration
    53  }
    54  
    55  // New creates a default Websocket with the given address.
    56  func New(addr string) *Websocket {
    57  	return NewCustom(NewConn(), addr)
    58  }
    59  
    60  // NewCustom creates a new undialed Websocket.
    61  func NewCustom(conn Connection, addr string) *Websocket {
    62  	return &Websocket{
    63  		conn:   conn,
    64  		addr:   addr,
    65  		closed: true,
    66  
    67  		sendLimiter: NewSendLimiter(),
    68  		dialLimiter: NewDialLimiter(),
    69  
    70  		Timeout: WSTimeout,
    71  	}
    72  }
    73  
    74  // Dial waits until the rate limiter allows then dials the websocket.
    75  func (ws *Websocket) Dial(ctx context.Context) error {
    76  	if ws.Timeout > 0 {
    77  		tctx, cancel := context.WithTimeout(ctx, ws.Timeout)
    78  		defer cancel()
    79  
    80  		ctx = tctx
    81  	}
    82  
    83  	if err := ws.dialLimiter.Wait(ctx); err != nil {
    84  		// Expired, fatal error
    85  		return errors.Wrap(err, "failed to wait")
    86  	}
    87  
    88  	ws.mutex.Lock()
    89  	defer ws.mutex.Unlock()
    90  
    91  	if !ws.closed {
    92  		WSDebug("Old connection not yet closed while dialog; closing it.")
    93  		ws.conn.Close()
    94  	}
    95  
    96  	if err := ws.conn.Dial(ctx, ws.addr); err != nil {
    97  		return errors.Wrap(err, "failed to dial")
    98  	}
    99  
   100  	ws.closed = false
   101  
   102  	// Reset the send limiter.
   103  	ws.sendLimiter = NewSendLimiter()
   104  
   105  	return nil
   106  }
   107  
   108  // Listen returns the inner event channel or nil if the Websocket connection is
   109  // not alive.
   110  func (ws *Websocket) Listen() <-chan Event {
   111  	ws.mutex.Lock()
   112  	defer ws.mutex.Unlock()
   113  
   114  	if ws.closed {
   115  		return nil
   116  	}
   117  
   118  	return ws.conn.Listen()
   119  }
   120  
   121  // Send sends b over the Websocket without a timeout.
   122  func (ws *Websocket) Send(b []byte) error {
   123  	return ws.SendCtx(context.Background(), b)
   124  }
   125  
   126  // SendCtx sends b over the Websocket with a deadline. It closes the internal
   127  // Websocket if the Send method errors out.
   128  func (ws *Websocket) SendCtx(ctx context.Context, b []byte) error {
   129  	WSDebug("Waiting for the send rate limiter...")
   130  
   131  	if err := ws.sendLimiter.Wait(ctx); err != nil {
   132  		WSDebug("Send rate limiter timed out.")
   133  		return errors.Wrap(err, "SendLimiter failed")
   134  	}
   135  
   136  	WSDebug("Send is passed the rate limiting. Waiting on mutex.")
   137  
   138  	ws.mutex.Lock()
   139  	defer ws.mutex.Unlock()
   140  
   141  	WSDebug("Mutex lock acquired.")
   142  
   143  	if ws.closed {
   144  		return ErrWebsocketClosed
   145  	}
   146  
   147  	if err := ws.conn.Send(ctx, b); err != nil {
   148  		// We need to clean up ourselves if things are erroring out.
   149  		WSDebug("Conn: Error while sending; closing the connection. Error:", err)
   150  		ws.close(false)
   151  		return err
   152  	}
   153  
   154  	return nil
   155  }
   156  
   157  // Close closes the websocket connection. It assumes that the Websocket is
   158  // closed even when it returns an error. If the Websocket was already closed
   159  // before, ErrWebsocketClosed will be returned.
   160  func (ws *Websocket) Close() error { return ws.close(false) }
   161  
   162  func (ws *Websocket) CloseGracefully() error { return ws.close(true) }
   163  
   164  // close closes the Websocket without acquiring the mutex. Refer to Close for
   165  // more information.
   166  func (ws *Websocket) close(graceful bool) error {
   167  	WSDebug("Conn: Acquiring mutex lock to close...")
   168  
   169  	ws.mutex.Lock()
   170  	defer ws.mutex.Unlock()
   171  
   172  	WSDebug("Conn: Write mutex acquired")
   173  
   174  	if ws.closed {
   175  		WSDebug("Conn: Websocket is already closed.")
   176  		return ErrWebsocketClosed
   177  	}
   178  
   179  	ws.closed = true
   180  
   181  	if graceful {
   182  		if gc, ok := ws.conn.(GracefulCloser); ok {
   183  			WSDebug("Conn: Closing gracefully")
   184  			return gc.CloseGracefully()
   185  		}
   186  
   187  		WSDebug("Conn: The Websocket's Connection does not support graceful closure.")
   188  	}
   189  
   190  	WSDebug("Conn: Closing")
   191  	return ws.conn.Close()
   192  }