github.com/simonmittag/ws@v1.1.0-rc.5.0.20210419231947-82b846128245/wsutil/writer.go (about)

     1  package wsutil
     2  
     3  import (
     4  	"fmt"
     5  	"io"
     6  
     7  	"github.com/gobwas/pool"
     8  	"github.com/gobwas/pool/pbytes"
     9  	"github.com/simonmittag/ws"
    10  )
    11  
    12  // DefaultWriteBuffer contains size of Writer's default buffer. It used by
    13  // Writer constructor functions.
    14  var DefaultWriteBuffer = 4096
    15  
    16  var (
    17  	// ErrNotEmpty is returned by Writer.WriteThrough() to indicate that buffer is
    18  	// not empty and write through could not be done. That is, caller should call
    19  	// Writer.FlushFragment() to make buffer empty.
    20  	ErrNotEmpty = fmt.Errorf("writer not empty")
    21  
    22  	// ErrControlOverflow is returned by ControlWriter.Write() to indicate that
    23  	// no more data could be written to the underlying io.Writer because
    24  	// MaxControlFramePayloadSize limit is reached.
    25  	ErrControlOverflow = fmt.Errorf("control frame payload overflow")
    26  )
    27  
    28  // Constants which are represent frame length ranges.
    29  const (
    30  	len7  = int64(125) // 126 and 127 are reserved values
    31  	len16 = int64(^uint16(0))
    32  	len64 = int64((^uint64(0)) >> 1)
    33  )
    34  
    35  // ControlWriter is a wrapper around Writer that contains some guards for
    36  // buffered writes of control frames.
    37  type ControlWriter struct {
    38  	w     *Writer
    39  	limit int
    40  	n     int
    41  }
    42  
    43  // NewControlWriter contains ControlWriter with Writer inside whose buffer size
    44  // is at most ws.MaxControlFramePayloadSize + ws.MaxHeaderSize.
    45  func NewControlWriter(dest io.Writer, state ws.State, op ws.OpCode) *ControlWriter {
    46  	return &ControlWriter{
    47  		w:     NewWriterSize(dest, state, op, ws.MaxControlFramePayloadSize),
    48  		limit: ws.MaxControlFramePayloadSize,
    49  	}
    50  }
    51  
    52  // NewControlWriterBuffer returns a new ControlWriter with buf as a buffer.
    53  //
    54  // Note that it reserves x bytes of buf for header data, where x could be
    55  // ws.MinHeaderSize or ws.MinHeaderSize+4 (depending on state). At most
    56  // (ws.MaxControlFramePayloadSize + x) bytes of buf will be used.
    57  //
    58  // It panics if len(buf) <= ws.MinHeaderSize + x.
    59  func NewControlWriterBuffer(dest io.Writer, state ws.State, op ws.OpCode, buf []byte) *ControlWriter {
    60  	max := ws.MaxControlFramePayloadSize + headerSize(state, ws.MaxControlFramePayloadSize)
    61  	if len(buf) > max {
    62  		buf = buf[:max]
    63  	}
    64  
    65  	w := NewWriterBuffer(dest, state, op, buf)
    66  
    67  	return &ControlWriter{
    68  		w:     w,
    69  		limit: len(w.buf),
    70  	}
    71  }
    72  
    73  // Write implements io.Writer. It writes to the underlying Writer until it
    74  // returns error or until ControlWriter write limit will be exceeded.
    75  func (c *ControlWriter) Write(p []byte) (n int, err error) {
    76  	if c.n+len(p) > c.limit {
    77  		return 0, ErrControlOverflow
    78  	}
    79  	return c.w.Write(p)
    80  }
    81  
    82  // Flush flushes all buffered data to the underlying io.Writer.
    83  func (c *ControlWriter) Flush() error {
    84  	return c.w.Flush()
    85  }
    86  
    87  var writers = pool.New(128, 65536)
    88  
    89  // GetWriter tries to reuse Writer getting it from the pool.
    90  //
    91  // This function is intended for memory consumption optimizations, because
    92  // NewWriter*() functions make allocations for inner buffer.
    93  //
    94  // Note the it ceils n to the power of two.
    95  //
    96  // If you have your own bytes buffer pool you could use NewWriterBuffer to use
    97  // pooled bytes in writer.
    98  func GetWriter(dest io.Writer, state ws.State, op ws.OpCode, n int) *Writer {
    99  	x, m := writers.Get(n)
   100  	if x != nil {
   101  		w := x.(*Writer)
   102  		w.Reset(dest, state, op)
   103  		return w
   104  	}
   105  	// NOTE: we use m instead of n, because m is an attempt to reuse w of such
   106  	// size in the future.
   107  	return NewWriterBufferSize(dest, state, op, m)
   108  }
   109  
   110  // PutWriter puts w for future reuse by GetWriter().
   111  func PutWriter(w *Writer) {
   112  	w.Reset(nil, 0, 0)
   113  	writers.Put(w, w.Size())
   114  }
   115  
   116  // Writer contains logic of buffering output data into a WebSocket fragments.
   117  // It is much the same as bufio.Writer, except the thing that it works with
   118  // WebSocket frames, not the raw data.
   119  //
   120  // Writer writes frames with specified OpCode.
   121  // It uses ws.State to decide whether the output frames must be masked.
   122  //
   123  // Note that it does not check control frame size or other RFC rules.
   124  // That is, it must be used with special care to write control frames without
   125  // violation of RFC. You could use ControlWriter that wraps Writer and contains
   126  // some guards for writing control frames.
   127  //
   128  // If an error occurs writing to a Writer, no more data will be accepted and
   129  // all subsequent writes will return the error.
   130  //
   131  // After all data has been written, the client should call the Flush() method
   132  // to guarantee all data has been forwarded to the underlying io.Writer.
   133  type Writer struct {
   134  	// dest specifies a destination of buffer flushes.
   135  	dest io.Writer
   136  
   137  	// op specifies the WebSocket operation code used in flushed frames.
   138  	op ws.OpCode
   139  
   140  	// state specifies the state of the Writer.
   141  	state ws.State
   142  
   143  	// extensions is a list of negotiated extensions for writer Dest.
   144  	// It is used to meet the specs and set appropriate bits in fragment
   145  	// header RSV segment.
   146  	extensions []SendExtension
   147  
   148  	// noFlush reports whether buffer must grow instead of being flushed.
   149  	noFlush bool
   150  
   151  	// Raw representation of the buffer, including reserved header bytes.
   152  	raw []byte
   153  
   154  	// Writeable part of buffer, without reserved header bytes.
   155  	// Resetting this to nil will not result in reallocation if raw is not nil.
   156  	// And vice versa: if buf is not nil, then Writer is assumed as ready and
   157  	// initialized.
   158  	buf []byte
   159  
   160  	// Buffered bytes counter.
   161  	n int
   162  
   163  	dirty bool
   164  	fseq  int
   165  	err   error
   166  }
   167  
   168  // NewWriter returns a new Writer whose buffer has the DefaultWriteBuffer size.
   169  func NewWriter(dest io.Writer, state ws.State, op ws.OpCode) *Writer {
   170  	return NewWriterBufferSize(dest, state, op, 0)
   171  }
   172  
   173  // NewWriterSize returns a new Writer whose buffer size is at most n + ws.MaxHeaderSize.
   174  // That is, output frames payload length could be up to n, except the case when
   175  // Write() is called on empty Writer with len(p) > n.
   176  //
   177  // If n <= 0 then the default buffer size is used as Writer's buffer size.
   178  func NewWriterSize(dest io.Writer, state ws.State, op ws.OpCode, n int) *Writer {
   179  	if n > 0 {
   180  		n += headerSize(state, n)
   181  	}
   182  	return NewWriterBufferSize(dest, state, op, n)
   183  }
   184  
   185  // NewWriterBufferSize returns a new Writer whose buffer size is equal to n.
   186  // If n <= ws.MinHeaderSize then the default buffer size is used.
   187  //
   188  // Note that Writer will reserve x bytes for header data, where x is in range
   189  // [ws.MinHeaderSize,ws.MaxHeaderSize]. That is, frames flushed by Writer
   190  // will not have payload length equal to n, except the case when Write() is
   191  // called on empty Writer with len(p) > n.
   192  func NewWriterBufferSize(dest io.Writer, state ws.State, op ws.OpCode, n int) *Writer {
   193  	if n <= ws.MinHeaderSize {
   194  		n = DefaultWriteBuffer
   195  	}
   196  	return NewWriterBuffer(dest, state, op, make([]byte, n))
   197  }
   198  
   199  // NewWriterBuffer returns a new Writer with buf as a buffer.
   200  //
   201  // Note that it reserves x bytes of buf for header data, where x is in range
   202  // [ws.MinHeaderSize,ws.MaxHeaderSize] (depending on state and buf size).
   203  //
   204  // You could use ws.HeaderSize() to calculate number of bytes needed to store
   205  // header data.
   206  //
   207  // It panics if len(buf) is too small to fit header and payload data.
   208  func NewWriterBuffer(dest io.Writer, state ws.State, op ws.OpCode, buf []byte) *Writer {
   209  	w := &Writer{
   210  		dest:  dest,
   211  		state: state,
   212  		op:    op,
   213  		raw:   buf,
   214  	}
   215  	w.initBuf()
   216  	return w
   217  }
   218  
   219  func (w *Writer) initBuf() {
   220  	offset := reserve(w.state, len(w.raw))
   221  	if len(w.raw) <= offset {
   222  		panic("wsutil: writer buffer is too small")
   223  	}
   224  	w.buf = w.raw[offset:]
   225  }
   226  
   227  // Reset resets Writer as it was created by New() methods.
   228  // Note that Reset does reset extensions and other options was set after
   229  // Writer initialization.
   230  func (w *Writer) Reset(dest io.Writer, state ws.State, op ws.OpCode) {
   231  	w.dest = dest
   232  	w.state = state
   233  	w.op = op
   234  
   235  	w.initBuf()
   236  
   237  	w.n = 0
   238  	w.dirty = false
   239  	w.fseq = 0
   240  	w.extensions = w.extensions[:0]
   241  	w.noFlush = false
   242  }
   243  
   244  // ResetOp is an quick version of Reset().
   245  // ResetOp does reset unwritten fragments and does not reset results of
   246  // SetExtensions() or DisableFlush() methods.
   247  func (w *Writer) ResetOp(op ws.OpCode) {
   248  	w.op = op
   249  	w.n = 0
   250  	w.dirty = false
   251  	w.fseq = 0
   252  }
   253  
   254  // SetExtensions adds xs as extensions to be used during writes.
   255  func (w *Writer) SetExtensions(xs ...SendExtension) {
   256  	w.extensions = xs
   257  }
   258  
   259  // DisableFlush denies Writer to write fragments.
   260  func (w *Writer) DisableFlush() {
   261  	w.noFlush = true
   262  }
   263  
   264  // Size returns the size of the underlying buffer in bytes (not including
   265  // WebSocket header bytes).
   266  func (w *Writer) Size() int {
   267  	return len(w.buf)
   268  }
   269  
   270  // Available returns how many bytes are unused in the buffer.
   271  func (w *Writer) Available() int {
   272  	return len(w.buf) - w.n
   273  }
   274  
   275  // Buffered returns the number of bytes that have been written into the current
   276  // buffer.
   277  func (w *Writer) Buffered() int {
   278  	return w.n
   279  }
   280  
   281  // Write implements io.Writer.
   282  //
   283  // Note that even if the Writer was created to have N-sized buffer, Write()
   284  // with payload of N bytes will not fit into that buffer. Writer reserves some
   285  // space to fit WebSocket header data.
   286  func (w *Writer) Write(p []byte) (n int, err error) {
   287  	// Even empty p may make a sense.
   288  	w.dirty = true
   289  
   290  	var nn int
   291  	for len(p) > w.Available() && w.err == nil {
   292  		if w.noFlush {
   293  			w.Grow(len(p) - w.Available())
   294  			continue
   295  		}
   296  		if w.Buffered() == 0 {
   297  			// Large write, empty buffer. Write directly from p to avoid copy.
   298  			// Trade off here is that we make additional Write() to underlying
   299  			// io.Writer when writing frame header.
   300  			//
   301  			// On large buffers additional write is better than copying.
   302  			nn, _ = w.WriteThrough(p)
   303  		} else {
   304  			nn = copy(w.buf[w.n:], p)
   305  			w.n += nn
   306  			w.FlushFragment()
   307  		}
   308  		n += nn
   309  		p = p[nn:]
   310  	}
   311  	if w.err != nil {
   312  		return n, w.err
   313  	}
   314  	nn = copy(w.buf[w.n:], p)
   315  	w.n += nn
   316  	n += nn
   317  
   318  	// Even if w.Available() == 0 we will not flush buffer preventively because
   319  	// this could bring unwanted fragmentation. That is, user could create
   320  	// buffer with size that fits exactly all further Write() call, and then
   321  	// call Flush(), excepting that single and not fragmented frame will be
   322  	// sent. With preemptive flush this case will produce two frames – last one
   323  	// will be empty and just to set fin = true.
   324  
   325  	return n, w.err
   326  }
   327  
   328  func ceilPowerOfTwo(n int) int {
   329  	n |= n >> 1
   330  	n |= n >> 2
   331  	n |= n >> 4
   332  	n |= n >> 8
   333  	n |= n >> 16
   334  	n |= n >> 32
   335  	n++
   336  	return n
   337  }
   338  
   339  func (w *Writer) Grow(n int) {
   340  	var (
   341  		offset = len(w.raw) - len(w.buf)
   342  		size   = ceilPowerOfTwo(offset + w.n + n)
   343  	)
   344  	if size <= len(w.raw) {
   345  		panic("wsutil: buffer grow leads to its reduce")
   346  	}
   347  	p := make([]byte, size)
   348  	copy(p, w.raw[:offset+w.n])
   349  	w.raw = p
   350  	w.buf = w.raw[offset:]
   351  }
   352  
   353  // WriteThrough writes data bypassing the buffer.
   354  // Note that Writer's buffer must be empty before calling WriteThrough().
   355  func (w *Writer) WriteThrough(p []byte) (n int, err error) {
   356  	if w.err != nil {
   357  		return 0, w.err
   358  	}
   359  	if w.Buffered() != 0 {
   360  		return 0, ErrNotEmpty
   361  	}
   362  
   363  	var frame ws.Frame
   364  	frame.Header = ws.Header{
   365  		OpCode: w.opCode(),
   366  		Fin:    false,
   367  		Length: int64(len(p)),
   368  	}
   369  	for _, x := range w.extensions {
   370  		frame.Header, err = x.SetBits(frame.Header)
   371  		if err != nil {
   372  			return 0, err
   373  		}
   374  	}
   375  	if w.state.ClientSide() {
   376  		// Should copy bytes to prevent corruption of caller data.
   377  		payload := pbytes.GetLen(len(p))
   378  		defer pbytes.Put(payload)
   379  		copy(payload, p)
   380  
   381  		frame.Payload = payload
   382  		frame = ws.MaskFrameInPlace(frame)
   383  	} else {
   384  		frame.Payload = p
   385  	}
   386  
   387  	w.err = ws.WriteFrame(w.dest, frame)
   388  	if w.err == nil {
   389  		n = len(p)
   390  	}
   391  
   392  	w.dirty = true
   393  	w.fseq++
   394  
   395  	return n, w.err
   396  }
   397  
   398  // ReadFrom implements io.ReaderFrom.
   399  func (w *Writer) ReadFrom(src io.Reader) (n int64, err error) {
   400  	var nn int
   401  	for err == nil {
   402  		if w.Available() == 0 {
   403  			if w.noFlush {
   404  				w.Grow(w.Buffered()) // Twice bigger.
   405  			} else {
   406  				err = w.FlushFragment()
   407  			}
   408  			continue
   409  		}
   410  
   411  		// We copy the behavior of bufio.Writer here.
   412  		// Also, from the docs on io.ReaderFrom:
   413  		//   ReadFrom reads data from r until EOF or error.
   414  		//
   415  		// See https://codereview.appspot.com/76400048/#ps1
   416  		const maxEmptyReads = 100
   417  		var nr int
   418  		for nr < maxEmptyReads {
   419  			nn, err = src.Read(w.buf[w.n:])
   420  			if nn != 0 || err != nil {
   421  				break
   422  			}
   423  			nr++
   424  		}
   425  		if nr == maxEmptyReads {
   426  			return n, io.ErrNoProgress
   427  		}
   428  
   429  		w.n += nn
   430  		n += int64(nn)
   431  	}
   432  	if err == io.EOF {
   433  		// NOTE: Do not flush preemptively.
   434  		// See the Write() sources for more info.
   435  		err = nil
   436  		w.dirty = true
   437  	}
   438  	return n, err
   439  }
   440  
   441  // Flush writes any buffered data to the underlying io.Writer.
   442  // It sends the frame with "fin" flag set to true.
   443  //
   444  // If no Write() or ReadFrom() was made, then Flush() does nothing.
   445  func (w *Writer) Flush() error {
   446  	if (!w.dirty && w.Buffered() == 0) || w.err != nil {
   447  		return w.err
   448  	}
   449  
   450  	w.err = w.flushFragment(true)
   451  	w.n = 0
   452  	w.dirty = false
   453  	w.fseq = 0
   454  
   455  	return w.err
   456  }
   457  
   458  // FlushFragment writes any buffered data to the underlying io.Writer.
   459  // It sends the frame with "fin" flag set to false.
   460  func (w *Writer) FlushFragment() error {
   461  	if w.Buffered() == 0 || w.err != nil {
   462  		return w.err
   463  	}
   464  
   465  	w.err = w.flushFragment(false)
   466  	w.n = 0
   467  	w.fseq++
   468  
   469  	return w.err
   470  }
   471  
   472  func (w *Writer) flushFragment(fin bool) (err error) {
   473  	var (
   474  		payload = w.buf[:w.n]
   475  		header  = ws.Header{
   476  			OpCode: w.opCode(),
   477  			Fin:    fin,
   478  			Length: int64(len(payload)),
   479  		}
   480  	)
   481  	for _, ext := range w.extensions {
   482  		header, err = ext.SetBits(header)
   483  		if err != nil {
   484  			return err
   485  		}
   486  	}
   487  	if w.state.ClientSide() {
   488  		header.Masked = true
   489  		header.Mask = ws.NewMask()
   490  		ws.Cipher(payload, header.Mask, 0)
   491  	}
   492  	// Write header to the header segment of the raw buffer.
   493  	var (
   494  		offset = len(w.raw) - len(w.buf)
   495  		skip   = offset - ws.HeaderSize(header)
   496  	)
   497  	buf := bytesWriter{
   498  		buf: w.raw[skip:offset],
   499  	}
   500  	if err := ws.WriteHeader(&buf, header); err != nil {
   501  		// Must never be reached.
   502  		panic("dump header error: " + err.Error())
   503  	}
   504  	_, err = w.dest.Write(w.raw[skip : offset+w.n])
   505  	return err
   506  }
   507  
   508  func (w *Writer) opCode() ws.OpCode {
   509  	if w.fseq > 0 {
   510  		return ws.OpContinuation
   511  	}
   512  	return w.op
   513  }
   514  
   515  var errNoSpace = fmt.Errorf("not enough buffer space")
   516  
   517  type bytesWriter struct {
   518  	buf []byte
   519  	pos int
   520  }
   521  
   522  func (w *bytesWriter) Write(p []byte) (int, error) {
   523  	n := copy(w.buf[w.pos:], p)
   524  	w.pos += n
   525  	if n != len(p) {
   526  		return n, errNoSpace
   527  	}
   528  	return n, nil
   529  }
   530  
   531  func writeFrame(w io.Writer, s ws.State, op ws.OpCode, fin bool, p []byte) error {
   532  	var frame ws.Frame
   533  	if s.ClientSide() {
   534  		// Should copy bytes to prevent corruption of caller data.
   535  		payload := pbytes.GetLen(len(p))
   536  		defer pbytes.Put(payload)
   537  
   538  		copy(payload, p)
   539  
   540  		frame = ws.NewFrame(op, fin, payload)
   541  		frame = ws.MaskFrameInPlace(frame)
   542  	} else {
   543  		frame = ws.NewFrame(op, fin, p)
   544  	}
   545  
   546  	return ws.WriteFrame(w, frame)
   547  }
   548  
   549  func reserve(state ws.State, n int) (offset int) {
   550  	var mask int
   551  	if state.ClientSide() {
   552  		mask = 4
   553  	}
   554  
   555  	switch {
   556  	case n <= int(len7)+mask+2:
   557  		return mask + 2
   558  	case n <= int(len16)+mask+4:
   559  		return mask + 4
   560  	default:
   561  		return mask + 10
   562  	}
   563  }
   564  
   565  // headerSize returns number of bytes needed to encode header of a frame with
   566  // given state and length.
   567  func headerSize(s ws.State, n int) int {
   568  	return ws.HeaderSize(ws.Header{
   569  		Length: int64(n),
   570  		Masked: s.ClientSide(),
   571  	})
   572  }