github.com/mdaxf/iac@v0.0.0-20240519030858-58a061660378/vendor_skip/nhooyr.io/websocket/close_notjs.go (about)

     1  // +build !js
     2  
     3  package websocket
     4  
     5  import (
     6  	"context"
     7  	"encoding/binary"
     8  	"errors"
     9  	"fmt"
    10  	"log"
    11  	"time"
    12  
    13  	"nhooyr.io/websocket/internal/errd"
    14  )
    15  
    16  // Close performs the WebSocket close handshake with the given status code and reason.
    17  //
    18  // It will write a WebSocket close frame with a timeout of 5s and then wait 5s for
    19  // the peer to send a close frame.
    20  // All data messages received from the peer during the close handshake will be discarded.
    21  //
    22  // The connection can only be closed once. Additional calls to Close
    23  // are no-ops.
    24  //
    25  // The maximum length of reason must be 125 bytes. Avoid
    26  // sending a dynamic reason.
    27  //
    28  // Close will unblock all goroutines interacting with the connection once
    29  // complete.
    30  func (c *Conn) Close(code StatusCode, reason string) error {
    31  	return c.closeHandshake(code, reason)
    32  }
    33  
    34  func (c *Conn) closeHandshake(code StatusCode, reason string) (err error) {
    35  	defer errd.Wrap(&err, "failed to close WebSocket")
    36  
    37  	writeErr := c.writeClose(code, reason)
    38  	closeHandshakeErr := c.waitCloseHandshake()
    39  
    40  	if writeErr != nil {
    41  		return writeErr
    42  	}
    43  
    44  	if CloseStatus(closeHandshakeErr) == -1 {
    45  		return closeHandshakeErr
    46  	}
    47  
    48  	return nil
    49  }
    50  
    51  var errAlreadyWroteClose = errors.New("already wrote close")
    52  
    53  func (c *Conn) writeClose(code StatusCode, reason string) error {
    54  	c.closeMu.Lock()
    55  	wroteClose := c.wroteClose
    56  	c.wroteClose = true
    57  	c.closeMu.Unlock()
    58  	if wroteClose {
    59  		return errAlreadyWroteClose
    60  	}
    61  
    62  	ce := CloseError{
    63  		Code:   code,
    64  		Reason: reason,
    65  	}
    66  
    67  	var p []byte
    68  	var marshalErr error
    69  	if ce.Code != StatusNoStatusRcvd {
    70  		p, marshalErr = ce.bytes()
    71  		if marshalErr != nil {
    72  			log.Printf("websocket: %v", marshalErr)
    73  		}
    74  	}
    75  
    76  	writeErr := c.writeControl(context.Background(), opClose, p)
    77  	if CloseStatus(writeErr) != -1 {
    78  		// Not a real error if it's due to a close frame being received.
    79  		writeErr = nil
    80  	}
    81  
    82  	// We do this after in case there was an error writing the close frame.
    83  	c.setCloseErr(fmt.Errorf("sent close frame: %w", ce))
    84  
    85  	if marshalErr != nil {
    86  		return marshalErr
    87  	}
    88  	return writeErr
    89  }
    90  
    91  func (c *Conn) waitCloseHandshake() error {
    92  	defer c.close(nil)
    93  
    94  	ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
    95  	defer cancel()
    96  
    97  	err := c.readMu.lock(ctx)
    98  	if err != nil {
    99  		return err
   100  	}
   101  	defer c.readMu.unlock()
   102  
   103  	if c.readCloseFrameErr != nil {
   104  		return c.readCloseFrameErr
   105  	}
   106  
   107  	for {
   108  		h, err := c.readLoop(ctx)
   109  		if err != nil {
   110  			return err
   111  		}
   112  
   113  		for i := int64(0); i < h.payloadLength; i++ {
   114  			_, err := c.br.ReadByte()
   115  			if err != nil {
   116  				return err
   117  			}
   118  		}
   119  	}
   120  }
   121  
   122  func parseClosePayload(p []byte) (CloseError, error) {
   123  	if len(p) == 0 {
   124  		return CloseError{
   125  			Code: StatusNoStatusRcvd,
   126  		}, nil
   127  	}
   128  
   129  	if len(p) < 2 {
   130  		return CloseError{}, fmt.Errorf("close payload %q too small, cannot even contain the 2 byte status code", p)
   131  	}
   132  
   133  	ce := CloseError{
   134  		Code:   StatusCode(binary.BigEndian.Uint16(p)),
   135  		Reason: string(p[2:]),
   136  	}
   137  
   138  	if !validWireCloseCode(ce.Code) {
   139  		return CloseError{}, fmt.Errorf("invalid status code %v", ce.Code)
   140  	}
   141  
   142  	return ce, nil
   143  }
   144  
   145  // See http://www.iana.org/assignments/websocket/websocket.xhtml#close-code-number
   146  // and https://tools.ietf.org/html/rfc6455#section-7.4.1
   147  func validWireCloseCode(code StatusCode) bool {
   148  	switch code {
   149  	case statusReserved, StatusNoStatusRcvd, StatusAbnormalClosure, StatusTLSHandshake:
   150  		return false
   151  	}
   152  
   153  	if code >= StatusNormalClosure && code <= StatusBadGateway {
   154  		return true
   155  	}
   156  	if code >= 3000 && code <= 4999 {
   157  		return true
   158  	}
   159  
   160  	return false
   161  }
   162  
   163  func (ce CloseError) bytes() ([]byte, error) {
   164  	p, err := ce.bytesErr()
   165  	if err != nil {
   166  		err = fmt.Errorf("failed to marshal close frame: %w", err)
   167  		ce = CloseError{
   168  			Code: StatusInternalError,
   169  		}
   170  		p, _ = ce.bytesErr()
   171  	}
   172  	return p, err
   173  }
   174  
   175  const maxCloseReason = maxControlPayload - 2
   176  
   177  func (ce CloseError) bytesErr() ([]byte, error) {
   178  	if len(ce.Reason) > maxCloseReason {
   179  		return nil, fmt.Errorf("reason string max is %v but got %q with length %v", maxCloseReason, ce.Reason, len(ce.Reason))
   180  	}
   181  
   182  	if !validWireCloseCode(ce.Code) {
   183  		return nil, fmt.Errorf("status code %v cannot be set", ce.Code)
   184  	}
   185  
   186  	buf := make([]byte, 2+len(ce.Reason))
   187  	binary.BigEndian.PutUint16(buf, uint16(ce.Code))
   188  	copy(buf[2:], ce.Reason)
   189  	return buf, nil
   190  }
   191  
   192  func (c *Conn) setCloseErr(err error) {
   193  	c.closeMu.Lock()
   194  	c.setCloseErrLocked(err)
   195  	c.closeMu.Unlock()
   196  }
   197  
   198  func (c *Conn) setCloseErrLocked(err error) {
   199  	if c.closeErr == nil {
   200  		c.closeErr = fmt.Errorf("WebSocket closed: %w", err)
   201  	}
   202  }
   203  
   204  func (c *Conn) isClosed() bool {
   205  	select {
   206  	case <-c.closed:
   207  		return true
   208  	default:
   209  		return false
   210  	}
   211  }