github.com/mdaxf/iac@v0.0.0-20240519030858-58a061660378/vendor_skip/nhooyr.io/websocket/read.go (about)

     1  // +build !js
     2  
     3  package websocket
     4  
     5  import (
     6  	"bufio"
     7  	"context"
     8  	"errors"
     9  	"fmt"
    10  	"io"
    11  	"io/ioutil"
    12  	"strings"
    13  	"time"
    14  
    15  	"nhooyr.io/websocket/internal/errd"
    16  	"nhooyr.io/websocket/internal/xsync"
    17  )
    18  
    19  // Reader reads from the connection until until there is a WebSocket
    20  // data message to be read. It will handle ping, pong and close frames as appropriate.
    21  //
    22  // It returns the type of the message and an io.Reader to read it.
    23  // The passed context will also bound the reader.
    24  // Ensure you read to EOF otherwise the connection will hang.
    25  //
    26  // Call CloseRead if you do not expect any data messages from the peer.
    27  //
    28  // Only one Reader may be open at a time.
    29  func (c *Conn) Reader(ctx context.Context) (MessageType, io.Reader, error) {
    30  	return c.reader(ctx)
    31  }
    32  
    33  // Read is a convenience method around Reader to read a single message
    34  // from the connection.
    35  func (c *Conn) Read(ctx context.Context) (MessageType, []byte, error) {
    36  	typ, r, err := c.Reader(ctx)
    37  	if err != nil {
    38  		return 0, nil, err
    39  	}
    40  
    41  	b, err := ioutil.ReadAll(r)
    42  	return typ, b, err
    43  }
    44  
    45  // CloseRead starts a goroutine to read from the connection until it is closed
    46  // or a data message is received.
    47  //
    48  // Once CloseRead is called you cannot read any messages from the connection.
    49  // The returned context will be cancelled when the connection is closed.
    50  //
    51  // If a data message is received, the connection will be closed with StatusPolicyViolation.
    52  //
    53  // Call CloseRead when you do not expect to read any more messages.
    54  // Since it actively reads from the connection, it will ensure that ping, pong and close
    55  // frames are responded to. This means c.Ping and c.Close will still work as expected.
    56  func (c *Conn) CloseRead(ctx context.Context) context.Context {
    57  	ctx, cancel := context.WithCancel(ctx)
    58  	go func() {
    59  		defer cancel()
    60  		c.Reader(ctx)
    61  		c.Close(StatusPolicyViolation, "unexpected data message")
    62  	}()
    63  	return ctx
    64  }
    65  
    66  // SetReadLimit sets the max number of bytes to read for a single message.
    67  // It applies to the Reader and Read methods.
    68  //
    69  // By default, the connection has a message read limit of 32768 bytes.
    70  //
    71  // When the limit is hit, the connection will be closed with StatusMessageTooBig.
    72  func (c *Conn) SetReadLimit(n int64) {
    73  	// We add read one more byte than the limit in case
    74  	// there is a fin frame that needs to be read.
    75  	c.msgReader.limitReader.limit.Store(n + 1)
    76  }
    77  
    78  const defaultReadLimit = 32768
    79  
    80  func newMsgReader(c *Conn) *msgReader {
    81  	mr := &msgReader{
    82  		c:   c,
    83  		fin: true,
    84  	}
    85  	mr.readFunc = mr.read
    86  
    87  	mr.limitReader = newLimitReader(c, mr.readFunc, defaultReadLimit+1)
    88  	return mr
    89  }
    90  
    91  func (mr *msgReader) resetFlate() {
    92  	if mr.flateContextTakeover() {
    93  		mr.dict.init(32768)
    94  	}
    95  	if mr.flateBufio == nil {
    96  		mr.flateBufio = getBufioReader(mr.readFunc)
    97  	}
    98  
    99  	mr.flateReader = getFlateReader(mr.flateBufio, mr.dict.buf)
   100  	mr.limitReader.r = mr.flateReader
   101  	mr.flateTail.Reset(deflateMessageTail)
   102  }
   103  
   104  func (mr *msgReader) putFlateReader() {
   105  	if mr.flateReader != nil {
   106  		putFlateReader(mr.flateReader)
   107  		mr.flateReader = nil
   108  	}
   109  }
   110  
   111  func (mr *msgReader) close() {
   112  	mr.c.readMu.forceLock()
   113  	mr.putFlateReader()
   114  	mr.dict.close()
   115  	if mr.flateBufio != nil {
   116  		putBufioReader(mr.flateBufio)
   117  	}
   118  
   119  	if mr.c.client {
   120  		putBufioReader(mr.c.br)
   121  		mr.c.br = nil
   122  	}
   123  }
   124  
   125  func (mr *msgReader) flateContextTakeover() bool {
   126  	if mr.c.client {
   127  		return !mr.c.copts.serverNoContextTakeover
   128  	}
   129  	return !mr.c.copts.clientNoContextTakeover
   130  }
   131  
   132  func (c *Conn) readRSV1Illegal(h header) bool {
   133  	// If compression is disabled, rsv1 is illegal.
   134  	if !c.flate() {
   135  		return true
   136  	}
   137  	// rsv1 is only allowed on data frames beginning messages.
   138  	if h.opcode != opText && h.opcode != opBinary {
   139  		return true
   140  	}
   141  	return false
   142  }
   143  
   144  func (c *Conn) readLoop(ctx context.Context) (header, error) {
   145  	for {
   146  		h, err := c.readFrameHeader(ctx)
   147  		if err != nil {
   148  			return header{}, err
   149  		}
   150  
   151  		if h.rsv1 && c.readRSV1Illegal(h) || h.rsv2 || h.rsv3 {
   152  			err := fmt.Errorf("received header with unexpected rsv bits set: %v:%v:%v", h.rsv1, h.rsv2, h.rsv3)
   153  			c.writeError(StatusProtocolError, err)
   154  			return header{}, err
   155  		}
   156  
   157  		if !c.client && !h.masked {
   158  			return header{}, errors.New("received unmasked frame from client")
   159  		}
   160  
   161  		switch h.opcode {
   162  		case opClose, opPing, opPong:
   163  			err = c.handleControl(ctx, h)
   164  			if err != nil {
   165  				// Pass through CloseErrors when receiving a close frame.
   166  				if h.opcode == opClose && CloseStatus(err) != -1 {
   167  					return header{}, err
   168  				}
   169  				return header{}, fmt.Errorf("failed to handle control frame %v: %w", h.opcode, err)
   170  			}
   171  		case opContinuation, opText, opBinary:
   172  			return h, nil
   173  		default:
   174  			err := fmt.Errorf("received unknown opcode %v", h.opcode)
   175  			c.writeError(StatusProtocolError, err)
   176  			return header{}, err
   177  		}
   178  	}
   179  }
   180  
   181  func (c *Conn) readFrameHeader(ctx context.Context) (header, error) {
   182  	select {
   183  	case <-c.closed:
   184  		return header{}, c.closeErr
   185  	case c.readTimeout <- ctx:
   186  	}
   187  
   188  	h, err := readFrameHeader(c.br, c.readHeaderBuf[:])
   189  	if err != nil {
   190  		select {
   191  		case <-c.closed:
   192  			return header{}, c.closeErr
   193  		case <-ctx.Done():
   194  			return header{}, ctx.Err()
   195  		default:
   196  			c.close(err)
   197  			return header{}, err
   198  		}
   199  	}
   200  
   201  	select {
   202  	case <-c.closed:
   203  		return header{}, c.closeErr
   204  	case c.readTimeout <- context.Background():
   205  	}
   206  
   207  	return h, nil
   208  }
   209  
   210  func (c *Conn) readFramePayload(ctx context.Context, p []byte) (int, error) {
   211  	select {
   212  	case <-c.closed:
   213  		return 0, c.closeErr
   214  	case c.readTimeout <- ctx:
   215  	}
   216  
   217  	n, err := io.ReadFull(c.br, p)
   218  	if err != nil {
   219  		select {
   220  		case <-c.closed:
   221  			return n, c.closeErr
   222  		case <-ctx.Done():
   223  			return n, ctx.Err()
   224  		default:
   225  			err = fmt.Errorf("failed to read frame payload: %w", err)
   226  			c.close(err)
   227  			return n, err
   228  		}
   229  	}
   230  
   231  	select {
   232  	case <-c.closed:
   233  		return n, c.closeErr
   234  	case c.readTimeout <- context.Background():
   235  	}
   236  
   237  	return n, err
   238  }
   239  
   240  func (c *Conn) handleControl(ctx context.Context, h header) (err error) {
   241  	if h.payloadLength < 0 || h.payloadLength > maxControlPayload {
   242  		err := fmt.Errorf("received control frame payload with invalid length: %d", h.payloadLength)
   243  		c.writeError(StatusProtocolError, err)
   244  		return err
   245  	}
   246  
   247  	if !h.fin {
   248  		err := errors.New("received fragmented control frame")
   249  		c.writeError(StatusProtocolError, err)
   250  		return err
   251  	}
   252  
   253  	ctx, cancel := context.WithTimeout(ctx, time.Second*5)
   254  	defer cancel()
   255  
   256  	b := c.readControlBuf[:h.payloadLength]
   257  	_, err = c.readFramePayload(ctx, b)
   258  	if err != nil {
   259  		return err
   260  	}
   261  
   262  	if h.masked {
   263  		mask(h.maskKey, b)
   264  	}
   265  
   266  	switch h.opcode {
   267  	case opPing:
   268  		return c.writeControl(ctx, opPong, b)
   269  	case opPong:
   270  		c.activePingsMu.Lock()
   271  		pong, ok := c.activePings[string(b)]
   272  		c.activePingsMu.Unlock()
   273  		if ok {
   274  			select {
   275  			case pong <- struct{}{}:
   276  			default:
   277  			}
   278  		}
   279  		return nil
   280  	}
   281  
   282  	defer func() {
   283  		c.readCloseFrameErr = err
   284  	}()
   285  
   286  	ce, err := parseClosePayload(b)
   287  	if err != nil {
   288  		err = fmt.Errorf("received invalid close payload: %w", err)
   289  		c.writeError(StatusProtocolError, err)
   290  		return err
   291  	}
   292  
   293  	err = fmt.Errorf("received close frame: %w", ce)
   294  	c.setCloseErr(err)
   295  	c.writeClose(ce.Code, ce.Reason)
   296  	c.close(err)
   297  	return err
   298  }
   299  
   300  func (c *Conn) reader(ctx context.Context) (_ MessageType, _ io.Reader, err error) {
   301  	defer errd.Wrap(&err, "failed to get reader")
   302  
   303  	err = c.readMu.lock(ctx)
   304  	if err != nil {
   305  		return 0, nil, err
   306  	}
   307  	defer c.readMu.unlock()
   308  
   309  	if !c.msgReader.fin {
   310  		err = errors.New("previous message not read to completion")
   311  		c.close(fmt.Errorf("failed to get reader: %w", err))
   312  		return 0, nil, err
   313  	}
   314  
   315  	h, err := c.readLoop(ctx)
   316  	if err != nil {
   317  		return 0, nil, err
   318  	}
   319  
   320  	if h.opcode == opContinuation {
   321  		err := errors.New("received continuation frame without text or binary frame")
   322  		c.writeError(StatusProtocolError, err)
   323  		return 0, nil, err
   324  	}
   325  
   326  	c.msgReader.reset(ctx, h)
   327  
   328  	return MessageType(h.opcode), c.msgReader, nil
   329  }
   330  
   331  type msgReader struct {
   332  	c *Conn
   333  
   334  	ctx         context.Context
   335  	flate       bool
   336  	flateReader io.Reader
   337  	flateBufio  *bufio.Reader
   338  	flateTail   strings.Reader
   339  	limitReader *limitReader
   340  	dict        slidingWindow
   341  
   342  	fin           bool
   343  	payloadLength int64
   344  	maskKey       uint32
   345  
   346  	// readerFunc(mr.Read) to avoid continuous allocations.
   347  	readFunc readerFunc
   348  }
   349  
   350  func (mr *msgReader) reset(ctx context.Context, h header) {
   351  	mr.ctx = ctx
   352  	mr.flate = h.rsv1
   353  	mr.limitReader.reset(mr.readFunc)
   354  
   355  	if mr.flate {
   356  		mr.resetFlate()
   357  	}
   358  
   359  	mr.setFrame(h)
   360  }
   361  
   362  func (mr *msgReader) setFrame(h header) {
   363  	mr.fin = h.fin
   364  	mr.payloadLength = h.payloadLength
   365  	mr.maskKey = h.maskKey
   366  }
   367  
   368  func (mr *msgReader) Read(p []byte) (n int, err error) {
   369  	err = mr.c.readMu.lock(mr.ctx)
   370  	if err != nil {
   371  		return 0, fmt.Errorf("failed to read: %w", err)
   372  	}
   373  	defer mr.c.readMu.unlock()
   374  
   375  	n, err = mr.limitReader.Read(p)
   376  	if mr.flate && mr.flateContextTakeover() {
   377  		p = p[:n]
   378  		mr.dict.write(p)
   379  	}
   380  	if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) && mr.fin && mr.flate {
   381  		mr.putFlateReader()
   382  		return n, io.EOF
   383  	}
   384  	if err != nil {
   385  		err = fmt.Errorf("failed to read: %w", err)
   386  		mr.c.close(err)
   387  	}
   388  	return n, err
   389  }
   390  
   391  func (mr *msgReader) read(p []byte) (int, error) {
   392  	for {
   393  		if mr.payloadLength == 0 {
   394  			if mr.fin {
   395  				if mr.flate {
   396  					return mr.flateTail.Read(p)
   397  				}
   398  				return 0, io.EOF
   399  			}
   400  
   401  			h, err := mr.c.readLoop(mr.ctx)
   402  			if err != nil {
   403  				return 0, err
   404  			}
   405  			if h.opcode != opContinuation {
   406  				err := errors.New("received new data message without finishing the previous message")
   407  				mr.c.writeError(StatusProtocolError, err)
   408  				return 0, err
   409  			}
   410  			mr.setFrame(h)
   411  
   412  			continue
   413  		}
   414  
   415  		if int64(len(p)) > mr.payloadLength {
   416  			p = p[:mr.payloadLength]
   417  		}
   418  
   419  		n, err := mr.c.readFramePayload(mr.ctx, p)
   420  		if err != nil {
   421  			return n, err
   422  		}
   423  
   424  		mr.payloadLength -= int64(n)
   425  
   426  		if !mr.c.client {
   427  			mr.maskKey = mask(mr.maskKey, p)
   428  		}
   429  
   430  		return n, nil
   431  	}
   432  }
   433  
   434  type limitReader struct {
   435  	c     *Conn
   436  	r     io.Reader
   437  	limit xsync.Int64
   438  	n     int64
   439  }
   440  
   441  func newLimitReader(c *Conn, r io.Reader, limit int64) *limitReader {
   442  	lr := &limitReader{
   443  		c: c,
   444  	}
   445  	lr.limit.Store(limit)
   446  	lr.reset(r)
   447  	return lr
   448  }
   449  
   450  func (lr *limitReader) reset(r io.Reader) {
   451  	lr.n = lr.limit.Load()
   452  	lr.r = r
   453  }
   454  
   455  func (lr *limitReader) Read(p []byte) (int, error) {
   456  	if lr.n <= 0 {
   457  		err := fmt.Errorf("read limited at %v bytes", lr.limit.Load())
   458  		lr.c.writeError(StatusMessageTooBig, err)
   459  		return 0, err
   460  	}
   461  
   462  	if int64(len(p)) > lr.n {
   463  		p = p[:lr.n]
   464  	}
   465  	n, err := lr.r.Read(p)
   466  	lr.n -= int64(n)
   467  	return n, err
   468  }
   469  
   470  type readerFunc func(p []byte) (int, error)
   471  
   472  func (f readerFunc) Read(p []byte) (int, error) {
   473  	return f(p)
   474  }