github.com/diamondburned/arikawa/v2@v2.1.0/utils/wsutil/ws.go (about) 1 // Package wsutil provides abstractions around the Websocket, including rate 2 // limits. 3 package wsutil 4 5 import ( 6 "context" 7 "log" 8 "sync" 9 "time" 10 11 "github.com/pkg/errors" 12 "golang.org/x/time/rate" 13 ) 14 15 var ( 16 // WSTimeout is the timeout for connecting and writing to the Websocket, 17 // before Gateway cancels and fails. 18 WSTimeout = 30 * time.Second 19 // WSBuffer is the size of the Event channel. This has to be at least 1 to 20 // make space for the first Event: Ready or Resumed. 21 WSBuffer = 10 22 // WSError is the default error handler 23 WSError = func(err error) { log.Println("Gateway error:", err) } 24 // WSDebug is used for extra debug logging. This is expected to behave 25 // similarly to log.Println(). 26 WSDebug = func(v ...interface{}) {} 27 ) 28 29 type Event struct { 30 Data []byte 31 32 // Error is non-nil if Data is nil. 33 Error error 34 } 35 36 // Websocket is a wrapper around a websocket Conn with thread safety and rate 37 // limiting for sending and throttling. 38 type Websocket struct { 39 mutex sync.Mutex 40 conn Connection 41 addr string 42 closed bool 43 44 sendLimiter *rate.Limiter 45 dialLimiter *rate.Limiter 46 47 // Constants. These must not be changed after the Websocket instance is used 48 // once, as they are not thread-safe. 49 50 // Timeout for connecting and writing to the Websocket, uses default 51 // WSTimeout (global). 52 Timeout time.Duration 53 } 54 55 // New creates a default Websocket with the given address. 56 func New(addr string) *Websocket { 57 return NewCustom(NewConn(), addr) 58 } 59 60 // NewCustom creates a new undialed Websocket. 61 func NewCustom(conn Connection, addr string) *Websocket { 62 return &Websocket{ 63 conn: conn, 64 addr: addr, 65 closed: true, 66 67 sendLimiter: NewSendLimiter(), 68 dialLimiter: NewDialLimiter(), 69 70 Timeout: WSTimeout, 71 } 72 } 73 74 // Dial waits until the rate limiter allows then dials the websocket. 75 func (ws *Websocket) Dial(ctx context.Context) error { 76 if ws.Timeout > 0 { 77 tctx, cancel := context.WithTimeout(ctx, ws.Timeout) 78 defer cancel() 79 80 ctx = tctx 81 } 82 83 if err := ws.dialLimiter.Wait(ctx); err != nil { 84 // Expired, fatal error 85 return errors.Wrap(err, "failed to wait") 86 } 87 88 ws.mutex.Lock() 89 defer ws.mutex.Unlock() 90 91 if !ws.closed { 92 WSDebug("Old connection not yet closed while dialog; closing it.") 93 ws.conn.Close() 94 } 95 96 if err := ws.conn.Dial(ctx, ws.addr); err != nil { 97 return errors.Wrap(err, "failed to dial") 98 } 99 100 ws.closed = false 101 102 // Reset the send limiter. 103 ws.sendLimiter = NewSendLimiter() 104 105 return nil 106 } 107 108 // Listen returns the inner event channel or nil if the Websocket connection is 109 // not alive. 110 func (ws *Websocket) Listen() <-chan Event { 111 ws.mutex.Lock() 112 defer ws.mutex.Unlock() 113 114 if ws.closed { 115 return nil 116 } 117 118 return ws.conn.Listen() 119 } 120 121 // Send sends b over the Websocket without a timeout. 122 func (ws *Websocket) Send(b []byte) error { 123 return ws.SendCtx(context.Background(), b) 124 } 125 126 // SendCtx sends b over the Websocket with a deadline. It closes the internal 127 // Websocket if the Send method errors out. 128 func (ws *Websocket) SendCtx(ctx context.Context, b []byte) error { 129 WSDebug("Waiting for the send rate limiter...") 130 131 if err := ws.sendLimiter.Wait(ctx); err != nil { 132 WSDebug("Send rate limiter timed out.") 133 return errors.Wrap(err, "SendLimiter failed") 134 } 135 136 WSDebug("Send is passed the rate limiting. Waiting on mutex.") 137 138 ws.mutex.Lock() 139 defer ws.mutex.Unlock() 140 141 WSDebug("Mutex lock acquired.") 142 143 if ws.closed { 144 return ErrWebsocketClosed 145 } 146 147 if err := ws.conn.Send(ctx, b); err != nil { 148 // We need to clean up ourselves if things are erroring out. 149 WSDebug("Conn: Error while sending; closing the connection. Error:", err) 150 ws.close(false) 151 return err 152 } 153 154 return nil 155 } 156 157 // Close closes the websocket connection. It assumes that the Websocket is 158 // closed even when it returns an error. If the Websocket was already closed 159 // before, ErrWebsocketClosed will be returned. 160 func (ws *Websocket) Close() error { return ws.close(false) } 161 162 func (ws *Websocket) CloseGracefully() error { return ws.close(true) } 163 164 // close closes the Websocket without acquiring the mutex. Refer to Close for 165 // more information. 166 func (ws *Websocket) close(graceful bool) error { 167 WSDebug("Conn: Acquiring mutex lock to close...") 168 169 ws.mutex.Lock() 170 defer ws.mutex.Unlock() 171 172 WSDebug("Conn: Write mutex acquired") 173 174 if ws.closed { 175 WSDebug("Conn: Websocket is already closed.") 176 return ErrWebsocketClosed 177 } 178 179 ws.closed = true 180 181 if graceful { 182 if gc, ok := ws.conn.(GracefulCloser); ok { 183 WSDebug("Conn: Closing gracefully") 184 return gc.CloseGracefully() 185 } 186 187 WSDebug("Conn: The Websocket's Connection does not support graceful closure.") 188 } 189 190 WSDebug("Conn: Closing") 191 return ws.conn.Close() 192 }