github.com/inazumav/sing-box@v0.0.0-20230926072359-ab51429a14f1/transport/v2raywebsocket/conn.go (about)

     1  package v2raywebsocket
     2  
     3  import (
     4  	"context"
     5  	"encoding/base64"
     6  	"io"
     7  	"net"
     8  	"net/http"
     9  	"os"
    10  	"time"
    11  
    12  	C "github.com/inazumav/sing-box/constant"
    13  	"github.com/sagernet/sing/common"
    14  	"github.com/sagernet/sing/common/buf"
    15  	E "github.com/sagernet/sing/common/exceptions"
    16  	"github.com/sagernet/websocket"
    17  )
    18  
    19  type WebsocketConn struct {
    20  	*websocket.Conn
    21  	*Writer
    22  	remoteAddr net.Addr
    23  	reader     io.Reader
    24  }
    25  
    26  func NewServerConn(wsConn *websocket.Conn, remoteAddr net.Addr) *WebsocketConn {
    27  	return &WebsocketConn{
    28  		Conn:       wsConn,
    29  		remoteAddr: remoteAddr,
    30  		Writer:     NewWriter(wsConn, true),
    31  	}
    32  }
    33  
    34  func (c *WebsocketConn) Close() error {
    35  	err := c.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""), time.Now().Add(C.TCPTimeout))
    36  	if err != nil {
    37  		return c.Conn.Close()
    38  	}
    39  	return nil
    40  }
    41  
    42  func (c *WebsocketConn) Read(b []byte) (n int, err error) {
    43  	for {
    44  		if c.reader == nil {
    45  			_, c.reader, err = c.NextReader()
    46  			if err != nil {
    47  				err = wrapError(err)
    48  				return
    49  			}
    50  		}
    51  		n, err = c.reader.Read(b)
    52  		if E.IsMulti(err, io.EOF) {
    53  			c.reader = nil
    54  			continue
    55  		}
    56  		err = wrapError(err)
    57  		return
    58  	}
    59  }
    60  
    61  func (c *WebsocketConn) RemoteAddr() net.Addr {
    62  	if c.remoteAddr != nil {
    63  		return c.remoteAddr
    64  	}
    65  	return c.Conn.RemoteAddr()
    66  }
    67  
    68  func (c *WebsocketConn) SetDeadline(t time.Time) error {
    69  	return os.ErrInvalid
    70  }
    71  
    72  func (c *WebsocketConn) SetReadDeadline(t time.Time) error {
    73  	return os.ErrInvalid
    74  }
    75  
    76  func (c *WebsocketConn) SetWriteDeadline(t time.Time) error {
    77  	return os.ErrInvalid
    78  }
    79  
    80  func (c *WebsocketConn) NeedAdditionalReadDeadline() bool {
    81  	return true
    82  }
    83  
    84  func (c *WebsocketConn) Upstream() any {
    85  	return c.Conn.NetConn()
    86  }
    87  
    88  func (c *WebsocketConn) UpstreamWriter() any {
    89  	return c.Writer
    90  }
    91  
    92  type EarlyWebsocketConn struct {
    93  	*Client
    94  	ctx    context.Context
    95  	conn   *WebsocketConn
    96  	create chan struct{}
    97  	err    error
    98  }
    99  
   100  func (c *EarlyWebsocketConn) Read(b []byte) (n int, err error) {
   101  	if c.conn == nil {
   102  		<-c.create
   103  		if c.err != nil {
   104  			return 0, c.err
   105  		}
   106  	}
   107  	return c.conn.Read(b)
   108  }
   109  
   110  func (c *EarlyWebsocketConn) writeRequest(content []byte) error {
   111  	var (
   112  		earlyData []byte
   113  		lateData  []byte
   114  		conn      *websocket.Conn
   115  		response  *http.Response
   116  		err       error
   117  	)
   118  	if len(content) > int(c.maxEarlyData) {
   119  		earlyData = content[:c.maxEarlyData]
   120  		lateData = content[c.maxEarlyData:]
   121  	} else {
   122  		earlyData = content
   123  	}
   124  	if len(earlyData) > 0 {
   125  		earlyDataString := base64.RawURLEncoding.EncodeToString(earlyData)
   126  		if c.earlyDataHeaderName == "" {
   127  			requestURL := c.requestURL
   128  			requestURL.Path += earlyDataString
   129  			conn, response, err = c.dialer.DialContext(c.ctx, requestURL.String(), c.headers)
   130  		} else {
   131  			headers := c.headers.Clone()
   132  			headers.Set(c.earlyDataHeaderName, earlyDataString)
   133  			conn, response, err = c.dialer.DialContext(c.ctx, c.requestURLString, headers)
   134  		}
   135  	} else {
   136  		conn, response, err = c.dialer.DialContext(c.ctx, c.requestURLString, c.headers)
   137  	}
   138  	if err != nil {
   139  		return wrapDialError(response, err)
   140  	}
   141  	c.conn = &WebsocketConn{Conn: conn, Writer: NewWriter(conn, false)}
   142  	if len(lateData) > 0 {
   143  		_, err = c.conn.Write(lateData)
   144  	}
   145  	return err
   146  }
   147  
   148  func (c *EarlyWebsocketConn) Write(b []byte) (n int, err error) {
   149  	if c.conn != nil {
   150  		return c.conn.Write(b)
   151  	}
   152  	err = c.writeRequest(b)
   153  	c.err = err
   154  	close(c.create)
   155  	if err != nil {
   156  		return
   157  	}
   158  	return len(b), nil
   159  }
   160  
   161  func (c *EarlyWebsocketConn) WriteBuffer(buffer *buf.Buffer) error {
   162  	if c.conn != nil {
   163  		return c.conn.WriteBuffer(buffer)
   164  	}
   165  	err := c.writeRequest(buffer.Bytes())
   166  	c.err = err
   167  	close(c.create)
   168  	return err
   169  }
   170  
   171  func (c *EarlyWebsocketConn) Close() error {
   172  	if c.conn == nil {
   173  		return nil
   174  	}
   175  	return c.conn.Close()
   176  }
   177  
   178  func (c *EarlyWebsocketConn) LocalAddr() net.Addr {
   179  	if c.conn == nil {
   180  		return nil
   181  	}
   182  	return c.conn.LocalAddr()
   183  }
   184  
   185  func (c *EarlyWebsocketConn) RemoteAddr() net.Addr {
   186  	if c.conn == nil {
   187  		return nil
   188  	}
   189  	return c.conn.RemoteAddr()
   190  }
   191  
   192  func (c *EarlyWebsocketConn) SetDeadline(t time.Time) error {
   193  	return os.ErrInvalid
   194  }
   195  
   196  func (c *EarlyWebsocketConn) SetReadDeadline(t time.Time) error {
   197  	return os.ErrInvalid
   198  }
   199  
   200  func (c *EarlyWebsocketConn) SetWriteDeadline(t time.Time) error {
   201  	return os.ErrInvalid
   202  }
   203  
   204  func (c *EarlyWebsocketConn) NeedAdditionalReadDeadline() bool {
   205  	return true
   206  }
   207  
   208  func (c *EarlyWebsocketConn) Upstream() any {
   209  	return common.PtrOrNil(c.conn)
   210  }
   211  
   212  func (c *EarlyWebsocketConn) LazyHeadroom() bool {
   213  	return c.conn == nil
   214  }
   215  
   216  func wrapError(err error) error {
   217  	if websocket.IsCloseError(err, websocket.CloseNormalClosure) {
   218  		return io.EOF
   219  	}
   220  	if websocket.IsCloseError(err, websocket.CloseAbnormalClosure) {
   221  		return net.ErrClosed
   222  	}
   223  	return err
   224  }