golang.org/x/net@v0.25.1-0.20240516223405-c87a5b62e243/websocket/websocket.go (about)

     1  // Copyright 2009 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // Package websocket implements a client and server for the WebSocket protocol
     6  // as specified in RFC 6455.
     7  //
     8  // This package currently lacks some features found in an alternative
     9  // and more actively maintained WebSocket package:
    10  //
    11  //	https://pkg.go.dev/nhooyr.io/websocket
    12  package websocket // import "golang.org/x/net/websocket"
    13  
    14  import (
    15  	"bufio"
    16  	"crypto/tls"
    17  	"encoding/json"
    18  	"errors"
    19  	"io"
    20  	"io/ioutil"
    21  	"net"
    22  	"net/http"
    23  	"net/url"
    24  	"sync"
    25  	"time"
    26  )
    27  
    28  const (
    29  	ProtocolVersionHybi13    = 13
    30  	ProtocolVersionHybi      = ProtocolVersionHybi13
    31  	SupportedProtocolVersion = "13"
    32  
    33  	ContinuationFrame = 0
    34  	TextFrame         = 1
    35  	BinaryFrame       = 2
    36  	CloseFrame        = 8
    37  	PingFrame         = 9
    38  	PongFrame         = 10
    39  	UnknownFrame      = 255
    40  
    41  	DefaultMaxPayloadBytes = 32 << 20 // 32MB
    42  )
    43  
    44  // ProtocolError represents WebSocket protocol errors.
    45  type ProtocolError struct {
    46  	ErrorString string
    47  }
    48  
    49  func (err *ProtocolError) Error() string { return err.ErrorString }
    50  
    51  var (
    52  	ErrBadProtocolVersion   = &ProtocolError{"bad protocol version"}
    53  	ErrBadScheme            = &ProtocolError{"bad scheme"}
    54  	ErrBadStatus            = &ProtocolError{"bad status"}
    55  	ErrBadUpgrade           = &ProtocolError{"missing or bad upgrade"}
    56  	ErrBadWebSocketOrigin   = &ProtocolError{"missing or bad WebSocket-Origin"}
    57  	ErrBadWebSocketLocation = &ProtocolError{"missing or bad WebSocket-Location"}
    58  	ErrBadWebSocketProtocol = &ProtocolError{"missing or bad WebSocket-Protocol"}
    59  	ErrBadWebSocketVersion  = &ProtocolError{"missing or bad WebSocket Version"}
    60  	ErrChallengeResponse    = &ProtocolError{"mismatch challenge/response"}
    61  	ErrBadFrame             = &ProtocolError{"bad frame"}
    62  	ErrBadFrameBoundary     = &ProtocolError{"not on frame boundary"}
    63  	ErrNotWebSocket         = &ProtocolError{"not websocket protocol"}
    64  	ErrBadRequestMethod     = &ProtocolError{"bad method"}
    65  	ErrNotSupported         = &ProtocolError{"not supported"}
    66  )
    67  
    68  // ErrFrameTooLarge is returned by Codec's Receive method if payload size
    69  // exceeds limit set by Conn.MaxPayloadBytes
    70  var ErrFrameTooLarge = errors.New("websocket: frame payload size exceeds limit")
    71  
    72  // Addr is an implementation of net.Addr for WebSocket.
    73  type Addr struct {
    74  	*url.URL
    75  }
    76  
    77  // Network returns the network type for a WebSocket, "websocket".
    78  func (addr *Addr) Network() string { return "websocket" }
    79  
    80  // Config is a WebSocket configuration
    81  type Config struct {
    82  	// A WebSocket server address.
    83  	Location *url.URL
    84  
    85  	// A Websocket client origin.
    86  	Origin *url.URL
    87  
    88  	// WebSocket subprotocols.
    89  	Protocol []string
    90  
    91  	// WebSocket protocol version.
    92  	Version int
    93  
    94  	// TLS config for secure WebSocket (wss).
    95  	TlsConfig *tls.Config
    96  
    97  	// Additional header fields to be sent in WebSocket opening handshake.
    98  	Header http.Header
    99  
   100  	// Dialer used when opening websocket connections.
   101  	Dialer *net.Dialer
   102  
   103  	handshakeData map[string]string
   104  }
   105  
   106  // serverHandshaker is an interface to handle WebSocket server side handshake.
   107  type serverHandshaker interface {
   108  	// ReadHandshake reads handshake request message from client.
   109  	// Returns http response code and error if any.
   110  	ReadHandshake(buf *bufio.Reader, req *http.Request) (code int, err error)
   111  
   112  	// AcceptHandshake accepts the client handshake request and sends
   113  	// handshake response back to client.
   114  	AcceptHandshake(buf *bufio.Writer) (err error)
   115  
   116  	// NewServerConn creates a new WebSocket connection.
   117  	NewServerConn(buf *bufio.ReadWriter, rwc io.ReadWriteCloser, request *http.Request) (conn *Conn)
   118  }
   119  
   120  // frameReader is an interface to read a WebSocket frame.
   121  type frameReader interface {
   122  	// Reader is to read payload of the frame.
   123  	io.Reader
   124  
   125  	// PayloadType returns payload type.
   126  	PayloadType() byte
   127  
   128  	// HeaderReader returns a reader to read header of the frame.
   129  	HeaderReader() io.Reader
   130  
   131  	// TrailerReader returns a reader to read trailer of the frame.
   132  	// If it returns nil, there is no trailer in the frame.
   133  	TrailerReader() io.Reader
   134  
   135  	// Len returns total length of the frame, including header and trailer.
   136  	Len() int
   137  }
   138  
   139  // frameReaderFactory is an interface to creates new frame reader.
   140  type frameReaderFactory interface {
   141  	NewFrameReader() (r frameReader, err error)
   142  }
   143  
   144  // frameWriter is an interface to write a WebSocket frame.
   145  type frameWriter interface {
   146  	// Writer is to write payload of the frame.
   147  	io.WriteCloser
   148  }
   149  
   150  // frameWriterFactory is an interface to create new frame writer.
   151  type frameWriterFactory interface {
   152  	NewFrameWriter(payloadType byte) (w frameWriter, err error)
   153  }
   154  
   155  type frameHandler interface {
   156  	HandleFrame(frame frameReader) (r frameReader, err error)
   157  	WriteClose(status int) (err error)
   158  }
   159  
   160  // Conn represents a WebSocket connection.
   161  //
   162  // Multiple goroutines may invoke methods on a Conn simultaneously.
   163  type Conn struct {
   164  	config  *Config
   165  	request *http.Request
   166  
   167  	buf *bufio.ReadWriter
   168  	rwc io.ReadWriteCloser
   169  
   170  	rio sync.Mutex
   171  	frameReaderFactory
   172  	frameReader
   173  
   174  	wio sync.Mutex
   175  	frameWriterFactory
   176  
   177  	frameHandler
   178  	PayloadType        byte
   179  	defaultCloseStatus int
   180  
   181  	// MaxPayloadBytes limits the size of frame payload received over Conn
   182  	// by Codec's Receive method. If zero, DefaultMaxPayloadBytes is used.
   183  	MaxPayloadBytes int
   184  }
   185  
   186  // Read implements the io.Reader interface:
   187  // it reads data of a frame from the WebSocket connection.
   188  // if msg is not large enough for the frame data, it fills the msg and next Read
   189  // will read the rest of the frame data.
   190  // it reads Text frame or Binary frame.
   191  func (ws *Conn) Read(msg []byte) (n int, err error) {
   192  	ws.rio.Lock()
   193  	defer ws.rio.Unlock()
   194  again:
   195  	if ws.frameReader == nil {
   196  		frame, err := ws.frameReaderFactory.NewFrameReader()
   197  		if err != nil {
   198  			return 0, err
   199  		}
   200  		ws.frameReader, err = ws.frameHandler.HandleFrame(frame)
   201  		if err != nil {
   202  			return 0, err
   203  		}
   204  		if ws.frameReader == nil {
   205  			goto again
   206  		}
   207  	}
   208  	n, err = ws.frameReader.Read(msg)
   209  	if err == io.EOF {
   210  		if trailer := ws.frameReader.TrailerReader(); trailer != nil {
   211  			io.Copy(ioutil.Discard, trailer)
   212  		}
   213  		ws.frameReader = nil
   214  		goto again
   215  	}
   216  	return n, err
   217  }
   218  
   219  // Write implements the io.Writer interface:
   220  // it writes data as a frame to the WebSocket connection.
   221  func (ws *Conn) Write(msg []byte) (n int, err error) {
   222  	ws.wio.Lock()
   223  	defer ws.wio.Unlock()
   224  	w, err := ws.frameWriterFactory.NewFrameWriter(ws.PayloadType)
   225  	if err != nil {
   226  		return 0, err
   227  	}
   228  	n, err = w.Write(msg)
   229  	w.Close()
   230  	return n, err
   231  }
   232  
   233  // Close implements the io.Closer interface.
   234  func (ws *Conn) Close() error {
   235  	err := ws.frameHandler.WriteClose(ws.defaultCloseStatus)
   236  	err1 := ws.rwc.Close()
   237  	if err != nil {
   238  		return err
   239  	}
   240  	return err1
   241  }
   242  
   243  // IsClientConn reports whether ws is a client-side connection.
   244  func (ws *Conn) IsClientConn() bool { return ws.request == nil }
   245  
   246  // IsServerConn reports whether ws is a server-side connection.
   247  func (ws *Conn) IsServerConn() bool { return ws.request != nil }
   248  
   249  // LocalAddr returns the WebSocket Origin for the connection for client, or
   250  // the WebSocket location for server.
   251  func (ws *Conn) LocalAddr() net.Addr {
   252  	if ws.IsClientConn() {
   253  		return &Addr{ws.config.Origin}
   254  	}
   255  	return &Addr{ws.config.Location}
   256  }
   257  
   258  // RemoteAddr returns the WebSocket location for the connection for client, or
   259  // the Websocket Origin for server.
   260  func (ws *Conn) RemoteAddr() net.Addr {
   261  	if ws.IsClientConn() {
   262  		return &Addr{ws.config.Location}
   263  	}
   264  	return &Addr{ws.config.Origin}
   265  }
   266  
   267  var errSetDeadline = errors.New("websocket: cannot set deadline: not using a net.Conn")
   268  
   269  // SetDeadline sets the connection's network read & write deadlines.
   270  func (ws *Conn) SetDeadline(t time.Time) error {
   271  	if conn, ok := ws.rwc.(net.Conn); ok {
   272  		return conn.SetDeadline(t)
   273  	}
   274  	return errSetDeadline
   275  }
   276  
   277  // SetReadDeadline sets the connection's network read deadline.
   278  func (ws *Conn) SetReadDeadline(t time.Time) error {
   279  	if conn, ok := ws.rwc.(net.Conn); ok {
   280  		return conn.SetReadDeadline(t)
   281  	}
   282  	return errSetDeadline
   283  }
   284  
   285  // SetWriteDeadline sets the connection's network write deadline.
   286  func (ws *Conn) SetWriteDeadline(t time.Time) error {
   287  	if conn, ok := ws.rwc.(net.Conn); ok {
   288  		return conn.SetWriteDeadline(t)
   289  	}
   290  	return errSetDeadline
   291  }
   292  
   293  // Config returns the WebSocket config.
   294  func (ws *Conn) Config() *Config { return ws.config }
   295  
   296  // Request returns the http request upgraded to the WebSocket.
   297  // It is nil for client side.
   298  func (ws *Conn) Request() *http.Request { return ws.request }
   299  
   300  // Codec represents a symmetric pair of functions that implement a codec.
   301  type Codec struct {
   302  	Marshal   func(v interface{}) (data []byte, payloadType byte, err error)
   303  	Unmarshal func(data []byte, payloadType byte, v interface{}) (err error)
   304  }
   305  
   306  // Send sends v marshaled by cd.Marshal as single frame to ws.
   307  func (cd Codec) Send(ws *Conn, v interface{}) (err error) {
   308  	data, payloadType, err := cd.Marshal(v)
   309  	if err != nil {
   310  		return err
   311  	}
   312  	ws.wio.Lock()
   313  	defer ws.wio.Unlock()
   314  	w, err := ws.frameWriterFactory.NewFrameWriter(payloadType)
   315  	if err != nil {
   316  		return err
   317  	}
   318  	_, err = w.Write(data)
   319  	w.Close()
   320  	return err
   321  }
   322  
   323  // Receive receives single frame from ws, unmarshaled by cd.Unmarshal and stores
   324  // in v. The whole frame payload is read to an in-memory buffer; max size of
   325  // payload is defined by ws.MaxPayloadBytes. If frame payload size exceeds
   326  // limit, ErrFrameTooLarge is returned; in this case frame is not read off wire
   327  // completely. The next call to Receive would read and discard leftover data of
   328  // previous oversized frame before processing next frame.
   329  func (cd Codec) Receive(ws *Conn, v interface{}) (err error) {
   330  	ws.rio.Lock()
   331  	defer ws.rio.Unlock()
   332  	if ws.frameReader != nil {
   333  		_, err = io.Copy(ioutil.Discard, ws.frameReader)
   334  		if err != nil {
   335  			return err
   336  		}
   337  		ws.frameReader = nil
   338  	}
   339  again:
   340  	frame, err := ws.frameReaderFactory.NewFrameReader()
   341  	if err != nil {
   342  		return err
   343  	}
   344  	frame, err = ws.frameHandler.HandleFrame(frame)
   345  	if err != nil {
   346  		return err
   347  	}
   348  	if frame == nil {
   349  		goto again
   350  	}
   351  	maxPayloadBytes := ws.MaxPayloadBytes
   352  	if maxPayloadBytes == 0 {
   353  		maxPayloadBytes = DefaultMaxPayloadBytes
   354  	}
   355  	if hf, ok := frame.(*hybiFrameReader); ok && hf.header.Length > int64(maxPayloadBytes) {
   356  		// payload size exceeds limit, no need to call Unmarshal
   357  		//
   358  		// set frameReader to current oversized frame so that
   359  		// the next call to this function can drain leftover
   360  		// data before processing the next frame
   361  		ws.frameReader = frame
   362  		return ErrFrameTooLarge
   363  	}
   364  	payloadType := frame.PayloadType()
   365  	data, err := ioutil.ReadAll(frame)
   366  	if err != nil {
   367  		return err
   368  	}
   369  	return cd.Unmarshal(data, payloadType, v)
   370  }
   371  
   372  func marshal(v interface{}) (msg []byte, payloadType byte, err error) {
   373  	switch data := v.(type) {
   374  	case string:
   375  		return []byte(data), TextFrame, nil
   376  	case []byte:
   377  		return data, BinaryFrame, nil
   378  	}
   379  	return nil, UnknownFrame, ErrNotSupported
   380  }
   381  
   382  func unmarshal(msg []byte, payloadType byte, v interface{}) (err error) {
   383  	switch data := v.(type) {
   384  	case *string:
   385  		*data = string(msg)
   386  		return nil
   387  	case *[]byte:
   388  		*data = msg
   389  		return nil
   390  	}
   391  	return ErrNotSupported
   392  }
   393  
   394  /*
   395  Message is a codec to send/receive text/binary data in a frame on WebSocket connection.
   396  To send/receive text frame, use string type.
   397  To send/receive binary frame, use []byte type.
   398  
   399  Trivial usage:
   400  
   401  	import "websocket"
   402  
   403  	// receive text frame
   404  	var message string
   405  	websocket.Message.Receive(ws, &message)
   406  
   407  	// send text frame
   408  	message = "hello"
   409  	websocket.Message.Send(ws, message)
   410  
   411  	// receive binary frame
   412  	var data []byte
   413  	websocket.Message.Receive(ws, &data)
   414  
   415  	// send binary frame
   416  	data = []byte{0, 1, 2}
   417  	websocket.Message.Send(ws, data)
   418  */
   419  var Message = Codec{marshal, unmarshal}
   420  
   421  func jsonMarshal(v interface{}) (msg []byte, payloadType byte, err error) {
   422  	msg, err = json.Marshal(v)
   423  	return msg, TextFrame, err
   424  }
   425  
   426  func jsonUnmarshal(msg []byte, payloadType byte, v interface{}) (err error) {
   427  	return json.Unmarshal(msg, v)
   428  }
   429  
   430  /*
   431  JSON is a codec to send/receive JSON data in a frame from a WebSocket connection.
   432  
   433  Trivial usage:
   434  
   435  	import "websocket"
   436  
   437  	type T struct {
   438  		Msg string
   439  		Count int
   440  	}
   441  
   442  	// receive JSON type T
   443  	var data T
   444  	websocket.JSON.Receive(ws, &data)
   445  
   446  	// send JSON type T
   447  	websocket.JSON.Send(ws, data)
   448  */
   449  var JSON = Codec{jsonMarshal, jsonUnmarshal}