github.com/Asutorufa/yuhaiin@v0.3.6-0.20240502055049-7984da7023a0/pkg/net/proxy/websocket/x/conn.go (about)

     1  // Copyright 2011 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package websocket
     6  
     7  // This file implements a protocol of hybi draft.
     8  // http://tools.ietf.org/html/draft-ietf-hybi-thewebsocketprotocol-17
     9  
    10  import (
    11  	"crypto/rand"
    12  	"encoding/binary"
    13  	"errors"
    14  	"io"
    15  	"net"
    16  	"sync"
    17  	"time"
    18  
    19  	"github.com/Asutorufa/yuhaiin/pkg/net/netapi"
    20  	"github.com/Asutorufa/yuhaiin/pkg/utils/pool"
    21  	"github.com/Asutorufa/yuhaiin/pkg/utils/relay"
    22  )
    23  
    24  // Conn represents a WebSocket connection.
    25  //
    26  // Multiple goroutines may invoke methods on a Conn simultaneously.
    27  type Conn struct {
    28  	IsServer bool
    29  
    30  	LastPayloadType opcode
    31  	PayloadType     opcode
    32  
    33  	readHeaderBuf  [8]byte
    34  	writeHeaderBuf [8]byte
    35  
    36  	rio sync.Mutex
    37  	wio sync.Mutex
    38  
    39  	Frame io.ReadCloser
    40  
    41  	RawConn net.Conn
    42  }
    43  
    44  // newConn creates a new WebSocket connection speaking hybi draft protocol.
    45  func newConn(rwc net.Conn, isServer bool) *Conn {
    46  	return &Conn{
    47  		IsServer:    isServer,
    48  		RawConn:     rwc,
    49  		PayloadType: opBinary,
    50  	}
    51  }
    52  
    53  // Read implements the io.Reader interface:
    54  // it reads data of a frame from the WebSocket connection.
    55  // if msg is not large enough for the frame data, it fills the msg and next Read
    56  // will read the rest of the frame data.
    57  // it reads Text frame or Binary frame.
    58  func (ws *Conn) Read(msg []byte) (n int, err error) {
    59  	ws.rio.Lock()
    60  	defer ws.rio.Unlock()
    61  
    62  	for {
    63  		if ws.Frame == nil {
    64  			_, ws.Frame, err = ws.nextFrameReader()
    65  			if err != nil {
    66  				return 0, err
    67  			}
    68  		}
    69  
    70  		n, err = ws.Frame.Read(msg)
    71  		if err == nil || n != 0 {
    72  			return n, err
    73  		}
    74  
    75  		if !errors.Is(err, io.EOF) {
    76  			return n, err
    77  		}
    78  
    79  		ws.Frame = nil
    80  	}
    81  
    82  }
    83  
    84  func (ws *Conn) NextFrameReader(handle func(*Header, io.ReadCloser) error) error {
    85  	ws.rio.Lock()
    86  	defer ws.rio.Unlock()
    87  
    88  	if ws.Frame != nil {
    89  		_ = ws.Frame.Close()
    90  		ws.Frame = nil
    91  	}
    92  
    93  	h, r, err := ws.nextFrameReader()
    94  	if err != nil {
    95  		return err
    96  	}
    97  	defer r.Close()
    98  
    99  	if err := handle(h, r); err != nil {
   100  		return err
   101  	}
   102  
   103  	return nil
   104  }
   105  
   106  func (ws *Conn) nextFrameReader() (*Header, io.ReadCloser, error) {
   107  	for {
   108  		header, err := readFrameHeader(netapi.NewReader(ws.RawConn), ws.readHeaderBuf[:])
   109  		if err != nil {
   110  			return nil, nil, err
   111  		}
   112  
   113  		frame := &frameReader{
   114  			masked:  header.masked,
   115  			maskKey: header.maskKey,
   116  			reader:  io.LimitReader(ws.RawConn, header.payloadLength),
   117  		}
   118  
   119  		frameReader, err := ws.handleFrame(&header, frame)
   120  		if err != nil {
   121  			return nil, nil, err
   122  		}
   123  
   124  		if frameReader != nil {
   125  			return &header, frameReader, nil
   126  		}
   127  	}
   128  }
   129  
   130  // Write implements the io.Writer interface:
   131  // it writes data as a frame to the WebSocket connection.
   132  func (ws *Conn) Write(msg []byte) (n int, err error) { return ws.WriteMsg(msg, ws.PayloadType) }
   133  
   134  func (ws *Conn) WriteMsg(msg []byte, payloadType opcode) (int, error) {
   135  
   136  	frameHeader := Header{
   137  		fin:           true,
   138  		opcode:        payloadType,
   139  		masked:        !ws.IsServer,
   140  		payloadLength: int64(len(msg)),
   141  	}
   142  
   143  	if frameHeader.masked {
   144  		_ = binary.Read(rand.Reader, binary.BigEndian, &frameHeader.maskKey)
   145  	}
   146  
   147  	buf := pool.GetBytesWriter(pool.DefaultSize + len(msg))
   148  	defer buf.Free()
   149  
   150  	if err := writeFrameHeader(frameHeader, buf, ws.writeHeaderBuf[:]); err != nil {
   151  		return 0, err
   152  	}
   153  
   154  	headerLength := buf.Len()
   155  
   156  	_, _ = buf.Write(msg)
   157  
   158  	if frameHeader.masked {
   159  		mask(frameHeader.maskKey, buf.Bytes()[headerLength:])
   160  	}
   161  
   162  	ws.wio.Lock()
   163  	n, err := ws.RawConn.Write(buf.Bytes())
   164  	ws.wio.Unlock()
   165  	if err != nil {
   166  		return n, err
   167  	}
   168  
   169  	return int(frameHeader.payloadLength), nil
   170  }
   171  
   172  func (ws *Conn) handleFrame(header *Header, frame io.ReadCloser) (io.ReadCloser, error) {
   173  	if ws.IsServer && !header.masked {
   174  		// client --> server
   175  		// The client MUST mask all frames sent to the server.
   176  		ws.WriteClose(closeStatusProtocolError)
   177  		return nil, io.EOF
   178  	} else if !ws.IsServer && header.masked {
   179  		// server --> client
   180  		// The server MUST NOT mask all frames.
   181  		ws.WriteClose(closeStatusProtocolError)
   182  		return nil, io.EOF
   183  	}
   184  
   185  	switch header.opcode {
   186  	case opContinuation:
   187  		header.opcode = ws.LastPayloadType
   188  	case opText, opBinary:
   189  		ws.LastPayloadType = header.opcode
   190  	case opClose:
   191  		ws.Close()
   192  		return nil, io.EOF
   193  	case opPing, opPong:
   194  		b := make([]byte, maxControlFramePayloadLength)
   195  		n, err := io.ReadFull(frame, b)
   196  		if err != nil && err != io.EOF && err != io.ErrUnexpectedEOF {
   197  			return nil, err
   198  		}
   199  		_ = frame.Close()
   200  		if header.opcode == opPing {
   201  			if _, err := ws.WritePong(b[:n]); err != nil {
   202  				return nil, err
   203  			}
   204  		}
   205  		return nil, nil
   206  	}
   207  	return frame, nil
   208  }
   209  
   210  func (ws *Conn) WriteClose(status int) (err error) {
   211  	_, err = ws.WriteMsg(binary.BigEndian.AppendUint16(nil, uint16(status)), opClose)
   212  	return err
   213  }
   214  
   215  func (ws *Conn) WritePong(msg []byte) (n int, err error) { return ws.WriteMsg(msg, opPong) }
   216  
   217  // Close implements the io.Closer interface.
   218  func (ws *Conn) Close() error {
   219  	return ws.RawConn.Close()
   220  }
   221  
   222  func (ws *Conn) LocalAddr() net.Addr                { return ws.RawConn.LocalAddr() }
   223  func (ws *Conn) RemoteAddr() net.Addr               { return ws.RawConn.RemoteAddr() }
   224  func (ws *Conn) SetDeadline(t time.Time) error      { return ws.RawConn.SetDeadline(t) }
   225  func (ws *Conn) SetReadDeadline(t time.Time) error  { return ws.RawConn.SetReadDeadline(t) }
   226  func (ws *Conn) SetWriteDeadline(t time.Time) error { return ws.RawConn.SetWriteDeadline(t) }
   227  
   228  // A frameReader is a reader for hybi frame.
   229  type frameReader struct {
   230  	reader io.Reader
   231  
   232  	masked  bool
   233  	maskKey uint32
   234  }
   235  
   236  func (frame *frameReader) Read(msg []byte) (n int, err error) {
   237  	n, err = frame.reader.Read(msg)
   238  	if frame.masked {
   239  		frame.maskKey = mask(frame.maskKey, msg[:n])
   240  	}
   241  	return n, err
   242  }
   243  
   244  func (f *frameReader) Close() error {
   245  	_, err := relay.Copy(io.Discard, f.reader)
   246  	return err
   247  }