github.com/ezoic/ws@v1.0.4-0.20220713205711-5c1d69e074c5/wsutil/writer.go (about)

     1  package wsutil
     2  
     3  import (
     4  	"fmt"
     5  	"io"
     6  
     7  	"github.com/ezoic/pool"
     8  	"github.com/ezoic/pool/pbytes"
     9  	"github.com/ezoic/ws"
    10  )
    11  
    12  // DefaultWriteBuffer contains size of Writer's default buffer. It used by
    13  // Writer constructor functions.
    14  var DefaultWriteBuffer = 4096
    15  
    16  var (
    17  	// ErrNotEmpty is returned by Writer.WriteThrough() to indicate that buffer is
    18  	// not empty and write through could not be done. That is, caller should call
    19  	// Writer.FlushFragment() to make buffer empty.
    20  	ErrNotEmpty = fmt.Errorf("writer not empty")
    21  
    22  	// ErrControlOverflow is returned by ControlWriter.Write() to indicate that
    23  	// no more data could be written to the underlying io.Writer because
    24  	// MaxControlFramePayloadSize limit is reached.
    25  	ErrControlOverflow = fmt.Errorf("control frame payload overflow")
    26  )
    27  
    28  // Constants which are represent frame length ranges.
    29  const (
    30  	len7  = int64(125) // 126 and 127 are reserved values
    31  	len16 = int64(^uint16(0))
    32  	len64 = int64((^uint64(0)) >> 1)
    33  )
    34  
    35  // ControlWriter is a wrapper around Writer that contains some guards for
    36  // buffered writes of control frames.
    37  type ControlWriter struct {
    38  	w     *Writer
    39  	limit int
    40  	n     int
    41  }
    42  
    43  // NewControlWriter contains ControlWriter with Writer inside whose buffer size
    44  // is at most ws.MaxControlFramePayloadSize + ws.MaxHeaderSize.
    45  func NewControlWriter(dest io.Writer, state ws.State, op ws.OpCode) *ControlWriter {
    46  	return &ControlWriter{
    47  		w:     NewWriterSize(dest, state, op, ws.MaxControlFramePayloadSize),
    48  		limit: ws.MaxControlFramePayloadSize,
    49  	}
    50  }
    51  
    52  // NewControlWriterBuffer returns a new ControlWriter with buf as a buffer.
    53  //
    54  // Note that it reserves x bytes of buf for header data, where x could be
    55  // ws.MinHeaderSize or ws.MinHeaderSize+4 (depending on state). At most
    56  // (ws.MaxControlFramePayloadSize + x) bytes of buf will be used.
    57  //
    58  // It panics if len(buf) <= ws.MinHeaderSize + x.
    59  func NewControlWriterBuffer(dest io.Writer, state ws.State, op ws.OpCode, buf []byte) *ControlWriter {
    60  	max := ws.MaxControlFramePayloadSize + headerSize(state, ws.MaxControlFramePayloadSize)
    61  	if len(buf) > max {
    62  		buf = buf[:max]
    63  	}
    64  
    65  	w := NewWriterBuffer(dest, state, op, buf)
    66  
    67  	return &ControlWriter{
    68  		w:     w,
    69  		limit: len(w.buf),
    70  	}
    71  }
    72  
    73  // Write implements io.Writer. It writes to the underlying Writer until it
    74  // returns error or until ControlWriter write limit will be exceeded.
    75  func (c *ControlWriter) Write(p []byte) (n int, err error) {
    76  	if c.n+len(p) > c.limit {
    77  		return 0, ErrControlOverflow
    78  	}
    79  	return c.w.Write(p)
    80  }
    81  
    82  // Flush flushes all buffered data to the underlying io.Writer.
    83  func (c *ControlWriter) Flush() error {
    84  	return c.w.Flush()
    85  }
    86  
    87  // Writer contains logic of buffering output data into a WebSocket fragments.
    88  // It is much the same as bufio.Writer, except the thing that it works with
    89  // WebSocket frames, not the raw data.
    90  //
    91  // Writer writes frames with specified OpCode.
    92  // It uses ws.State to decide whether the output frames must be masked.
    93  //
    94  // Note that it does not check control frame size or other RFC rules.
    95  // That is, it must be used with special care to write control frames without
    96  // violation of RFC. You could use ControlWriter that wraps Writer and contains
    97  // some guards for writing control frames.
    98  //
    99  // If an error occurs writing to a Writer, no more data will be accepted and
   100  // all subsequent writes will return the error.
   101  // After all data has been written, the client should call the Flush() method
   102  // to guarantee all data has been forwarded to the underlying io.Writer.
   103  type Writer struct {
   104  	dest io.Writer
   105  
   106  	n   int    // Buffered bytes counter.
   107  	raw []byte // Raw representation of buffer, including reserved header bytes.
   108  	buf []byte // Writeable part of buffer, without reserved header bytes.
   109  
   110  	op    ws.OpCode
   111  	state ws.State
   112  
   113  	dirty      bool
   114  	fragmented bool
   115  
   116  	err error
   117  }
   118  
   119  var writers = pool.New(128, 65536)
   120  
   121  // GetWriter tries to reuse Writer getting it from the pool.
   122  //
   123  // This function is intended for memory consumption optimizations, because
   124  // NewWriter*() functions make allocations for inner buffer.
   125  //
   126  // Note the it ceils n to the power of two.
   127  //
   128  // If you have your own bytes buffer pool you could use NewWriterBuffer to use
   129  // pooled bytes in writer.
   130  func GetWriter(dest io.Writer, state ws.State, op ws.OpCode, n int) *Writer {
   131  	x, m := writers.Get(n)
   132  	if x != nil {
   133  		w := x.(*Writer)
   134  		w.Reset(dest, state, op)
   135  		return w
   136  	}
   137  	// NOTE: we use m instead of n, because m is an attempt to reuse w of such
   138  	// size in the future.
   139  	return NewWriterBufferSize(dest, state, op, m)
   140  }
   141  
   142  // PutWriter puts w for future reuse by GetWriter().
   143  func PutWriter(w *Writer) {
   144  	w.Reset(nil, 0, 0)
   145  	writers.Put(w, w.Size())
   146  }
   147  
   148  // NewWriter returns a new Writer whose buffer has the DefaultWriteBuffer size.
   149  func NewWriter(dest io.Writer, state ws.State, op ws.OpCode) *Writer {
   150  	return NewWriterBufferSize(dest, state, op, 0)
   151  }
   152  
   153  // NewWriterSize returns a new Writer whose buffer size is at most n + ws.MaxHeaderSize.
   154  // That is, output frames payload length could be up to n, except the case when
   155  // Write() is called on empty Writer with len(p) > n.
   156  //
   157  // If n <= 0 then the default buffer size is used as Writer's buffer size.
   158  func NewWriterSize(dest io.Writer, state ws.State, op ws.OpCode, n int) *Writer {
   159  	if n > 0 {
   160  		n += headerSize(state, n)
   161  	}
   162  	return NewWriterBufferSize(dest, state, op, n)
   163  }
   164  
   165  // NewWriterBufferSize returns a new Writer whose buffer size is equal to n.
   166  // If n <= ws.MinHeaderSize then the default buffer size is used.
   167  //
   168  // Note that Writer will reserve x bytes for header data, where x is in range
   169  // [ws.MinHeaderSize,ws.MaxHeaderSize]. That is, frames flushed by Writer
   170  // will not have payload length equal to n, except the case when Write() is
   171  // called on empty Writer with len(p) > n.
   172  func NewWriterBufferSize(dest io.Writer, state ws.State, op ws.OpCode, n int) *Writer {
   173  	if n <= ws.MinHeaderSize {
   174  		n = DefaultWriteBuffer
   175  	}
   176  	return NewWriterBuffer(dest, state, op, make([]byte, n))
   177  }
   178  
   179  // NewWriterBuffer returns a new Writer with buf as a buffer.
   180  //
   181  // Note that it reserves x bytes of buf for header data, where x is in range
   182  // [ws.MinHeaderSize,ws.MaxHeaderSize] (depending on state and buf size).
   183  //
   184  // You could use ws.HeaderSize() to calculate number of bytes needed to store
   185  // header data.
   186  //
   187  // It panics if len(buf) is too small to fit header and payload data.
   188  func NewWriterBuffer(dest io.Writer, state ws.State, op ws.OpCode, buf []byte) *Writer {
   189  	offset := reserve(state, len(buf))
   190  	if len(buf) <= offset {
   191  		panic("buffer too small")
   192  	}
   193  
   194  	return &Writer{
   195  		dest:  dest,
   196  		raw:   buf,
   197  		buf:   buf[offset:],
   198  		state: state,
   199  		op:    op,
   200  	}
   201  }
   202  
   203  func reserve(state ws.State, n int) (offset int) {
   204  	var mask int
   205  	if state.ClientSide() {
   206  		mask = 4
   207  	}
   208  
   209  	switch {
   210  	case n <= int(len7)+mask+2:
   211  		return mask + 2
   212  	case n <= int(len16)+mask+4:
   213  		return mask + 4
   214  	default:
   215  		return mask + 10
   216  	}
   217  }
   218  
   219  // headerSize returns number of bytes needed to encode header of a frame with
   220  // given state and length.
   221  func headerSize(s ws.State, n int) int {
   222  	return ws.HeaderSize(ws.Header{
   223  		Length: int64(n),
   224  		Masked: s.ClientSide(),
   225  	})
   226  }
   227  
   228  // Reset discards any buffered data, clears error, and resets w to have given
   229  // state and write frames with given OpCode to dest.
   230  func (w *Writer) Reset(dest io.Writer, state ws.State, op ws.OpCode) {
   231  	w.n = 0
   232  	w.dirty = false
   233  	w.fragmented = false
   234  	w.dest = dest
   235  	w.state = state
   236  	w.op = op
   237  }
   238  
   239  // Size returns the size of the underlying buffer in bytes.
   240  func (w *Writer) Size() int {
   241  	return len(w.buf)
   242  }
   243  
   244  // Available returns how many bytes are unused in the buffer.
   245  func (w *Writer) Available() int {
   246  	return len(w.buf) - w.n
   247  }
   248  
   249  // Buffered returns the number of bytes that have been written into the current
   250  // buffer.
   251  func (w *Writer) Buffered() int {
   252  	return w.n
   253  }
   254  
   255  // Write implements io.Writer.
   256  //
   257  // Note that even if the Writer was created to have N-sized buffer, Write()
   258  // with payload of N bytes will not fit into that buffer. Writer reserves some
   259  // space to fit WebSocket header data.
   260  func (w *Writer) Write(p []byte) (n int, err error) {
   261  	// Even empty p may make a sense.
   262  	w.dirty = true
   263  
   264  	var nn int
   265  	for len(p) > w.Available() && w.err == nil {
   266  		if w.Buffered() == 0 {
   267  			// Large write, empty buffer. Write directly from p to avoid copy.
   268  			// Trade off here is that we make additional Write() to underlying
   269  			// io.Writer when writing frame header.
   270  			//
   271  			// On large buffers additional write is better than copying.
   272  			nn, _ = w.WriteThrough(p)
   273  		} else {
   274  			nn = copy(w.buf[w.n:], p)
   275  			w.n += nn
   276  			w.FlushFragment()
   277  		}
   278  		n += nn
   279  		p = p[nn:]
   280  	}
   281  	if w.err != nil {
   282  		return n, w.err
   283  	}
   284  	nn = copy(w.buf[w.n:], p)
   285  	w.n += nn
   286  	n += nn
   287  
   288  	// Even if w.Available() == 0 we will not flush buffer preventively because
   289  	// this could bring unwanted fragmentation. That is, user could create
   290  	// buffer with size that fits exactly all further Write() call, and then
   291  	// call Flush(), excepting that single and not fragmented frame will be
   292  	// sent. With preemptive flush this case will produce two frames – last one
   293  	// will be empty and just to set fin = true.
   294  
   295  	return n, w.err
   296  }
   297  
   298  // WriteThrough writes data bypassing the buffer.
   299  // Note that Writer's buffer must be empty before calling WriteThrough().
   300  func (w *Writer) WriteThrough(p []byte) (n int, err error) {
   301  	if w.err != nil {
   302  		return 0, w.err
   303  	}
   304  	if w.Buffered() != 0 {
   305  		return 0, ErrNotEmpty
   306  	}
   307  
   308  	w.err = writeFrame(w.dest, w.state, w.opCode(), false, p)
   309  	if w.err == nil {
   310  		n = len(p)
   311  	}
   312  
   313  	w.dirty = true
   314  	w.fragmented = true
   315  
   316  	return n, w.err
   317  }
   318  
   319  // ReadFrom implements io.ReaderFrom.
   320  func (w *Writer) ReadFrom(src io.Reader) (n int64, err error) {
   321  	var nn int
   322  	for err == nil {
   323  		if w.Available() == 0 {
   324  			err = w.FlushFragment()
   325  			continue
   326  		}
   327  
   328  		// We copy the behavior of bufio.Writer here.
   329  		// Also, from the docs on io.ReaderFrom:
   330  		//   ReadFrom reads data from r until EOF or error.
   331  		//
   332  		// See https://codereview.appspot.com/76400048/#ps1
   333  		const maxEmptyReads = 100
   334  		var nr int
   335  		for nr < maxEmptyReads {
   336  			nn, err = src.Read(w.buf[w.n:])
   337  			if nn != 0 || err != nil {
   338  				break
   339  			}
   340  			nr++
   341  		}
   342  		if nr == maxEmptyReads {
   343  			return n, io.ErrNoProgress
   344  		}
   345  
   346  		w.n += nn
   347  		n += int64(nn)
   348  	}
   349  	if err == io.EOF {
   350  		// NOTE: Do not flush preemptively.
   351  		// See the Write() sources for more info.
   352  		err = nil
   353  		w.dirty = true
   354  	}
   355  	return n, err
   356  }
   357  
   358  // Flush writes any buffered data to the underlying io.Writer.
   359  // It sends the frame with "fin" flag set to true.
   360  //
   361  // If no Write() or ReadFrom() was made, then Flush() does nothing.
   362  func (w *Writer) Flush() error {
   363  	if (!w.dirty && w.Buffered() == 0) || w.err != nil {
   364  		return w.err
   365  	}
   366  
   367  	w.err = w.flushFragment(true)
   368  	w.n = 0
   369  	w.dirty = false
   370  	w.fragmented = false
   371  
   372  	return w.err
   373  }
   374  
   375  // FlushFragment writes any buffered data to the underlying io.Writer.
   376  // It sends the frame with "fin" flag set to false.
   377  func (w *Writer) FlushFragment() error {
   378  	if w.Buffered() == 0 || w.err != nil {
   379  		return w.err
   380  	}
   381  
   382  	w.err = w.flushFragment(false)
   383  	w.n = 0
   384  	w.fragmented = true
   385  
   386  	return w.err
   387  }
   388  
   389  func (w *Writer) flushFragment(fin bool) error {
   390  	frame := ws.NewFrame(w.opCode(), fin, w.buf[:w.n])
   391  	if w.state.ClientSide() {
   392  		frame = ws.MaskFrameInPlace(frame)
   393  	}
   394  
   395  	// Write header to the header segment of the raw buffer.
   396  	head := len(w.raw) - len(w.buf)
   397  	offset := head - ws.HeaderSize(frame.Header)
   398  	buf := bytesWriter{
   399  		buf: w.raw[offset:head],
   400  	}
   401  	if err := ws.WriteHeader(&buf, frame.Header); err != nil {
   402  		// Must never be reached.
   403  		panic("dump header error: " + err.Error())
   404  	}
   405  
   406  	_, err := w.dest.Write(w.raw[offset : head+w.n])
   407  
   408  	return err
   409  }
   410  
   411  func (w *Writer) opCode() ws.OpCode {
   412  	if w.fragmented {
   413  		return ws.OpContinuation
   414  	}
   415  	return w.op
   416  }
   417  
   418  var errNoSpace = fmt.Errorf("not enough buffer space")
   419  
   420  type bytesWriter struct {
   421  	buf []byte
   422  	pos int
   423  }
   424  
   425  func (w *bytesWriter) Write(p []byte) (int, error) {
   426  	n := copy(w.buf[w.pos:], p)
   427  	w.pos += n
   428  	if n != len(p) {
   429  		return n, errNoSpace
   430  	}
   431  	return n, nil
   432  }
   433  
   434  func writeFrame(w io.Writer, s ws.State, op ws.OpCode, fin bool, p []byte) error {
   435  	var frame ws.Frame
   436  	if s.ClientSide() {
   437  		// Should copy bytes to prevent corruption of caller data.
   438  		payload := pbytes.GetLen(len(p))
   439  		defer pbytes.Put(payload)
   440  
   441  		copy(payload, p)
   442  
   443  		frame = ws.NewFrame(op, fin, payload)
   444  		frame = ws.MaskFrameInPlace(frame)
   445  	} else {
   446  		frame = ws.NewFrame(op, fin, p)
   447  	}
   448  
   449  	return ws.WriteFrame(w, frame)
   450  }