github.com/mdaxf/iac@v0.0.0-20240519030858-58a061660378/vendor_skip/nhooyr.io/websocket/ws_js.go (about)

     1  package websocket // import "nhooyr.io/websocket"
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"errors"
     7  	"fmt"
     8  	"io"
     9  	"net/http"
    10  	"reflect"
    11  	"runtime"
    12  	"strings"
    13  	"sync"
    14  	"syscall/js"
    15  
    16  	"nhooyr.io/websocket/internal/bpool"
    17  	"nhooyr.io/websocket/internal/wsjs"
    18  	"nhooyr.io/websocket/internal/xsync"
    19  )
    20  
    21  // Conn provides a wrapper around the browser WebSocket API.
    22  type Conn struct {
    23  	ws wsjs.WebSocket
    24  
    25  	// read limit for a message in bytes.
    26  	msgReadLimit xsync.Int64
    27  
    28  	closingMu     sync.Mutex
    29  	isReadClosed  xsync.Int64
    30  	closeOnce     sync.Once
    31  	closed        chan struct{}
    32  	closeErrOnce  sync.Once
    33  	closeErr      error
    34  	closeWasClean bool
    35  
    36  	releaseOnClose   func()
    37  	releaseOnMessage func()
    38  
    39  	readSignal chan struct{}
    40  	readBufMu  sync.Mutex
    41  	readBuf    []wsjs.MessageEvent
    42  }
    43  
    44  func (c *Conn) close(err error, wasClean bool) {
    45  	c.closeOnce.Do(func() {
    46  		runtime.SetFinalizer(c, nil)
    47  
    48  		if !wasClean {
    49  			err = fmt.Errorf("unclean connection close: %w", err)
    50  		}
    51  		c.setCloseErr(err)
    52  		c.closeWasClean = wasClean
    53  		close(c.closed)
    54  	})
    55  }
    56  
    57  func (c *Conn) init() {
    58  	c.closed = make(chan struct{})
    59  	c.readSignal = make(chan struct{}, 1)
    60  
    61  	c.msgReadLimit.Store(32768)
    62  
    63  	c.releaseOnClose = c.ws.OnClose(func(e wsjs.CloseEvent) {
    64  		err := CloseError{
    65  			Code:   StatusCode(e.Code),
    66  			Reason: e.Reason,
    67  		}
    68  		// We do not know if we sent or received this close as
    69  		// its possible the browser triggered it without us
    70  		// explicitly sending it.
    71  		c.close(err, e.WasClean)
    72  
    73  		c.releaseOnClose()
    74  		c.releaseOnMessage()
    75  	})
    76  
    77  	c.releaseOnMessage = c.ws.OnMessage(func(e wsjs.MessageEvent) {
    78  		c.readBufMu.Lock()
    79  		defer c.readBufMu.Unlock()
    80  
    81  		c.readBuf = append(c.readBuf, e)
    82  
    83  		// Lets the read goroutine know there is definitely something in readBuf.
    84  		select {
    85  		case c.readSignal <- struct{}{}:
    86  		default:
    87  		}
    88  	})
    89  
    90  	runtime.SetFinalizer(c, func(c *Conn) {
    91  		c.setCloseErr(errors.New("connection garbage collected"))
    92  		c.closeWithInternal()
    93  	})
    94  }
    95  
    96  func (c *Conn) closeWithInternal() {
    97  	c.Close(StatusInternalError, "something went wrong")
    98  }
    99  
   100  // Read attempts to read a message from the connection.
   101  // The maximum time spent waiting is bounded by the context.
   102  func (c *Conn) Read(ctx context.Context) (MessageType, []byte, error) {
   103  	if c.isReadClosed.Load() == 1 {
   104  		return 0, nil, errors.New("WebSocket connection read closed")
   105  	}
   106  
   107  	typ, p, err := c.read(ctx)
   108  	if err != nil {
   109  		return 0, nil, fmt.Errorf("failed to read: %w", err)
   110  	}
   111  	if int64(len(p)) > c.msgReadLimit.Load() {
   112  		err := fmt.Errorf("read limited at %v bytes", c.msgReadLimit.Load())
   113  		c.Close(StatusMessageTooBig, err.Error())
   114  		return 0, nil, err
   115  	}
   116  	return typ, p, nil
   117  }
   118  
   119  func (c *Conn) read(ctx context.Context) (MessageType, []byte, error) {
   120  	select {
   121  	case <-ctx.Done():
   122  		c.Close(StatusPolicyViolation, "read timed out")
   123  		return 0, nil, ctx.Err()
   124  	case <-c.readSignal:
   125  	case <-c.closed:
   126  		return 0, nil, c.closeErr
   127  	}
   128  
   129  	c.readBufMu.Lock()
   130  	defer c.readBufMu.Unlock()
   131  
   132  	me := c.readBuf[0]
   133  	// We copy the messages forward and decrease the size
   134  	// of the slice to avoid reallocating.
   135  	copy(c.readBuf, c.readBuf[1:])
   136  	c.readBuf = c.readBuf[:len(c.readBuf)-1]
   137  
   138  	if len(c.readBuf) > 0 {
   139  		// Next time we read, we'll grab the message.
   140  		select {
   141  		case c.readSignal <- struct{}{}:
   142  		default:
   143  		}
   144  	}
   145  
   146  	switch p := me.Data.(type) {
   147  	case string:
   148  		return MessageText, []byte(p), nil
   149  	case []byte:
   150  		return MessageBinary, p, nil
   151  	default:
   152  		panic("websocket: unexpected data type from wsjs OnMessage: " + reflect.TypeOf(me.Data).String())
   153  	}
   154  }
   155  
   156  // Ping is mocked out for Wasm.
   157  func (c *Conn) Ping(ctx context.Context) error {
   158  	return nil
   159  }
   160  
   161  // Write writes a message of the given type to the connection.
   162  // Always non blocking.
   163  func (c *Conn) Write(ctx context.Context, typ MessageType, p []byte) error {
   164  	err := c.write(ctx, typ, p)
   165  	if err != nil {
   166  		// Have to ensure the WebSocket is closed after a write error
   167  		// to match the Go API. It can only error if the message type
   168  		// is unexpected or the passed bytes contain invalid UTF-8 for
   169  		// MessageText.
   170  		err := fmt.Errorf("failed to write: %w", err)
   171  		c.setCloseErr(err)
   172  		c.closeWithInternal()
   173  		return err
   174  	}
   175  	return nil
   176  }
   177  
   178  func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) error {
   179  	if c.isClosed() {
   180  		return c.closeErr
   181  	}
   182  	switch typ {
   183  	case MessageBinary:
   184  		return c.ws.SendBytes(p)
   185  	case MessageText:
   186  		return c.ws.SendText(string(p))
   187  	default:
   188  		return fmt.Errorf("unexpected message type: %v", typ)
   189  	}
   190  }
   191  
   192  // Close closes the WebSocket with the given code and reason.
   193  // It will wait until the peer responds with a close frame
   194  // or the connection is closed.
   195  // It thus performs the full WebSocket close handshake.
   196  func (c *Conn) Close(code StatusCode, reason string) error {
   197  	err := c.exportedClose(code, reason)
   198  	if err != nil {
   199  		return fmt.Errorf("failed to close WebSocket: %w", err)
   200  	}
   201  	return nil
   202  }
   203  
   204  func (c *Conn) exportedClose(code StatusCode, reason string) error {
   205  	c.closingMu.Lock()
   206  	defer c.closingMu.Unlock()
   207  
   208  	ce := fmt.Errorf("sent close: %w", CloseError{
   209  		Code:   code,
   210  		Reason: reason,
   211  	})
   212  
   213  	if c.isClosed() {
   214  		return fmt.Errorf("tried to close with %q but connection already closed: %w", ce, c.closeErr)
   215  	}
   216  
   217  	c.setCloseErr(ce)
   218  	err := c.ws.Close(int(code), reason)
   219  	if err != nil {
   220  		return err
   221  	}
   222  
   223  	<-c.closed
   224  	if !c.closeWasClean {
   225  		return c.closeErr
   226  	}
   227  	return nil
   228  }
   229  
   230  // Subprotocol returns the negotiated subprotocol.
   231  // An empty string means the default protocol.
   232  func (c *Conn) Subprotocol() string {
   233  	return c.ws.Subprotocol()
   234  }
   235  
   236  // DialOptions represents the options available to pass to Dial.
   237  type DialOptions struct {
   238  	// Subprotocols lists the subprotocols to negotiate with the server.
   239  	Subprotocols []string
   240  }
   241  
   242  // Dial creates a new WebSocket connection to the given url with the given options.
   243  // The passed context bounds the maximum time spent waiting for the connection to open.
   244  // The returned *http.Response is always nil or a mock. It's only in the signature
   245  // to match the core API.
   246  func Dial(ctx context.Context, url string, opts *DialOptions) (*Conn, *http.Response, error) {
   247  	c, resp, err := dial(ctx, url, opts)
   248  	if err != nil {
   249  		return nil, nil, fmt.Errorf("failed to WebSocket dial %q: %w", url, err)
   250  	}
   251  	return c, resp, nil
   252  }
   253  
   254  func dial(ctx context.Context, url string, opts *DialOptions) (*Conn, *http.Response, error) {
   255  	if opts == nil {
   256  		opts = &DialOptions{}
   257  	}
   258  
   259  	url = strings.Replace(url, "http://", "ws://", 1)
   260  	url = strings.Replace(url, "https://", "wss://", 1)
   261  
   262  	ws, err := wsjs.New(url, opts.Subprotocols)
   263  	if err != nil {
   264  		return nil, nil, err
   265  	}
   266  
   267  	c := &Conn{
   268  		ws: ws,
   269  	}
   270  	c.init()
   271  
   272  	opench := make(chan struct{})
   273  	releaseOpen := ws.OnOpen(func(e js.Value) {
   274  		close(opench)
   275  	})
   276  	defer releaseOpen()
   277  
   278  	select {
   279  	case <-ctx.Done():
   280  		c.Close(StatusPolicyViolation, "dial timed out")
   281  		return nil, nil, ctx.Err()
   282  	case <-opench:
   283  		return c, &http.Response{
   284  			StatusCode: http.StatusSwitchingProtocols,
   285  		}, nil
   286  	case <-c.closed:
   287  		return nil, nil, c.closeErr
   288  	}
   289  }
   290  
   291  // Reader attempts to read a message from the connection.
   292  // The maximum time spent waiting is bounded by the context.
   293  func (c *Conn) Reader(ctx context.Context) (MessageType, io.Reader, error) {
   294  	typ, p, err := c.Read(ctx)
   295  	if err != nil {
   296  		return 0, nil, err
   297  	}
   298  	return typ, bytes.NewReader(p), nil
   299  }
   300  
   301  // Writer returns a writer to write a WebSocket data message to the connection.
   302  // It buffers the entire message in memory and then sends it when the writer
   303  // is closed.
   304  func (c *Conn) Writer(ctx context.Context, typ MessageType) (io.WriteCloser, error) {
   305  	return writer{
   306  		c:   c,
   307  		ctx: ctx,
   308  		typ: typ,
   309  		b:   bpool.Get(),
   310  	}, nil
   311  }
   312  
   313  type writer struct {
   314  	closed bool
   315  
   316  	c   *Conn
   317  	ctx context.Context
   318  	typ MessageType
   319  
   320  	b *bytes.Buffer
   321  }
   322  
   323  func (w writer) Write(p []byte) (int, error) {
   324  	if w.closed {
   325  		return 0, errors.New("cannot write to closed writer")
   326  	}
   327  	n, err := w.b.Write(p)
   328  	if err != nil {
   329  		return n, fmt.Errorf("failed to write message: %w", err)
   330  	}
   331  	return n, nil
   332  }
   333  
   334  func (w writer) Close() error {
   335  	if w.closed {
   336  		return errors.New("cannot close closed writer")
   337  	}
   338  	w.closed = true
   339  	defer bpool.Put(w.b)
   340  
   341  	err := w.c.Write(w.ctx, w.typ, w.b.Bytes())
   342  	if err != nil {
   343  		return fmt.Errorf("failed to close writer: %w", err)
   344  	}
   345  	return nil
   346  }
   347  
   348  // CloseRead implements *Conn.CloseRead for wasm.
   349  func (c *Conn) CloseRead(ctx context.Context) context.Context {
   350  	c.isReadClosed.Store(1)
   351  
   352  	ctx, cancel := context.WithCancel(ctx)
   353  	go func() {
   354  		defer cancel()
   355  		c.read(ctx)
   356  		c.Close(StatusPolicyViolation, "unexpected data message")
   357  	}()
   358  	return ctx
   359  }
   360  
   361  // SetReadLimit implements *Conn.SetReadLimit for wasm.
   362  func (c *Conn) SetReadLimit(n int64) {
   363  	c.msgReadLimit.Store(n)
   364  }
   365  
   366  func (c *Conn) setCloseErr(err error) {
   367  	c.closeErrOnce.Do(func() {
   368  		c.closeErr = fmt.Errorf("WebSocket closed: %w", err)
   369  	})
   370  }
   371  
   372  func (c *Conn) isClosed() bool {
   373  	select {
   374  	case <-c.closed:
   375  		return true
   376  	default:
   377  		return false
   378  	}
   379  }