github.com/moqsien/xraycore@v1.8.5/transport/internet/websocket/dialer.go (about)

     1  package websocket
     2  
     3  import (
     4  	"context"
     5  	_ "embed"
     6  	"encoding/base64"
     7  	"fmt"
     8  	"io"
     9  	gonet "net"
    10  	"net/http"
    11  	"os"
    12  	"time"
    13  
    14  	"github.com/gorilla/websocket"
    15  	"github.com/moqsien/xraycore/common"
    16  	"github.com/moqsien/xraycore/common/net"
    17  	"github.com/moqsien/xraycore/common/session"
    18  	"github.com/moqsien/xraycore/transport/internet"
    19  	"github.com/moqsien/xraycore/transport/internet/stat"
    20  	"github.com/moqsien/xraycore/transport/internet/tls"
    21  )
    22  
    23  //go:embed dialer.html
    24  var webpage []byte
    25  
    26  var conns chan *websocket.Conn
    27  
    28  func init() {
    29  	if addr := os.Getenv("XRAY_BROWSER_DIALER"); addr != "" {
    30  		conns = make(chan *websocket.Conn, 256)
    31  		go http.ListenAndServe(addr, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    32  			if r.URL.Path == "/websocket" {
    33  				if conn, err := upgrader.Upgrade(w, r, nil); err == nil {
    34  					conns <- conn
    35  				} else {
    36  					fmt.Println("unexpected error")
    37  				}
    38  			} else {
    39  				w.Write(webpage)
    40  			}
    41  		}))
    42  	}
    43  }
    44  
    45  // Dial dials a WebSocket connection to the given destination.
    46  func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (stat.Connection, error) {
    47  	newError("creating connection to ", dest).WriteToLog(session.ExportIDToError(ctx))
    48  	var conn net.Conn
    49  	if streamSettings.ProtocolSettings.(*Config).Ed > 0 {
    50  		ctx, cancel := context.WithCancel(ctx)
    51  		conn = &delayDialConn{
    52  			dialed:         make(chan bool, 1),
    53  			cancel:         cancel,
    54  			ctx:            ctx,
    55  			dest:           dest,
    56  			streamSettings: streamSettings,
    57  		}
    58  	} else {
    59  		var err error
    60  		if conn, err = dialWebSocket(ctx, dest, streamSettings, nil); err != nil {
    61  			return nil, newError("failed to dial WebSocket").Base(err)
    62  		}
    63  	}
    64  	return stat.Connection(conn), nil
    65  }
    66  
    67  func init() {
    68  	common.Must(internet.RegisterTransportDialer(protocolName, Dial))
    69  }
    70  
    71  func dialWebSocket(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig, ed []byte) (net.Conn, error) {
    72  	wsSettings := streamSettings.ProtocolSettings.(*Config)
    73  
    74  	dialer := &websocket.Dialer{
    75  		NetDial: func(network, addr string) (net.Conn, error) {
    76  			return internet.DialSystem(ctx, dest, streamSettings.SocketSettings)
    77  		},
    78  		ReadBufferSize:   4 * 1024,
    79  		WriteBufferSize:  4 * 1024,
    80  		HandshakeTimeout: time.Second * 8,
    81  	}
    82  
    83  	protocol := "ws"
    84  
    85  	if config := tls.ConfigFromStreamSettings(streamSettings); config != nil {
    86  		protocol = "wss"
    87  		tlsConfig := config.GetTLSConfig(tls.WithDestination(dest), tls.WithNextProto("http/1.1"))
    88  		dialer.TLSClientConfig = tlsConfig
    89  		if fingerprint := tls.GetFingerprint(config.Fingerprint); fingerprint != nil {
    90  			dialer.NetDialTLSContext = func(_ context.Context, _, addr string) (gonet.Conn, error) {
    91  				// Like the NetDial in the dialer
    92  				pconn, err := internet.DialSystem(ctx, dest, streamSettings.SocketSettings)
    93  				if err != nil {
    94  					newError("failed to dial to " + addr).Base(err).AtError().WriteToLog()
    95  					return nil, err
    96  				}
    97  				// TLS and apply the handshake
    98  				cn := tls.UClient(pconn, tlsConfig, fingerprint).(*tls.UConn)
    99  				if err := cn.WebsocketHandshake(); err != nil {
   100  					newError("failed to dial to " + addr).Base(err).AtError().WriteToLog()
   101  					return nil, err
   102  				}
   103  				if !tlsConfig.InsecureSkipVerify {
   104  					if err := cn.VerifyHostname(tlsConfig.ServerName); err != nil {
   105  						newError("failed to dial to " + addr).Base(err).AtError().WriteToLog()
   106  						return nil, err
   107  					}
   108  				}
   109  				return cn, nil
   110  			}
   111  		}
   112  	}
   113  
   114  	host := dest.NetAddr()
   115  	if (protocol == "ws" && dest.Port == 80) || (protocol == "wss" && dest.Port == 443) {
   116  		host = dest.Address.String()
   117  	}
   118  	uri := protocol + "://" + host + wsSettings.GetNormalizedPath()
   119  
   120  	if conns != nil {
   121  		data := []byte(uri)
   122  		if ed != nil {
   123  			data = append(data, " "+base64.RawURLEncoding.EncodeToString(ed)...)
   124  		}
   125  		var conn *websocket.Conn
   126  		for {
   127  			conn = <-conns
   128  			if conn.WriteMessage(websocket.TextMessage, data) != nil {
   129  				conn.Close()
   130  			} else {
   131  				break
   132  			}
   133  		}
   134  		if _, p, err := conn.ReadMessage(); err != nil {
   135  			conn.Close()
   136  			return nil, err
   137  		} else if s := string(p); s != "ok" {
   138  			conn.Close()
   139  			return nil, newError(s)
   140  		}
   141  		return newConnection(conn, conn.RemoteAddr(), nil), nil
   142  	}
   143  
   144  	header := wsSettings.GetRequestHeader()
   145  	if ed != nil {
   146  		// RawURLEncoding is support by both V2Ray/V2Fly and XRay.
   147  		header.Set("Sec-WebSocket-Protocol", base64.RawURLEncoding.EncodeToString(ed))
   148  	}
   149  
   150  	conn, resp, err := dialer.Dial(uri, header)
   151  	if err != nil {
   152  		var reason string
   153  		if resp != nil {
   154  			reason = resp.Status
   155  		}
   156  		return nil, newError("failed to dial to (", uri, "): ", reason).Base(err)
   157  	}
   158  
   159  	return newConnection(conn, conn.RemoteAddr(), nil), nil
   160  }
   161  
   162  type delayDialConn struct {
   163  	net.Conn
   164  	closed         bool
   165  	dialed         chan bool
   166  	cancel         context.CancelFunc
   167  	ctx            context.Context
   168  	dest           net.Destination
   169  	streamSettings *internet.MemoryStreamConfig
   170  }
   171  
   172  func (d *delayDialConn) Write(b []byte) (int, error) {
   173  	if d.closed {
   174  		return 0, io.ErrClosedPipe
   175  	}
   176  	if d.Conn == nil {
   177  		ed := b
   178  		if len(ed) > int(d.streamSettings.ProtocolSettings.(*Config).Ed) {
   179  			ed = nil
   180  		}
   181  		var err error
   182  		if d.Conn, err = dialWebSocket(d.ctx, d.dest, d.streamSettings, ed); err != nil {
   183  			d.Close()
   184  			return 0, newError("failed to dial WebSocket").Base(err)
   185  		}
   186  		d.dialed <- true
   187  		if ed != nil {
   188  			return len(ed), nil
   189  		}
   190  	}
   191  	return d.Conn.Write(b)
   192  }
   193  
   194  func (d *delayDialConn) Read(b []byte) (int, error) {
   195  	if d.closed {
   196  		return 0, io.ErrClosedPipe
   197  	}
   198  	if d.Conn == nil {
   199  		select {
   200  		case <-d.ctx.Done():
   201  			return 0, io.ErrUnexpectedEOF
   202  		case <-d.dialed:
   203  		}
   204  	}
   205  	return d.Conn.Read(b)
   206  }
   207  
   208  func (d *delayDialConn) Close() error {
   209  	d.closed = true
   210  	d.cancel()
   211  	if d.Conn == nil {
   212  		return nil
   213  	}
   214  	return d.Conn.Close()
   215  }