github.com/ronaksoft/rony@v0.16.26-0.20230807065236-1743dbfe6959/internal/gateway/tcp/util/writer.go (about)

     1  package wsutil
     2  
     3  import (
     4  	"fmt"
     5  	"io"
     6  
     7  	"github.com/ronaksoft/rony/pools"
     8  
     9  	"github.com/gobwas/pool"
    10  	"github.com/gobwas/ws"
    11  )
    12  
    13  // DefaultWriteBuffer contains size of Writer's default buffer. It used by
    14  // Writer constructor functions.
    15  var DefaultWriteBuffer = 4096
    16  
    17  var (
    18  	// ErrNotEmpty is returned by Writer.WriteThrough() to indicate that buffer is
    19  	// not empty and write through could not be done. That is, caller should call
    20  	// Writer.FlushFragment() to make buffer empty.
    21  	ErrNotEmpty = fmt.Errorf("writer not empty")
    22  
    23  	// ErrControlOverflow is returned by ControlWriter.Write() to indicate that
    24  	// no more data could be written to the underlying io.Writer because
    25  	// MaxControlFramePayloadSize limit is reached.
    26  	ErrControlOverflow = fmt.Errorf("control frame payload overflow")
    27  )
    28  
    29  // Constants which are represent frame length ranges.
    30  const (
    31  	len7  = int64(125) // 126 and 127 are reserved values
    32  	len16 = int64(^uint16(0))
    33  )
    34  
    35  // ControlWriter is a wrapper around Writer that contains some guards for
    36  // buffered writes of control frames.
    37  type ControlWriter struct {
    38  	w     *Writer
    39  	limit int
    40  	n     int
    41  }
    42  
    43  // NewControlWriter contains ControlWriter with Writer inside whose buffer size
    44  // is at most ws.MaxControlFramePayloadSize + ws.MaxHeaderSize.
    45  func NewControlWriter(dest io.Writer, state ws.State, op ws.OpCode) *ControlWriter {
    46  	return &ControlWriter{
    47  		w:     NewWriterSize(dest, state, op, ws.MaxControlFramePayloadSize),
    48  		limit: ws.MaxControlFramePayloadSize,
    49  	}
    50  }
    51  
    52  // NewControlWriterBuffer returns a new ControlWriter with buf as a buffer.
    53  //
    54  // Note that it reserves x bytes of buf for header data, where x could be
    55  // ws.MinHeaderSize or ws.MinHeaderSize+4 (depending on state). At most
    56  // (ws.MaxControlFramePayloadSize + x) bytes of buf will be used.
    57  //
    58  // It panics if len(buf) <= ws.MinHeaderSize + x.
    59  func NewControlWriterBuffer(dest io.Writer, state ws.State, op ws.OpCode, buf []byte) *ControlWriter {
    60  	max := ws.MaxControlFramePayloadSize + headerSize(state, ws.MaxControlFramePayloadSize)
    61  	if len(buf) > max {
    62  		buf = buf[:max]
    63  	}
    64  
    65  	w := NewWriterBuffer(dest, state, op, buf)
    66  
    67  	return &ControlWriter{
    68  		w:     w,
    69  		limit: len(w.buf),
    70  	}
    71  }
    72  
    73  // Write implements io.Writer. It writes to the underlying Writer until it
    74  // returns error or until ControlWriter write limit will be exceeded.
    75  func (c *ControlWriter) Write(p []byte) (n int, err error) {
    76  	if c.n+len(p) > c.limit {
    77  		return 0, ErrControlOverflow
    78  	}
    79  
    80  	return c.w.Write(p)
    81  }
    82  
    83  // Flush flushes all buffered data to the underlying io.Writer.
    84  func (c *ControlWriter) Flush() error {
    85  	return c.w.Flush()
    86  }
    87  
    88  // Writer contains logic of buffering output data into a WebSocket fragments.
    89  // It is much the same as bufio.Writer, except the thing that it works with
    90  // WebSocket frames, not the raw data.
    91  //
    92  // Writer writes frames with specified OpCode.
    93  // It uses ws.State to decide whether the output frames must be masked.
    94  //
    95  // Note that it does not check control frame size or other RFC rules.
    96  // That is, it must be used with special care to write control frames without
    97  // violation of RFC. You could use ControlWriter that wraps Writer and contains
    98  // some guards for writing control frames.
    99  //
   100  // If an error occurs writing to a Writer, no more data will be accepted and
   101  // all subsequent writes will return the error.
   102  // After all data has been written, the client should call the Flush() method
   103  // to guarantee all data has been forwarded to the underlying io.Writer.
   104  type Writer struct {
   105  	dest io.Writer
   106  
   107  	n   int    // Buffered bytes counter.
   108  	raw []byte // Raw representation of buffer, including reserved header bytes.
   109  	buf []byte // Writeable part of buffer, without reserved header bytes.
   110  
   111  	op    ws.OpCode
   112  	state ws.State
   113  
   114  	dirty      bool
   115  	fragmented bool
   116  
   117  	err error
   118  }
   119  
   120  var writers = pool.New(128, 65536)
   121  
   122  // GetWriter tries to reuse Writer getting it from the pool.
   123  //
   124  // This function is intended for memory consumption optimizations, because
   125  // NewWriter*() functions make allocations for inner buffer.
   126  //
   127  // Note the it ceils n to the power of two.
   128  //
   129  // If you have your own bytes buffer pool you could use NewWriterBuffer to use
   130  // pooled bytes in writer.
   131  func GetWriter(dest io.Writer, state ws.State, op ws.OpCode, n int) *Writer {
   132  	x, m := writers.Get(n)
   133  	if x != nil {
   134  		w := x.(*Writer)
   135  		w.Reset(dest, state, op)
   136  
   137  		return w
   138  	}
   139  	// NOTE: we use m instead of n, because m is an attempt to reuse w of such
   140  	// size in the future.
   141  	return NewWriterBufferSize(dest, state, op, m)
   142  }
   143  
   144  // PutWriter puts w for future reuse by GetWriter().
   145  func PutWriter(w *Writer) {
   146  	w.Reset(nil, 0, 0)
   147  	writers.Put(w, w.Size())
   148  }
   149  
   150  // NewWriter returns a new Writer whose buffer has the DefaultWriteBuffer size.
   151  func NewWriter(dest io.Writer, state ws.State, op ws.OpCode) *Writer {
   152  	return NewWriterBufferSize(dest, state, op, 0)
   153  }
   154  
   155  // NewWriterSize returns a new Writer whose buffer size is at most n + ws.MaxHeaderSize.
   156  // That is, output frames payload length could be up to n, except the case when
   157  // Write() is called on empty Writer with len(p) > n.
   158  //
   159  // If n <= 0 then the default buffer size is used as Writer's buffer size.
   160  func NewWriterSize(dest io.Writer, state ws.State, op ws.OpCode, n int) *Writer {
   161  	if n > 0 {
   162  		n += headerSize(state, n)
   163  	}
   164  
   165  	return NewWriterBufferSize(dest, state, op, n)
   166  }
   167  
   168  // NewWriterBufferSize returns a new Writer whose buffer size is equal to n.
   169  // If n <= ws.MinHeaderSize then the default buffer size is used.
   170  //
   171  // Note that Writer will reserve x bytes for header data, where x is in range
   172  // [ws.MinHeaderSize,ws.MaxHeaderSize]. That is, frames flushed by Writer
   173  // will not have payload length equal to n, except the case when Write() is
   174  // called on empty Writer with len(p) > n.
   175  func NewWriterBufferSize(dest io.Writer, state ws.State, op ws.OpCode, n int) *Writer {
   176  	if n <= ws.MinHeaderSize {
   177  		n = DefaultWriteBuffer
   178  	}
   179  
   180  	return NewWriterBuffer(dest, state, op, make([]byte, n))
   181  }
   182  
   183  // NewWriterBuffer returns a new Writer with buf as a buffer.
   184  //
   185  // Note that it reserves x bytes of buf for header data, where x is in range
   186  // [ws.MinHeaderSize,ws.MaxHeaderSize] (depending on state and buf size).
   187  //
   188  // You could use ws.HeaderSize() to calculate number of bytes needed to store
   189  // header data.
   190  //
   191  // It panics if len(buf) is too small to fit header and payload data.
   192  func NewWriterBuffer(dest io.Writer, state ws.State, op ws.OpCode, buf []byte) *Writer {
   193  	offset := reserve(state, len(buf))
   194  	if len(buf) <= offset {
   195  		panic("buffer too small")
   196  	}
   197  
   198  	return &Writer{
   199  		dest:  dest,
   200  		raw:   buf,
   201  		buf:   buf[offset:],
   202  		state: state,
   203  		op:    op,
   204  	}
   205  }
   206  
   207  func reserve(state ws.State, n int) (offset int) {
   208  	var mask int
   209  	if state.ClientSide() {
   210  		mask = 4
   211  	}
   212  
   213  	switch {
   214  	case n <= int(len7)+mask+2:
   215  		return mask + 2
   216  	case n <= int(len16)+mask+4:
   217  		return mask + 4
   218  	default:
   219  		return mask + 10
   220  	}
   221  }
   222  
   223  // headerSize returns number of bytes needed to encode header of a frame with
   224  // given state and length.
   225  func headerSize(s ws.State, n int) int {
   226  	return ws.HeaderSize(ws.Header{
   227  		Length: int64(n),
   228  		Masked: s.ClientSide(),
   229  	})
   230  }
   231  
   232  // Reset discards any buffered data, clears error, and resets w to have given
   233  // state and write frames with given OpCode to dest.
   234  func (w *Writer) Reset(dest io.Writer, state ws.State, op ws.OpCode) {
   235  	w.n = 0
   236  	w.dirty = false
   237  	w.fragmented = false
   238  	w.dest = dest
   239  	w.state = state
   240  	w.op = op
   241  }
   242  
   243  // Size returns the size of the underlying buffer in bytes.
   244  func (w *Writer) Size() int {
   245  	return len(w.buf)
   246  }
   247  
   248  // Available returns how many bytes are unused in the buffer.
   249  func (w *Writer) Available() int {
   250  	return len(w.buf) - w.n
   251  }
   252  
   253  // Buffered returns the number of bytes that have been written into the current
   254  // buffer.
   255  func (w *Writer) Buffered() int {
   256  	return w.n
   257  }
   258  
   259  // Write implements io.Writer.
   260  //
   261  // Note that even if the Writer was created to have N-sized buffer, Write()
   262  // with payload of N bytes will not fit into that buffer. Writer reserves some
   263  // space to fit WebSocket header data.
   264  func (w *Writer) Write(p []byte) (n int, err error) {
   265  	// Even empty p may make a sense.
   266  	w.dirty = true
   267  
   268  	var nn int
   269  	for len(p) > w.Available() && w.err == nil {
   270  		if w.Buffered() == 0 {
   271  			// Large write, empty buffer. Write directly from p to avoid copy.
   272  			// Trade off here is that we make additional Write() to underlying
   273  			// io.Writer when writing frame header.
   274  			//
   275  			// On large buffers additional write is better than copying.
   276  			nn, _ = w.WriteThrough(p)
   277  		} else {
   278  			nn = copy(w.buf[w.n:], p)
   279  			w.n += nn
   280  			_ = w.FlushFragment()
   281  		}
   282  		n += nn
   283  		p = p[nn:]
   284  	}
   285  	if w.err != nil {
   286  		return n, w.err
   287  	}
   288  	nn = copy(w.buf[w.n:], p)
   289  	w.n += nn
   290  	n += nn
   291  
   292  	// Even if w.Available() == 0 we will not flush buffer preventively because
   293  	// this could bring unwanted fragmentation. That is, user could create
   294  	// buffer with size that fits exactly all further Write() call, and then
   295  	// call Flush(), excepting that single and not fragmented frame will be
   296  	// sent. With preemptive flush this case will produce two frames – last one
   297  	// will be empty and just to set fin = true.
   298  
   299  	return n, w.err
   300  }
   301  
   302  // WriteThrough writes data bypassing the buffer.
   303  // Note that Writer's buffer must be empty before calling WriteThrough().
   304  func (w *Writer) WriteThrough(p []byte) (n int, err error) {
   305  	if w.err != nil {
   306  		return 0, w.err
   307  	}
   308  	if w.Buffered() != 0 {
   309  		return 0, ErrNotEmpty
   310  	}
   311  
   312  	w.err = writeFrame(w.dest, w.state, w.opCode(), false, p)
   313  	if w.err == nil {
   314  		n = len(p)
   315  	}
   316  
   317  	w.dirty = true
   318  	w.fragmented = true
   319  
   320  	return n, w.err
   321  }
   322  
   323  // ReadFrom implements io.ReaderFrom.
   324  func (w *Writer) ReadFrom(src io.Reader) (n int64, err error) {
   325  	var nn int
   326  	for err == nil {
   327  		if w.Available() == 0 {
   328  			err = w.FlushFragment()
   329  
   330  			continue
   331  		}
   332  
   333  		// We copy the behavior of bufio.Writer here.
   334  		// Also, from the docs on io.ReaderFrom:
   335  		//   ReadFrom reads data from r until EOF or error.
   336  		//
   337  		// See https://codereview.appspot.com/76400048/#ps1
   338  		const maxEmptyReads = 100
   339  		var nr int
   340  		for nr < maxEmptyReads {
   341  			nn, err = src.Read(w.buf[w.n:])
   342  			if nn != 0 || err != nil {
   343  				break
   344  			}
   345  			nr++
   346  		}
   347  		if nr == maxEmptyReads {
   348  			return n, io.ErrNoProgress
   349  		}
   350  
   351  		w.n += nn
   352  		n += int64(nn)
   353  	}
   354  	if err == io.EOF {
   355  		// NOTE: Do not flush preemptively.
   356  		// See the Write() sources for more info.
   357  		err = nil
   358  		w.dirty = true
   359  	}
   360  
   361  	return n, err
   362  }
   363  
   364  // Flush writes any buffered data to the underlying io.Writer.
   365  // It sends the frame with "fin" flag set to true.
   366  //
   367  // If no Write() or ReadFrom() was made, then Flush() does nothing.
   368  func (w *Writer) Flush() error {
   369  	if (!w.dirty && w.Buffered() == 0) || w.err != nil {
   370  		return w.err
   371  	}
   372  
   373  	w.err = w.flushFragment(true)
   374  	w.n = 0
   375  	w.dirty = false
   376  	w.fragmented = false
   377  
   378  	return w.err
   379  }
   380  
   381  // FlushFragment writes any buffered data to the underlying io.Writer.
   382  // It sends the frame with "fin" flag set to false.
   383  func (w *Writer) FlushFragment() error {
   384  	if w.Buffered() == 0 || w.err != nil {
   385  		return w.err
   386  	}
   387  
   388  	w.err = w.flushFragment(false)
   389  	w.n = 0
   390  	w.fragmented = true
   391  
   392  	return w.err
   393  }
   394  
   395  func (w *Writer) flushFragment(fin bool) error {
   396  	frame := ws.NewFrame(w.opCode(), fin, w.buf[:w.n])
   397  	if w.state.ClientSide() {
   398  		frame = ws.MaskFrameInPlace(frame)
   399  	}
   400  
   401  	// Write header to the header segment of the raw buffer.
   402  	head := len(w.raw) - len(w.buf)
   403  	offset := head - ws.HeaderSize(frame.Header)
   404  	buf := bytesWriter{
   405  		buf: w.raw[offset:head],
   406  	}
   407  	if err := ws.WriteHeader(&buf, frame.Header); err != nil {
   408  		// Must never be reached.
   409  		panic("dump header error: " + err.Error())
   410  	}
   411  
   412  	_, err := w.dest.Write(w.raw[offset : head+w.n])
   413  
   414  	return err
   415  }
   416  
   417  func (w *Writer) opCode() ws.OpCode {
   418  	if w.fragmented {
   419  		return ws.OpContinuation
   420  	}
   421  
   422  	return w.op
   423  }
   424  
   425  var errNoSpace = fmt.Errorf("not enough buffer space")
   426  
   427  type bytesWriter struct {
   428  	buf []byte
   429  	pos int
   430  }
   431  
   432  func (w *bytesWriter) Write(p []byte) (int, error) {
   433  	n := copy(w.buf[w.pos:], p)
   434  	w.pos += n
   435  	if n != len(p) {
   436  		return n, errNoSpace
   437  	}
   438  
   439  	return n, nil
   440  }
   441  
   442  func writeFrame(w io.Writer, s ws.State, op ws.OpCode, fin bool, p []byte) error {
   443  	var frame ws.Frame
   444  	if s.ClientSide() {
   445  		// Should copy bytes to prevent corruption of caller data.
   446  		buf := pools.Buffer.GetLen(len(p))
   447  		defer pools.Buffer.Put(buf)
   448  		buf.CopyFrom(p)
   449  		frame = ws.NewFrame(op, fin, *buf.Bytes())
   450  		frame = ws.MaskFrameInPlace(frame)
   451  	} else {
   452  		frame = ws.NewFrame(op, fin, p)
   453  	}
   454  
   455  	return ws.WriteFrame(w, frame)
   456  }