github.com/mdaxf/iac@v0.0.0-20240519030858-58a061660378/vendor_skip/nhooyr.io/websocket/conn_notjs.go (about)

     1  // +build !js
     2  
     3  package websocket
     4  
     5  import (
     6  	"bufio"
     7  	"context"
     8  	"errors"
     9  	"fmt"
    10  	"io"
    11  	"runtime"
    12  	"strconv"
    13  	"sync"
    14  	"sync/atomic"
    15  )
    16  
    17  // Conn represents a WebSocket connection.
    18  // All methods may be called concurrently except for Reader and Read.
    19  //
    20  // You must always read from the connection. Otherwise control
    21  // frames will not be handled. See Reader and CloseRead.
    22  //
    23  // Be sure to call Close on the connection when you
    24  // are finished with it to release associated resources.
    25  //
    26  // On any error from any method, the connection is closed
    27  // with an appropriate reason.
    28  type Conn struct {
    29  	subprotocol    string
    30  	rwc            io.ReadWriteCloser
    31  	client         bool
    32  	copts          *compressionOptions
    33  	flateThreshold int
    34  	br             *bufio.Reader
    35  	bw             *bufio.Writer
    36  
    37  	readTimeout  chan context.Context
    38  	writeTimeout chan context.Context
    39  
    40  	// Read state.
    41  	readMu            *mu
    42  	readHeaderBuf     [8]byte
    43  	readControlBuf    [maxControlPayload]byte
    44  	msgReader         *msgReader
    45  	readCloseFrameErr error
    46  
    47  	// Write state.
    48  	msgWriterState *msgWriterState
    49  	writeFrameMu   *mu
    50  	writeBuf       []byte
    51  	writeHeaderBuf [8]byte
    52  	writeHeader    header
    53  
    54  	closed     chan struct{}
    55  	closeMu    sync.Mutex
    56  	closeErr   error
    57  	wroteClose bool
    58  
    59  	pingCounter   int32
    60  	activePingsMu sync.Mutex
    61  	activePings   map[string]chan<- struct{}
    62  }
    63  
    64  type connConfig struct {
    65  	subprotocol    string
    66  	rwc            io.ReadWriteCloser
    67  	client         bool
    68  	copts          *compressionOptions
    69  	flateThreshold int
    70  
    71  	br *bufio.Reader
    72  	bw *bufio.Writer
    73  }
    74  
    75  func newConn(cfg connConfig) *Conn {
    76  	c := &Conn{
    77  		subprotocol:    cfg.subprotocol,
    78  		rwc:            cfg.rwc,
    79  		client:         cfg.client,
    80  		copts:          cfg.copts,
    81  		flateThreshold: cfg.flateThreshold,
    82  
    83  		br: cfg.br,
    84  		bw: cfg.bw,
    85  
    86  		readTimeout:  make(chan context.Context),
    87  		writeTimeout: make(chan context.Context),
    88  
    89  		closed:      make(chan struct{}),
    90  		activePings: make(map[string]chan<- struct{}),
    91  	}
    92  
    93  	c.readMu = newMu(c)
    94  	c.writeFrameMu = newMu(c)
    95  
    96  	c.msgReader = newMsgReader(c)
    97  
    98  	c.msgWriterState = newMsgWriterState(c)
    99  	if c.client {
   100  		c.writeBuf = extractBufioWriterBuf(c.bw, c.rwc)
   101  	}
   102  
   103  	if c.flate() && c.flateThreshold == 0 {
   104  		c.flateThreshold = 128
   105  		if !c.msgWriterState.flateContextTakeover() {
   106  			c.flateThreshold = 512
   107  		}
   108  	}
   109  
   110  	runtime.SetFinalizer(c, func(c *Conn) {
   111  		c.close(errors.New("connection garbage collected"))
   112  	})
   113  
   114  	go c.timeoutLoop()
   115  
   116  	return c
   117  }
   118  
   119  // Subprotocol returns the negotiated subprotocol.
   120  // An empty string means the default protocol.
   121  func (c *Conn) Subprotocol() string {
   122  	return c.subprotocol
   123  }
   124  
   125  func (c *Conn) close(err error) {
   126  	c.closeMu.Lock()
   127  	defer c.closeMu.Unlock()
   128  
   129  	if c.isClosed() {
   130  		return
   131  	}
   132  	c.setCloseErrLocked(err)
   133  	close(c.closed)
   134  	runtime.SetFinalizer(c, nil)
   135  
   136  	// Have to close after c.closed is closed to ensure any goroutine that wakes up
   137  	// from the connection being closed also sees that c.closed is closed and returns
   138  	// closeErr.
   139  	c.rwc.Close()
   140  
   141  	go func() {
   142  		c.msgWriterState.close()
   143  
   144  		c.msgReader.close()
   145  	}()
   146  }
   147  
   148  func (c *Conn) timeoutLoop() {
   149  	readCtx := context.Background()
   150  	writeCtx := context.Background()
   151  
   152  	for {
   153  		select {
   154  		case <-c.closed:
   155  			return
   156  
   157  		case writeCtx = <-c.writeTimeout:
   158  		case readCtx = <-c.readTimeout:
   159  
   160  		case <-readCtx.Done():
   161  			c.setCloseErr(fmt.Errorf("read timed out: %w", readCtx.Err()))
   162  			go c.writeError(StatusPolicyViolation, errors.New("timed out"))
   163  		case <-writeCtx.Done():
   164  			c.close(fmt.Errorf("write timed out: %w", writeCtx.Err()))
   165  			return
   166  		}
   167  	}
   168  }
   169  
   170  func (c *Conn) flate() bool {
   171  	return c.copts != nil
   172  }
   173  
   174  // Ping sends a ping to the peer and waits for a pong.
   175  // Use this to measure latency or ensure the peer is responsive.
   176  // Ping must be called concurrently with Reader as it does
   177  // not read from the connection but instead waits for a Reader call
   178  // to read the pong.
   179  //
   180  // TCP Keepalives should suffice for most use cases.
   181  func (c *Conn) Ping(ctx context.Context) error {
   182  	p := atomic.AddInt32(&c.pingCounter, 1)
   183  
   184  	err := c.ping(ctx, strconv.Itoa(int(p)))
   185  	if err != nil {
   186  		return fmt.Errorf("failed to ping: %w", err)
   187  	}
   188  	return nil
   189  }
   190  
   191  func (c *Conn) ping(ctx context.Context, p string) error {
   192  	pong := make(chan struct{}, 1)
   193  
   194  	c.activePingsMu.Lock()
   195  	c.activePings[p] = pong
   196  	c.activePingsMu.Unlock()
   197  
   198  	defer func() {
   199  		c.activePingsMu.Lock()
   200  		delete(c.activePings, p)
   201  		c.activePingsMu.Unlock()
   202  	}()
   203  
   204  	err := c.writeControl(ctx, opPing, []byte(p))
   205  	if err != nil {
   206  		return err
   207  	}
   208  
   209  	select {
   210  	case <-c.closed:
   211  		return c.closeErr
   212  	case <-ctx.Done():
   213  		err := fmt.Errorf("failed to wait for pong: %w", ctx.Err())
   214  		c.close(err)
   215  		return err
   216  	case <-pong:
   217  		return nil
   218  	}
   219  }
   220  
   221  type mu struct {
   222  	c  *Conn
   223  	ch chan struct{}
   224  }
   225  
   226  func newMu(c *Conn) *mu {
   227  	return &mu{
   228  		c:  c,
   229  		ch: make(chan struct{}, 1),
   230  	}
   231  }
   232  
   233  func (m *mu) forceLock() {
   234  	m.ch <- struct{}{}
   235  }
   236  
   237  func (m *mu) lock(ctx context.Context) error {
   238  	select {
   239  	case <-m.c.closed:
   240  		return m.c.closeErr
   241  	case <-ctx.Done():
   242  		err := fmt.Errorf("failed to acquire lock: %w", ctx.Err())
   243  		m.c.close(err)
   244  		return err
   245  	case m.ch <- struct{}{}:
   246  		// To make sure the connection is certainly alive.
   247  		// As it's possible the send on m.ch was selected
   248  		// over the receive on closed.
   249  		select {
   250  		case <-m.c.closed:
   251  			// Make sure to release.
   252  			m.unlock()
   253  			return m.c.closeErr
   254  		default:
   255  		}
   256  		return nil
   257  	}
   258  }
   259  
   260  func (m *mu) unlock() {
   261  	select {
   262  	case <-m.ch:
   263  	default:
   264  	}
   265  }