github.com/diamondburned/arikawa@v1.3.14/utils/wsutil/conn.go (about) 1 package wsutil 2 3 import ( 4 "bytes" 5 "compress/zlib" 6 "context" 7 "io" 8 "net/http" 9 "strings" 10 "time" 11 12 "github.com/gorilla/websocket" 13 "github.com/pkg/errors" 14 ) 15 16 // CopyBufferSize is used for the initial size of the internal WS' buffer. Its 17 // size is 4KB. 18 var CopyBufferSize = 4096 19 20 // MaxCapUntilReset determines the maximum capacity before the bytes buffer is 21 // re-allocated. It is roughly 16KB, quadruple CopyBufferSize. 22 var MaxCapUntilReset = CopyBufferSize * 4 23 24 // CloseDeadline controls the deadline to wait for sending the Close frame. 25 var CloseDeadline = time.Second 26 27 // ErrWebsocketClosed is returned if the websocket is already closed. 28 var ErrWebsocketClosed = errors.New("websocket is closed") 29 30 // Connection is an interface that abstracts around a generic Websocket driver. 31 // This connection expects the driver to handle compression by itself, including 32 // modifying the connection URL. The implementation doesn't have to be safe for 33 // concurrent use. 34 type Connection interface { 35 // Dial dials the address (string). Context needs to be passed in for 36 // timeout. This method should also be re-usable after Close is called. 37 Dial(context.Context, string) error 38 39 // Listen returns an event channel that sends over events constantly. It can 40 // return nil if there isn't an ongoing connection. 41 Listen() <-chan Event 42 43 // Send allows the caller to send bytes. It does not need to clean itself 44 // up on errors, as the Websocket wrapper will do that. 45 Send(context.Context, []byte) error 46 47 // Close should close the websocket connection. The underlying connection 48 // may be reused, but this Connection instance will be reused with Dial. The 49 // Connection must still be reusable even if Close returns an error. 50 Close() error 51 } 52 53 // Conn is the default Websocket connection. It tries to compresses all payloads 54 // using zlib. 55 type Conn struct { 56 Dialer websocket.Dialer 57 Header http.Header 58 Conn *websocket.Conn 59 events chan Event 60 } 61 62 var _ Connection = (*Conn)(nil) 63 64 // NewConn creates a new default websocket connection with a default dialer. 65 func NewConn() *Conn { 66 return NewConnWithDialer(websocket.Dialer{ 67 Proxy: http.ProxyFromEnvironment, 68 HandshakeTimeout: WSTimeout, 69 ReadBufferSize: CopyBufferSize, 70 WriteBufferSize: CopyBufferSize, 71 EnableCompression: true, 72 }) 73 } 74 75 // NewConn creates a new default websocket connection with a custom dialer. 76 func NewConnWithDialer(dialer websocket.Dialer) *Conn { 77 return &Conn{ 78 Dialer: dialer, 79 Header: http.Header{ 80 "Accept-Encoding": {"zlib"}, 81 }, 82 } 83 } 84 85 func (c *Conn) Dial(ctx context.Context, addr string) (err error) { 86 // BUG which prevents stream compression. 87 // See https://github.com/golang/go/issues/31514. 88 89 c.Conn, _, err = c.Dialer.DialContext(ctx, addr, c.Header) 90 if err != nil { 91 return errors.Wrap(err, "failed to dial WS") 92 } 93 94 // Reset the deadline. 95 c.Conn.SetWriteDeadline(resetDeadline) 96 97 c.events = make(chan Event, WSBuffer) 98 go startReadLoop(c.Conn, c.events) 99 100 return err 101 } 102 103 // Listen returns an event channel if there is a connection associated with it. 104 // It returns nil if there is none. 105 func (c *Conn) Listen() <-chan Event { 106 return c.events 107 } 108 109 // resetDeadline is used to reset the write deadline after using the context's. 110 var resetDeadline = time.Time{} 111 112 func (c *Conn) Send(ctx context.Context, b []byte) error { 113 d, ok := ctx.Deadline() 114 if ok { 115 c.Conn.SetWriteDeadline(d) 116 defer c.Conn.SetWriteDeadline(resetDeadline) 117 } 118 119 if err := c.Conn.WriteMessage(websocket.TextMessage, b); err != nil { 120 return err 121 } 122 123 return nil 124 } 125 126 func (c *Conn) Close() error { 127 WSDebug("Conn: Close is called; shutting down the Websocket connection.") 128 129 // Have a deadline before closing. 130 var deadline = time.Now().Add(5 * time.Second) 131 c.Conn.SetWriteDeadline(deadline) 132 133 // Close the WS. 134 err := c.Conn.Close() 135 136 c.Conn.SetWriteDeadline(resetDeadline) 137 138 WSDebug("Conn: Websocket closed; error:", err) 139 WSDebug("Conn: Flusing events...") 140 141 // Flush all events before closing the channel. This will return as soon as 142 // c.events is closed, or after closed. 143 for range c.events { 144 } 145 146 WSDebug("Flushed events.") 147 148 return err 149 } 150 151 // loopState is a thread-unsafe disposable state container for the read loop. 152 // It's made to completely separate the read loop of any synchronization that 153 // doesn't involve the websocket connection itself. 154 type loopState struct { 155 conn *websocket.Conn 156 zlib io.ReadCloser 157 buf bytes.Buffer 158 } 159 160 func startReadLoop(conn *websocket.Conn, eventCh chan<- Event) { 161 // Clean up the events channel in the end. 162 defer close(eventCh) 163 164 // Allocate the read loop its own private resources. 165 state := loopState{conn: conn} 166 state.buf.Grow(CopyBufferSize) 167 168 for { 169 b, err := state.handle() 170 if err != nil { 171 WSDebug("Conn: Read error:", err) 172 173 // Is the error an EOF? 174 if errors.Is(err, io.EOF) { 175 // Yes it is, exit. 176 return 177 } 178 179 // Is the error an intentional close call? Go 1.16 exposes 180 // ErrClosing, but we have to do this for now. 181 if strings.HasSuffix(err.Error(), "use of closed network connection") { 182 return 183 } 184 185 // Check if the error is a normal one: 186 if websocket.IsCloseError(err, websocket.CloseNormalClosure) { 187 return 188 } 189 190 // Unusual error; log and exit: 191 eventCh <- Event{nil, errors.Wrap(err, "WS error")} 192 return 193 } 194 195 // If the payload length is 0, skip it. 196 if len(b) == 0 { 197 continue 198 } 199 200 eventCh <- Event{b, nil} 201 } 202 } 203 204 func (state *loopState) handle() ([]byte, error) { 205 // skip message type 206 t, r, err := state.conn.NextReader() 207 if err != nil { 208 return nil, err 209 } 210 211 if t == websocket.BinaryMessage { 212 // Probably a zlib payload. 213 214 if state.zlib == nil { 215 z, err := zlib.NewReader(r) 216 if err != nil { 217 return nil, errors.Wrap(err, "failed to create a zlib reader") 218 } 219 state.zlib = z 220 } else { 221 if err := state.zlib.(zlib.Resetter).Reset(r, nil); err != nil { 222 return nil, errors.Wrap(err, "failed to reset zlib reader") 223 } 224 } 225 226 defer state.zlib.Close() 227 r = state.zlib 228 } 229 230 return state.readAll(r) 231 } 232 233 // readAll reads bytes into an existing buffer, copy it over, then wipe the old 234 // buffer. 235 func (state *loopState) readAll(r io.Reader) ([]byte, error) { 236 defer state.buf.Reset() 237 238 if _, err := state.buf.ReadFrom(r); err != nil { 239 return nil, err 240 } 241 242 // Copy the bytes so we could empty the buffer for reuse. 243 cpy := make([]byte, state.buf.Len()) 244 copy(cpy, state.buf.Bytes()) 245 246 // If the buffer's capacity is over the limit, then re-allocate a new one. 247 if state.buf.Cap() > MaxCapUntilReset { 248 state.buf = bytes.Buffer{} 249 state.buf.Grow(CopyBufferSize) 250 } 251 252 return cpy, nil 253 }