github.com/mdaxf/iac@v0.0.0-20240519030858-58a061660378/vendor_skip/nhooyr.io/websocket/read.go (about) 1 // +build !js 2 3 package websocket 4 5 import ( 6 "bufio" 7 "context" 8 "errors" 9 "fmt" 10 "io" 11 "io/ioutil" 12 "strings" 13 "time" 14 15 "nhooyr.io/websocket/internal/errd" 16 "nhooyr.io/websocket/internal/xsync" 17 ) 18 19 // Reader reads from the connection until until there is a WebSocket 20 // data message to be read. It will handle ping, pong and close frames as appropriate. 21 // 22 // It returns the type of the message and an io.Reader to read it. 23 // The passed context will also bound the reader. 24 // Ensure you read to EOF otherwise the connection will hang. 25 // 26 // Call CloseRead if you do not expect any data messages from the peer. 27 // 28 // Only one Reader may be open at a time. 29 func (c *Conn) Reader(ctx context.Context) (MessageType, io.Reader, error) { 30 return c.reader(ctx) 31 } 32 33 // Read is a convenience method around Reader to read a single message 34 // from the connection. 35 func (c *Conn) Read(ctx context.Context) (MessageType, []byte, error) { 36 typ, r, err := c.Reader(ctx) 37 if err != nil { 38 return 0, nil, err 39 } 40 41 b, err := ioutil.ReadAll(r) 42 return typ, b, err 43 } 44 45 // CloseRead starts a goroutine to read from the connection until it is closed 46 // or a data message is received. 47 // 48 // Once CloseRead is called you cannot read any messages from the connection. 49 // The returned context will be cancelled when the connection is closed. 50 // 51 // If a data message is received, the connection will be closed with StatusPolicyViolation. 52 // 53 // Call CloseRead when you do not expect to read any more messages. 54 // Since it actively reads from the connection, it will ensure that ping, pong and close 55 // frames are responded to. This means c.Ping and c.Close will still work as expected. 56 func (c *Conn) CloseRead(ctx context.Context) context.Context { 57 ctx, cancel := context.WithCancel(ctx) 58 go func() { 59 defer cancel() 60 c.Reader(ctx) 61 c.Close(StatusPolicyViolation, "unexpected data message") 62 }() 63 return ctx 64 } 65 66 // SetReadLimit sets the max number of bytes to read for a single message. 67 // It applies to the Reader and Read methods. 68 // 69 // By default, the connection has a message read limit of 32768 bytes. 70 // 71 // When the limit is hit, the connection will be closed with StatusMessageTooBig. 72 func (c *Conn) SetReadLimit(n int64) { 73 // We add read one more byte than the limit in case 74 // there is a fin frame that needs to be read. 75 c.msgReader.limitReader.limit.Store(n + 1) 76 } 77 78 const defaultReadLimit = 32768 79 80 func newMsgReader(c *Conn) *msgReader { 81 mr := &msgReader{ 82 c: c, 83 fin: true, 84 } 85 mr.readFunc = mr.read 86 87 mr.limitReader = newLimitReader(c, mr.readFunc, defaultReadLimit+1) 88 return mr 89 } 90 91 func (mr *msgReader) resetFlate() { 92 if mr.flateContextTakeover() { 93 mr.dict.init(32768) 94 } 95 if mr.flateBufio == nil { 96 mr.flateBufio = getBufioReader(mr.readFunc) 97 } 98 99 mr.flateReader = getFlateReader(mr.flateBufio, mr.dict.buf) 100 mr.limitReader.r = mr.flateReader 101 mr.flateTail.Reset(deflateMessageTail) 102 } 103 104 func (mr *msgReader) putFlateReader() { 105 if mr.flateReader != nil { 106 putFlateReader(mr.flateReader) 107 mr.flateReader = nil 108 } 109 } 110 111 func (mr *msgReader) close() { 112 mr.c.readMu.forceLock() 113 mr.putFlateReader() 114 mr.dict.close() 115 if mr.flateBufio != nil { 116 putBufioReader(mr.flateBufio) 117 } 118 119 if mr.c.client { 120 putBufioReader(mr.c.br) 121 mr.c.br = nil 122 } 123 } 124 125 func (mr *msgReader) flateContextTakeover() bool { 126 if mr.c.client { 127 return !mr.c.copts.serverNoContextTakeover 128 } 129 return !mr.c.copts.clientNoContextTakeover 130 } 131 132 func (c *Conn) readRSV1Illegal(h header) bool { 133 // If compression is disabled, rsv1 is illegal. 134 if !c.flate() { 135 return true 136 } 137 // rsv1 is only allowed on data frames beginning messages. 138 if h.opcode != opText && h.opcode != opBinary { 139 return true 140 } 141 return false 142 } 143 144 func (c *Conn) readLoop(ctx context.Context) (header, error) { 145 for { 146 h, err := c.readFrameHeader(ctx) 147 if err != nil { 148 return header{}, err 149 } 150 151 if h.rsv1 && c.readRSV1Illegal(h) || h.rsv2 || h.rsv3 { 152 err := fmt.Errorf("received header with unexpected rsv bits set: %v:%v:%v", h.rsv1, h.rsv2, h.rsv3) 153 c.writeError(StatusProtocolError, err) 154 return header{}, err 155 } 156 157 if !c.client && !h.masked { 158 return header{}, errors.New("received unmasked frame from client") 159 } 160 161 switch h.opcode { 162 case opClose, opPing, opPong: 163 err = c.handleControl(ctx, h) 164 if err != nil { 165 // Pass through CloseErrors when receiving a close frame. 166 if h.opcode == opClose && CloseStatus(err) != -1 { 167 return header{}, err 168 } 169 return header{}, fmt.Errorf("failed to handle control frame %v: %w", h.opcode, err) 170 } 171 case opContinuation, opText, opBinary: 172 return h, nil 173 default: 174 err := fmt.Errorf("received unknown opcode %v", h.opcode) 175 c.writeError(StatusProtocolError, err) 176 return header{}, err 177 } 178 } 179 } 180 181 func (c *Conn) readFrameHeader(ctx context.Context) (header, error) { 182 select { 183 case <-c.closed: 184 return header{}, c.closeErr 185 case c.readTimeout <- ctx: 186 } 187 188 h, err := readFrameHeader(c.br, c.readHeaderBuf[:]) 189 if err != nil { 190 select { 191 case <-c.closed: 192 return header{}, c.closeErr 193 case <-ctx.Done(): 194 return header{}, ctx.Err() 195 default: 196 c.close(err) 197 return header{}, err 198 } 199 } 200 201 select { 202 case <-c.closed: 203 return header{}, c.closeErr 204 case c.readTimeout <- context.Background(): 205 } 206 207 return h, nil 208 } 209 210 func (c *Conn) readFramePayload(ctx context.Context, p []byte) (int, error) { 211 select { 212 case <-c.closed: 213 return 0, c.closeErr 214 case c.readTimeout <- ctx: 215 } 216 217 n, err := io.ReadFull(c.br, p) 218 if err != nil { 219 select { 220 case <-c.closed: 221 return n, c.closeErr 222 case <-ctx.Done(): 223 return n, ctx.Err() 224 default: 225 err = fmt.Errorf("failed to read frame payload: %w", err) 226 c.close(err) 227 return n, err 228 } 229 } 230 231 select { 232 case <-c.closed: 233 return n, c.closeErr 234 case c.readTimeout <- context.Background(): 235 } 236 237 return n, err 238 } 239 240 func (c *Conn) handleControl(ctx context.Context, h header) (err error) { 241 if h.payloadLength < 0 || h.payloadLength > maxControlPayload { 242 err := fmt.Errorf("received control frame payload with invalid length: %d", h.payloadLength) 243 c.writeError(StatusProtocolError, err) 244 return err 245 } 246 247 if !h.fin { 248 err := errors.New("received fragmented control frame") 249 c.writeError(StatusProtocolError, err) 250 return err 251 } 252 253 ctx, cancel := context.WithTimeout(ctx, time.Second*5) 254 defer cancel() 255 256 b := c.readControlBuf[:h.payloadLength] 257 _, err = c.readFramePayload(ctx, b) 258 if err != nil { 259 return err 260 } 261 262 if h.masked { 263 mask(h.maskKey, b) 264 } 265 266 switch h.opcode { 267 case opPing: 268 return c.writeControl(ctx, opPong, b) 269 case opPong: 270 c.activePingsMu.Lock() 271 pong, ok := c.activePings[string(b)] 272 c.activePingsMu.Unlock() 273 if ok { 274 select { 275 case pong <- struct{}{}: 276 default: 277 } 278 } 279 return nil 280 } 281 282 defer func() { 283 c.readCloseFrameErr = err 284 }() 285 286 ce, err := parseClosePayload(b) 287 if err != nil { 288 err = fmt.Errorf("received invalid close payload: %w", err) 289 c.writeError(StatusProtocolError, err) 290 return err 291 } 292 293 err = fmt.Errorf("received close frame: %w", ce) 294 c.setCloseErr(err) 295 c.writeClose(ce.Code, ce.Reason) 296 c.close(err) 297 return err 298 } 299 300 func (c *Conn) reader(ctx context.Context) (_ MessageType, _ io.Reader, err error) { 301 defer errd.Wrap(&err, "failed to get reader") 302 303 err = c.readMu.lock(ctx) 304 if err != nil { 305 return 0, nil, err 306 } 307 defer c.readMu.unlock() 308 309 if !c.msgReader.fin { 310 err = errors.New("previous message not read to completion") 311 c.close(fmt.Errorf("failed to get reader: %w", err)) 312 return 0, nil, err 313 } 314 315 h, err := c.readLoop(ctx) 316 if err != nil { 317 return 0, nil, err 318 } 319 320 if h.opcode == opContinuation { 321 err := errors.New("received continuation frame without text or binary frame") 322 c.writeError(StatusProtocolError, err) 323 return 0, nil, err 324 } 325 326 c.msgReader.reset(ctx, h) 327 328 return MessageType(h.opcode), c.msgReader, nil 329 } 330 331 type msgReader struct { 332 c *Conn 333 334 ctx context.Context 335 flate bool 336 flateReader io.Reader 337 flateBufio *bufio.Reader 338 flateTail strings.Reader 339 limitReader *limitReader 340 dict slidingWindow 341 342 fin bool 343 payloadLength int64 344 maskKey uint32 345 346 // readerFunc(mr.Read) to avoid continuous allocations. 347 readFunc readerFunc 348 } 349 350 func (mr *msgReader) reset(ctx context.Context, h header) { 351 mr.ctx = ctx 352 mr.flate = h.rsv1 353 mr.limitReader.reset(mr.readFunc) 354 355 if mr.flate { 356 mr.resetFlate() 357 } 358 359 mr.setFrame(h) 360 } 361 362 func (mr *msgReader) setFrame(h header) { 363 mr.fin = h.fin 364 mr.payloadLength = h.payloadLength 365 mr.maskKey = h.maskKey 366 } 367 368 func (mr *msgReader) Read(p []byte) (n int, err error) { 369 err = mr.c.readMu.lock(mr.ctx) 370 if err != nil { 371 return 0, fmt.Errorf("failed to read: %w", err) 372 } 373 defer mr.c.readMu.unlock() 374 375 n, err = mr.limitReader.Read(p) 376 if mr.flate && mr.flateContextTakeover() { 377 p = p[:n] 378 mr.dict.write(p) 379 } 380 if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) && mr.fin && mr.flate { 381 mr.putFlateReader() 382 return n, io.EOF 383 } 384 if err != nil { 385 err = fmt.Errorf("failed to read: %w", err) 386 mr.c.close(err) 387 } 388 return n, err 389 } 390 391 func (mr *msgReader) read(p []byte) (int, error) { 392 for { 393 if mr.payloadLength == 0 { 394 if mr.fin { 395 if mr.flate { 396 return mr.flateTail.Read(p) 397 } 398 return 0, io.EOF 399 } 400 401 h, err := mr.c.readLoop(mr.ctx) 402 if err != nil { 403 return 0, err 404 } 405 if h.opcode != opContinuation { 406 err := errors.New("received new data message without finishing the previous message") 407 mr.c.writeError(StatusProtocolError, err) 408 return 0, err 409 } 410 mr.setFrame(h) 411 412 continue 413 } 414 415 if int64(len(p)) > mr.payloadLength { 416 p = p[:mr.payloadLength] 417 } 418 419 n, err := mr.c.readFramePayload(mr.ctx, p) 420 if err != nil { 421 return n, err 422 } 423 424 mr.payloadLength -= int64(n) 425 426 if !mr.c.client { 427 mr.maskKey = mask(mr.maskKey, p) 428 } 429 430 return n, nil 431 } 432 } 433 434 type limitReader struct { 435 c *Conn 436 r io.Reader 437 limit xsync.Int64 438 n int64 439 } 440 441 func newLimitReader(c *Conn, r io.Reader, limit int64) *limitReader { 442 lr := &limitReader{ 443 c: c, 444 } 445 lr.limit.Store(limit) 446 lr.reset(r) 447 return lr 448 } 449 450 func (lr *limitReader) reset(r io.Reader) { 451 lr.n = lr.limit.Load() 452 lr.r = r 453 } 454 455 func (lr *limitReader) Read(p []byte) (int, error) { 456 if lr.n <= 0 { 457 err := fmt.Errorf("read limited at %v bytes", lr.limit.Load()) 458 lr.c.writeError(StatusMessageTooBig, err) 459 return 0, err 460 } 461 462 if int64(len(p)) > lr.n { 463 p = p[:lr.n] 464 } 465 n, err := lr.r.Read(p) 466 lr.n -= int64(n) 467 return n, err 468 } 469 470 type readerFunc func(p []byte) (int, error) 471 472 func (f readerFunc) Read(p []byte) (int, error) { 473 return f(p) 474 }