github.com/sagernet/sing-box@v1.2.7/transport/v2raywebsocket/conn.go (about)

     1  package v2raywebsocket
     2  
     3  import (
     4  	"context"
     5  	"encoding/base64"
     6  	"io"
     7  	"net"
     8  	"net/http"
     9  	"os"
    10  	"time"
    11  
    12  	C "github.com/sagernet/sing-box/constant"
    13  	"github.com/sagernet/sing/common"
    14  	"github.com/sagernet/sing/common/buf"
    15  	E "github.com/sagernet/sing/common/exceptions"
    16  	"github.com/sagernet/websocket"
    17  )
    18  
    19  type WebsocketConn struct {
    20  	*websocket.Conn
    21  	*Writer
    22  	remoteAddr net.Addr
    23  	reader     io.Reader
    24  }
    25  
    26  func NewServerConn(wsConn *websocket.Conn, remoteAddr net.Addr) *WebsocketConn {
    27  	return &WebsocketConn{
    28  		Conn:       wsConn,
    29  		remoteAddr: remoteAddr,
    30  		Writer:     NewWriter(wsConn, true),
    31  	}
    32  }
    33  
    34  func (c *WebsocketConn) Close() error {
    35  	err := c.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""), time.Now().Add(C.TCPTimeout))
    36  	if err != nil {
    37  		return c.Conn.Close()
    38  	}
    39  	return nil
    40  }
    41  
    42  func (c *WebsocketConn) Read(b []byte) (n int, err error) {
    43  	for {
    44  		if c.reader == nil {
    45  			_, c.reader, err = c.NextReader()
    46  			if err != nil {
    47  				err = wrapError(err)
    48  				return
    49  			}
    50  		}
    51  		n, err = c.reader.Read(b)
    52  		if E.IsMulti(err, io.EOF) {
    53  			c.reader = nil
    54  			continue
    55  		}
    56  		err = wrapError(err)
    57  		return
    58  	}
    59  }
    60  
    61  func (c *WebsocketConn) RemoteAddr() net.Addr {
    62  	if c.remoteAddr != nil {
    63  		return c.remoteAddr
    64  	}
    65  	return c.Conn.RemoteAddr()
    66  }
    67  
    68  func (c *WebsocketConn) SetDeadline(t time.Time) error {
    69  	return os.ErrInvalid
    70  }
    71  
    72  func (c *WebsocketConn) SetReadDeadline(t time.Time) error {
    73  	return os.ErrInvalid
    74  }
    75  
    76  func (c *WebsocketConn) SetWriteDeadline(t time.Time) error {
    77  	return os.ErrInvalid
    78  }
    79  
    80  func (c *WebsocketConn) NeedAdditionalReadDeadline() bool {
    81  	return true
    82  }
    83  
    84  func (c *WebsocketConn) Upstream() any {
    85  	return c.Conn.NetConn()
    86  }
    87  
    88  func (c *WebsocketConn) UpstreamWriter() any {
    89  	return c.Writer
    90  }
    91  
    92  type EarlyWebsocketConn struct {
    93  	*Client
    94  	ctx    context.Context
    95  	conn   *WebsocketConn
    96  	create chan struct{}
    97  }
    98  
    99  func (c *EarlyWebsocketConn) Read(b []byte) (n int, err error) {
   100  	if c.conn == nil {
   101  		<-c.create
   102  	}
   103  	return c.conn.Read(b)
   104  }
   105  
   106  func (c *EarlyWebsocketConn) Write(b []byte) (n int, err error) {
   107  	if c.conn != nil {
   108  		return c.conn.Write(b)
   109  	}
   110  	var (
   111  		earlyData []byte
   112  		lateData  []byte
   113  		conn      *websocket.Conn
   114  		response  *http.Response
   115  	)
   116  	if len(b) > int(c.maxEarlyData) {
   117  		earlyData = b[:c.maxEarlyData]
   118  		lateData = b[c.maxEarlyData:]
   119  	} else {
   120  		earlyData = b
   121  	}
   122  	if len(earlyData) > 0 {
   123  		earlyDataString := base64.RawURLEncoding.EncodeToString(earlyData)
   124  		if c.earlyDataHeaderName == "" {
   125  			conn, response, err = c.dialer.DialContext(c.ctx, c.uri+earlyDataString, c.headers)
   126  		} else {
   127  			headers := c.headers.Clone()
   128  			headers.Set(c.earlyDataHeaderName, earlyDataString)
   129  			conn, response, err = c.dialer.DialContext(c.ctx, c.uri, headers)
   130  		}
   131  	} else {
   132  		conn, response, err = c.dialer.DialContext(c.ctx, c.uri, c.headers)
   133  	}
   134  	if err != nil {
   135  		return 0, wrapDialError(response, err)
   136  	}
   137  	c.conn = &WebsocketConn{Conn: conn, Writer: NewWriter(conn, false)}
   138  	close(c.create)
   139  	if len(lateData) > 0 {
   140  		_, err = c.conn.Write(lateData)
   141  	}
   142  	if err != nil {
   143  		return
   144  	}
   145  	return len(b), nil
   146  }
   147  
   148  func (c *EarlyWebsocketConn) WriteBuffer(buffer *buf.Buffer) error {
   149  	if c.conn != nil {
   150  		return c.conn.WriteBuffer(buffer)
   151  	}
   152  	var (
   153  		earlyData []byte
   154  		lateData  []byte
   155  		conn      *websocket.Conn
   156  		response  *http.Response
   157  		err       error
   158  	)
   159  	if buffer.Len() > int(c.maxEarlyData) {
   160  		earlyData = buffer.Bytes()[:c.maxEarlyData]
   161  		lateData = buffer.Bytes()[c.maxEarlyData:]
   162  	} else {
   163  		earlyData = buffer.Bytes()
   164  	}
   165  	if len(earlyData) > 0 {
   166  		earlyDataString := base64.RawURLEncoding.EncodeToString(earlyData)
   167  		if c.earlyDataHeaderName == "" {
   168  			conn, response, err = c.dialer.DialContext(c.ctx, c.uri+earlyDataString, c.headers)
   169  		} else {
   170  			headers := c.headers.Clone()
   171  			headers.Set(c.earlyDataHeaderName, earlyDataString)
   172  			conn, response, err = c.dialer.DialContext(c.ctx, c.uri, headers)
   173  		}
   174  	} else {
   175  		conn, response, err = c.dialer.DialContext(c.ctx, c.uri, c.headers)
   176  	}
   177  	if err != nil {
   178  		return wrapDialError(response, err)
   179  	}
   180  	c.conn = &WebsocketConn{Conn: conn, Writer: NewWriter(conn, false)}
   181  	close(c.create)
   182  	if len(lateData) > 0 {
   183  		_, err = c.conn.Write(lateData)
   184  	}
   185  	return err
   186  }
   187  
   188  func (c *EarlyWebsocketConn) Close() error {
   189  	if c.conn == nil {
   190  		return nil
   191  	}
   192  	return c.conn.Close()
   193  }
   194  
   195  func (c *EarlyWebsocketConn) LocalAddr() net.Addr {
   196  	if c.conn == nil {
   197  		return nil
   198  	}
   199  	return c.conn.LocalAddr()
   200  }
   201  
   202  func (c *EarlyWebsocketConn) RemoteAddr() net.Addr {
   203  	if c.conn == nil {
   204  		return nil
   205  	}
   206  	return c.conn.RemoteAddr()
   207  }
   208  
   209  func (c *EarlyWebsocketConn) SetDeadline(t time.Time) error {
   210  	return os.ErrInvalid
   211  }
   212  
   213  func (c *EarlyWebsocketConn) SetReadDeadline(t time.Time) error {
   214  	return os.ErrInvalid
   215  }
   216  
   217  func (c *EarlyWebsocketConn) SetWriteDeadline(t time.Time) error {
   218  	return os.ErrInvalid
   219  }
   220  
   221  func (c *EarlyWebsocketConn) NeedAdditionalReadDeadline() bool {
   222  	return true
   223  }
   224  
   225  func (c *EarlyWebsocketConn) Upstream() any {
   226  	return common.PtrOrNil(c.conn)
   227  }
   228  
   229  func (c *EarlyWebsocketConn) LazyHeadroom() bool {
   230  	return c.conn == nil
   231  }
   232  
   233  func wrapError(err error) error {
   234  	if websocket.IsCloseError(err, websocket.CloseNormalClosure) {
   235  		return io.EOF
   236  	}
   237  	if websocket.IsCloseError(err, websocket.CloseAbnormalClosure) {
   238  		return net.ErrClosed
   239  	}
   240  	return err
   241  }