github.com/mdaxf/iac@v0.0.0-20240519030858-58a061660378/vendor_skip/nhooyr.io/websocket/write.go (about)

     1  // +build !js
     2  
     3  package websocket
     4  
     5  import (
     6  	"bufio"
     7  	"context"
     8  	"crypto/rand"
     9  	"encoding/binary"
    10  	"errors"
    11  	"fmt"
    12  	"io"
    13  	"time"
    14  
    15  	"github.com/klauspost/compress/flate"
    16  
    17  	"nhooyr.io/websocket/internal/errd"
    18  )
    19  
    20  // Writer returns a writer bounded by the context that will write
    21  // a WebSocket message of type dataType to the connection.
    22  //
    23  // You must close the writer once you have written the entire message.
    24  //
    25  // Only one writer can be open at a time, multiple calls will block until the previous writer
    26  // is closed.
    27  func (c *Conn) Writer(ctx context.Context, typ MessageType) (io.WriteCloser, error) {
    28  	w, err := c.writer(ctx, typ)
    29  	if err != nil {
    30  		return nil, fmt.Errorf("failed to get writer: %w", err)
    31  	}
    32  	return w, nil
    33  }
    34  
    35  // Write writes a message to the connection.
    36  //
    37  // See the Writer method if you want to stream a message.
    38  //
    39  // If compression is disabled or the threshold is not met, then it
    40  // will write the message in a single frame.
    41  func (c *Conn) Write(ctx context.Context, typ MessageType, p []byte) error {
    42  	_, err := c.write(ctx, typ, p)
    43  	if err != nil {
    44  		return fmt.Errorf("failed to write msg: %w", err)
    45  	}
    46  	return nil
    47  }
    48  
    49  type msgWriter struct {
    50  	mw     *msgWriterState
    51  	closed bool
    52  }
    53  
    54  func (mw *msgWriter) Write(p []byte) (int, error) {
    55  	if mw.closed {
    56  		return 0, errors.New("cannot use closed writer")
    57  	}
    58  	return mw.mw.Write(p)
    59  }
    60  
    61  func (mw *msgWriter) Close() error {
    62  	if mw.closed {
    63  		return errors.New("cannot use closed writer")
    64  	}
    65  	mw.closed = true
    66  	return mw.mw.Close()
    67  }
    68  
    69  type msgWriterState struct {
    70  	c *Conn
    71  
    72  	mu      *mu
    73  	writeMu *mu
    74  
    75  	ctx    context.Context
    76  	opcode opcode
    77  	flate  bool
    78  
    79  	trimWriter *trimLastFourBytesWriter
    80  	dict       slidingWindow
    81  }
    82  
    83  func newMsgWriterState(c *Conn) *msgWriterState {
    84  	mw := &msgWriterState{
    85  		c:       c,
    86  		mu:      newMu(c),
    87  		writeMu: newMu(c),
    88  	}
    89  	return mw
    90  }
    91  
    92  func (mw *msgWriterState) ensureFlate() {
    93  	if mw.trimWriter == nil {
    94  		mw.trimWriter = &trimLastFourBytesWriter{
    95  			w: writerFunc(mw.write),
    96  		}
    97  	}
    98  
    99  	mw.dict.init(8192)
   100  	mw.flate = true
   101  }
   102  
   103  func (mw *msgWriterState) flateContextTakeover() bool {
   104  	if mw.c.client {
   105  		return !mw.c.copts.clientNoContextTakeover
   106  	}
   107  	return !mw.c.copts.serverNoContextTakeover
   108  }
   109  
   110  func (c *Conn) writer(ctx context.Context, typ MessageType) (io.WriteCloser, error) {
   111  	err := c.msgWriterState.reset(ctx, typ)
   112  	if err != nil {
   113  		return nil, err
   114  	}
   115  	return &msgWriter{
   116  		mw:     c.msgWriterState,
   117  		closed: false,
   118  	}, nil
   119  }
   120  
   121  func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) (int, error) {
   122  	mw, err := c.writer(ctx, typ)
   123  	if err != nil {
   124  		return 0, err
   125  	}
   126  
   127  	if !c.flate() {
   128  		defer c.msgWriterState.mu.unlock()
   129  		return c.writeFrame(ctx, true, false, c.msgWriterState.opcode, p)
   130  	}
   131  
   132  	n, err := mw.Write(p)
   133  	if err != nil {
   134  		return n, err
   135  	}
   136  
   137  	err = mw.Close()
   138  	return n, err
   139  }
   140  
   141  func (mw *msgWriterState) reset(ctx context.Context, typ MessageType) error {
   142  	err := mw.mu.lock(ctx)
   143  	if err != nil {
   144  		return err
   145  	}
   146  
   147  	mw.ctx = ctx
   148  	mw.opcode = opcode(typ)
   149  	mw.flate = false
   150  
   151  	mw.trimWriter.reset()
   152  
   153  	return nil
   154  }
   155  
   156  // Write writes the given bytes to the WebSocket connection.
   157  func (mw *msgWriterState) Write(p []byte) (_ int, err error) {
   158  	err = mw.writeMu.lock(mw.ctx)
   159  	if err != nil {
   160  		return 0, fmt.Errorf("failed to write: %w", err)
   161  	}
   162  	defer mw.writeMu.unlock()
   163  
   164  	defer func() {
   165  		if err != nil {
   166  			err = fmt.Errorf("failed to write: %w", err)
   167  			mw.c.close(err)
   168  		}
   169  	}()
   170  
   171  	if mw.c.flate() {
   172  		// Only enables flate if the length crosses the
   173  		// threshold on the first frame
   174  		if mw.opcode != opContinuation && len(p) >= mw.c.flateThreshold {
   175  			mw.ensureFlate()
   176  		}
   177  	}
   178  
   179  	if mw.flate {
   180  		err = flate.StatelessDeflate(mw.trimWriter, p, false, mw.dict.buf)
   181  		if err != nil {
   182  			return 0, err
   183  		}
   184  		mw.dict.write(p)
   185  		return len(p), nil
   186  	}
   187  
   188  	return mw.write(p)
   189  }
   190  
   191  func (mw *msgWriterState) write(p []byte) (int, error) {
   192  	n, err := mw.c.writeFrame(mw.ctx, false, mw.flate, mw.opcode, p)
   193  	if err != nil {
   194  		return n, fmt.Errorf("failed to write data frame: %w", err)
   195  	}
   196  	mw.opcode = opContinuation
   197  	return n, nil
   198  }
   199  
   200  // Close flushes the frame to the connection.
   201  func (mw *msgWriterState) Close() (err error) {
   202  	defer errd.Wrap(&err, "failed to close writer")
   203  
   204  	err = mw.writeMu.lock(mw.ctx)
   205  	if err != nil {
   206  		return err
   207  	}
   208  	defer mw.writeMu.unlock()
   209  
   210  	_, err = mw.c.writeFrame(mw.ctx, true, mw.flate, mw.opcode, nil)
   211  	if err != nil {
   212  		return fmt.Errorf("failed to write fin frame: %w", err)
   213  	}
   214  
   215  	if mw.flate && !mw.flateContextTakeover() {
   216  		mw.dict.close()
   217  	}
   218  	mw.mu.unlock()
   219  	return nil
   220  }
   221  
   222  func (mw *msgWriterState) close() {
   223  	if mw.c.client {
   224  		mw.c.writeFrameMu.forceLock()
   225  		putBufioWriter(mw.c.bw)
   226  	}
   227  
   228  	mw.writeMu.forceLock()
   229  	mw.dict.close()
   230  }
   231  
   232  func (c *Conn) writeControl(ctx context.Context, opcode opcode, p []byte) error {
   233  	ctx, cancel := context.WithTimeout(ctx, time.Second*5)
   234  	defer cancel()
   235  
   236  	_, err := c.writeFrame(ctx, true, false, opcode, p)
   237  	if err != nil {
   238  		return fmt.Errorf("failed to write control frame %v: %w", opcode, err)
   239  	}
   240  	return nil
   241  }
   242  
   243  // frame handles all writes to the connection.
   244  func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opcode, p []byte) (_ int, err error) {
   245  	err = c.writeFrameMu.lock(ctx)
   246  	if err != nil {
   247  		return 0, err
   248  	}
   249  	defer c.writeFrameMu.unlock()
   250  
   251  	// If the state says a close has already been written, we wait until
   252  	// the connection is closed and return that error.
   253  	//
   254  	// However, if the frame being written is a close, that means its the close from
   255  	// the state being set so we let it go through.
   256  	c.closeMu.Lock()
   257  	wroteClose := c.wroteClose
   258  	c.closeMu.Unlock()
   259  	if wroteClose && opcode != opClose {
   260  		select {
   261  		case <-ctx.Done():
   262  			return 0, ctx.Err()
   263  		case <-c.closed:
   264  			return 0, c.closeErr
   265  		}
   266  	}
   267  
   268  	select {
   269  	case <-c.closed:
   270  		return 0, c.closeErr
   271  	case c.writeTimeout <- ctx:
   272  	}
   273  
   274  	defer func() {
   275  		if err != nil {
   276  			select {
   277  			case <-c.closed:
   278  				err = c.closeErr
   279  			case <-ctx.Done():
   280  				err = ctx.Err()
   281  			}
   282  			c.close(err)
   283  			err = fmt.Errorf("failed to write frame: %w", err)
   284  		}
   285  	}()
   286  
   287  	c.writeHeader.fin = fin
   288  	c.writeHeader.opcode = opcode
   289  	c.writeHeader.payloadLength = int64(len(p))
   290  
   291  	if c.client {
   292  		c.writeHeader.masked = true
   293  		_, err = io.ReadFull(rand.Reader, c.writeHeaderBuf[:4])
   294  		if err != nil {
   295  			return 0, fmt.Errorf("failed to generate masking key: %w", err)
   296  		}
   297  		c.writeHeader.maskKey = binary.LittleEndian.Uint32(c.writeHeaderBuf[:])
   298  	}
   299  
   300  	c.writeHeader.rsv1 = false
   301  	if flate && (opcode == opText || opcode == opBinary) {
   302  		c.writeHeader.rsv1 = true
   303  	}
   304  
   305  	err = writeFrameHeader(c.writeHeader, c.bw, c.writeHeaderBuf[:])
   306  	if err != nil {
   307  		return 0, err
   308  	}
   309  
   310  	n, err := c.writeFramePayload(p)
   311  	if err != nil {
   312  		return n, err
   313  	}
   314  
   315  	if c.writeHeader.fin {
   316  		err = c.bw.Flush()
   317  		if err != nil {
   318  			return n, fmt.Errorf("failed to flush: %w", err)
   319  		}
   320  	}
   321  
   322  	select {
   323  	case <-c.closed:
   324  		return n, c.closeErr
   325  	case c.writeTimeout <- context.Background():
   326  	}
   327  
   328  	return n, nil
   329  }
   330  
   331  func (c *Conn) writeFramePayload(p []byte) (n int, err error) {
   332  	defer errd.Wrap(&err, "failed to write frame payload")
   333  
   334  	if !c.writeHeader.masked {
   335  		return c.bw.Write(p)
   336  	}
   337  
   338  	maskKey := c.writeHeader.maskKey
   339  	for len(p) > 0 {
   340  		// If the buffer is full, we need to flush.
   341  		if c.bw.Available() == 0 {
   342  			err = c.bw.Flush()
   343  			if err != nil {
   344  				return n, err
   345  			}
   346  		}
   347  
   348  		// Start of next write in the buffer.
   349  		i := c.bw.Buffered()
   350  
   351  		j := len(p)
   352  		if j > c.bw.Available() {
   353  			j = c.bw.Available()
   354  		}
   355  
   356  		_, err := c.bw.Write(p[:j])
   357  		if err != nil {
   358  			return n, err
   359  		}
   360  
   361  		maskKey = mask(maskKey, c.writeBuf[i:c.bw.Buffered()])
   362  
   363  		p = p[j:]
   364  		n += j
   365  	}
   366  
   367  	return n, nil
   368  }
   369  
   370  type writerFunc func(p []byte) (int, error)
   371  
   372  func (f writerFunc) Write(p []byte) (int, error) {
   373  	return f(p)
   374  }
   375  
   376  // extractBufioWriterBuf grabs the []byte backing a *bufio.Writer
   377  // and returns it.
   378  func extractBufioWriterBuf(bw *bufio.Writer, w io.Writer) []byte {
   379  	var writeBuf []byte
   380  	bw.Reset(writerFunc(func(p2 []byte) (int, error) {
   381  		writeBuf = p2[:cap(p2)]
   382  		return len(p2), nil
   383  	}))
   384  
   385  	bw.WriteByte(0)
   386  	bw.Flush()
   387  
   388  	bw.Reset(w)
   389  
   390  	return writeBuf
   391  }
   392  
   393  func (c *Conn) writeError(code StatusCode, err error) {
   394  	c.setCloseErr(err)
   395  	c.writeClose(code, err.Error())
   396  	c.close(nil)
   397  }