github.com/sagernet/sing-box@v1.9.0-rc.20/transport/v2raywebsocket/conn.go (about)

     1  package v2raywebsocket
     2  
     3  import (
     4  	"context"
     5  	"encoding/base64"
     6  	"io"
     7  	"net"
     8  	"os"
     9  	"sync"
    10  	"time"
    11  
    12  	C "github.com/sagernet/sing-box/constant"
    13  	"github.com/sagernet/sing/common"
    14  	"github.com/sagernet/sing/common/buf"
    15  	"github.com/sagernet/sing/common/debug"
    16  	E "github.com/sagernet/sing/common/exceptions"
    17  	M "github.com/sagernet/sing/common/metadata"
    18  	"github.com/sagernet/ws"
    19  	"github.com/sagernet/ws/wsutil"
    20  )
    21  
    22  type WebsocketConn struct {
    23  	net.Conn
    24  	*Writer
    25  	state          ws.State
    26  	reader         *wsutil.Reader
    27  	controlHandler wsutil.FrameHandlerFunc
    28  	remoteAddr     net.Addr
    29  }
    30  
    31  func NewConn(conn net.Conn, remoteAddr net.Addr, state ws.State) *WebsocketConn {
    32  	controlHandler := wsutil.ControlFrameHandler(conn, state)
    33  	return &WebsocketConn{
    34  		Conn:  conn,
    35  		state: state,
    36  		reader: &wsutil.Reader{
    37  			Source:          conn,
    38  			State:           state,
    39  			SkipHeaderCheck: !debug.Enabled,
    40  			OnIntermediate:  controlHandler,
    41  		},
    42  		controlHandler: controlHandler,
    43  		remoteAddr:     remoteAddr,
    44  		Writer:         NewWriter(conn, state),
    45  	}
    46  }
    47  
    48  func (c *WebsocketConn) Close() error {
    49  	c.Conn.SetWriteDeadline(time.Now().Add(C.TCPTimeout))
    50  	frame := ws.NewCloseFrame(ws.NewCloseFrameBody(
    51  		ws.StatusNormalClosure, "",
    52  	))
    53  	if c.state == ws.StateClientSide {
    54  		frame = ws.MaskFrameInPlace(frame)
    55  	}
    56  	ws.WriteFrame(c.Conn, frame)
    57  	c.Conn.Close()
    58  	return nil
    59  }
    60  
    61  func (c *WebsocketConn) Read(b []byte) (n int, err error) {
    62  	var header ws.Header
    63  	for {
    64  		n, err = c.reader.Read(b)
    65  		if n > 0 {
    66  			err = nil
    67  			return
    68  		}
    69  		if !E.IsMulti(err, io.EOF, wsutil.ErrNoFrameAdvance) {
    70  			return
    71  		}
    72  		header, err = c.reader.NextFrame()
    73  		if err != nil {
    74  			return
    75  		}
    76  		if header.OpCode.IsControl() {
    77  			err = c.controlHandler(header, c.reader)
    78  			if err != nil {
    79  				return
    80  			}
    81  			continue
    82  		}
    83  		if header.OpCode&ws.OpBinary == 0 {
    84  			err = c.reader.Discard()
    85  			if err != nil {
    86  				return
    87  			}
    88  			continue
    89  		}
    90  	}
    91  }
    92  
    93  func (c *WebsocketConn) Write(p []byte) (n int, err error) {
    94  	err = wsutil.WriteMessage(c.Conn, c.state, ws.OpBinary, p)
    95  	if err != nil {
    96  		return
    97  	}
    98  	n = len(p)
    99  	return
   100  }
   101  
   102  func (c *WebsocketConn) RemoteAddr() net.Addr {
   103  	if c.remoteAddr != nil {
   104  		return c.remoteAddr
   105  	}
   106  	return c.Conn.RemoteAddr()
   107  }
   108  
   109  func (c *WebsocketConn) SetDeadline(t time.Time) error {
   110  	return os.ErrInvalid
   111  }
   112  
   113  func (c *WebsocketConn) SetReadDeadline(t time.Time) error {
   114  	return os.ErrInvalid
   115  }
   116  
   117  func (c *WebsocketConn) SetWriteDeadline(t time.Time) error {
   118  	return os.ErrInvalid
   119  }
   120  
   121  func (c *WebsocketConn) NeedAdditionalReadDeadline() bool {
   122  	return true
   123  }
   124  
   125  func (c *WebsocketConn) Upstream() any {
   126  	return c.Conn
   127  }
   128  
   129  type EarlyWebsocketConn struct {
   130  	*Client
   131  	ctx    context.Context
   132  	conn   *WebsocketConn
   133  	access sync.Mutex
   134  	create chan struct{}
   135  	err    error
   136  }
   137  
   138  func (c *EarlyWebsocketConn) Read(b []byte) (n int, err error) {
   139  	if c.conn == nil {
   140  		<-c.create
   141  		if c.err != nil {
   142  			return 0, c.err
   143  		}
   144  	}
   145  	return c.conn.Read(b)
   146  }
   147  
   148  func (c *EarlyWebsocketConn) writeRequest(content []byte) error {
   149  	var (
   150  		earlyData []byte
   151  		lateData  []byte
   152  		conn      *WebsocketConn
   153  		err       error
   154  	)
   155  	if len(content) > int(c.maxEarlyData) {
   156  		earlyData = content[:c.maxEarlyData]
   157  		lateData = content[c.maxEarlyData:]
   158  	} else {
   159  		earlyData = content
   160  	}
   161  	if len(earlyData) > 0 {
   162  		earlyDataString := base64.RawURLEncoding.EncodeToString(earlyData)
   163  		if c.earlyDataHeaderName == "" {
   164  			requestURL := c.requestURL
   165  			requestURL.Path += earlyDataString
   166  			conn, err = c.dialContext(c.ctx, &requestURL, c.headers)
   167  		} else {
   168  			headers := c.headers.Clone()
   169  			headers.Set(c.earlyDataHeaderName, earlyDataString)
   170  			conn, err = c.dialContext(c.ctx, &c.requestURL, headers)
   171  		}
   172  	} else {
   173  		conn, err = c.dialContext(c.ctx, &c.requestURL, c.headers)
   174  	}
   175  	if err != nil {
   176  		return err
   177  	}
   178  	if len(lateData) > 0 {
   179  		_, err = conn.Write(lateData)
   180  		if err != nil {
   181  			return err
   182  		}
   183  	}
   184  	c.conn = conn
   185  	return nil
   186  }
   187  
   188  func (c *EarlyWebsocketConn) Write(b []byte) (n int, err error) {
   189  	if c.conn != nil {
   190  		return c.conn.Write(b)
   191  	}
   192  	c.access.Lock()
   193  	defer c.access.Unlock()
   194  	if c.err != nil {
   195  		return 0, c.err
   196  	}
   197  	if c.conn != nil {
   198  		return c.conn.Write(b)
   199  	}
   200  	err = c.writeRequest(b)
   201  	c.err = err
   202  	close(c.create)
   203  	if err != nil {
   204  		return
   205  	}
   206  	return len(b), nil
   207  }
   208  
   209  func (c *EarlyWebsocketConn) WriteBuffer(buffer *buf.Buffer) error {
   210  	if c.conn != nil {
   211  		return c.conn.WriteBuffer(buffer)
   212  	}
   213  	c.access.Lock()
   214  	defer c.access.Unlock()
   215  	if c.conn != nil {
   216  		return c.conn.WriteBuffer(buffer)
   217  	}
   218  	if c.err != nil {
   219  		return c.err
   220  	}
   221  	err := c.writeRequest(buffer.Bytes())
   222  	c.err = err
   223  	close(c.create)
   224  	return err
   225  }
   226  
   227  func (c *EarlyWebsocketConn) Close() error {
   228  	if c.conn == nil {
   229  		return nil
   230  	}
   231  	return c.conn.Close()
   232  }
   233  
   234  func (c *EarlyWebsocketConn) LocalAddr() net.Addr {
   235  	if c.conn == nil {
   236  		return M.Socksaddr{}
   237  	}
   238  	return c.conn.LocalAddr()
   239  }
   240  
   241  func (c *EarlyWebsocketConn) RemoteAddr() net.Addr {
   242  	if c.conn == nil {
   243  		return M.Socksaddr{}
   244  	}
   245  	return c.conn.RemoteAddr()
   246  }
   247  
   248  func (c *EarlyWebsocketConn) SetDeadline(t time.Time) error {
   249  	return os.ErrInvalid
   250  }
   251  
   252  func (c *EarlyWebsocketConn) SetReadDeadline(t time.Time) error {
   253  	return os.ErrInvalid
   254  }
   255  
   256  func (c *EarlyWebsocketConn) SetWriteDeadline(t time.Time) error {
   257  	return os.ErrInvalid
   258  }
   259  
   260  func (c *EarlyWebsocketConn) NeedAdditionalReadDeadline() bool {
   261  	return true
   262  }
   263  
   264  func (c *EarlyWebsocketConn) Upstream() any {
   265  	return common.PtrOrNil(c.conn)
   266  }
   267  
   268  func (c *EarlyWebsocketConn) LazyHeadroom() bool {
   269  	return c.conn == nil
   270  }