github.com/EagleQL/Xray-core@v1.4.3/transport/internet/websocket/dialer.go (about)

     1  package websocket
     2  
     3  import (
     4  	"context"
     5  	_ "embed"
     6  	"encoding/base64"
     7  	"fmt"
     8  	"io"
     9  	"net/http"
    10  	"os"
    11  	"time"
    12  
    13  	"github.com/gorilla/websocket"
    14  
    15  	"github.com/xtls/xray-core/common"
    16  	"github.com/xtls/xray-core/common/net"
    17  	"github.com/xtls/xray-core/common/session"
    18  	"github.com/xtls/xray-core/transport/internet"
    19  	"github.com/xtls/xray-core/transport/internet/tls"
    20  )
    21  
    22  //go:embed dialer.html
    23  var webpage []byte
    24  var conns chan *websocket.Conn
    25  
    26  func init() {
    27  	if addr := os.Getenv("XRAY_BROWSER_DIALER"); addr != "" {
    28  		conns = make(chan *websocket.Conn, 256)
    29  		go http.ListenAndServe(addr, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    30  			if r.URL.Path == "/websocket" {
    31  				if conn, err := upgrader.Upgrade(w, r, nil); err == nil {
    32  					conns <- conn
    33  				} else {
    34  					fmt.Println("unexpected error")
    35  				}
    36  			} else {
    37  				w.Write(webpage)
    38  			}
    39  		}))
    40  	}
    41  }
    42  
    43  // Dial dials a WebSocket connection to the given destination.
    44  func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (internet.Connection, error) {
    45  	newError("creating connection to ", dest).WriteToLog(session.ExportIDToError(ctx))
    46  	var conn net.Conn
    47  	if streamSettings.ProtocolSettings.(*Config).Ed > 0 {
    48  		ctx, cancel := context.WithCancel(ctx)
    49  		conn = &delayDialConn{
    50  			dialed:         make(chan bool, 1),
    51  			cancel:         cancel,
    52  			ctx:            ctx,
    53  			dest:           dest,
    54  			streamSettings: streamSettings,
    55  		}
    56  	} else {
    57  		var err error
    58  		if conn, err = dialWebSocket(ctx, dest, streamSettings, nil); err != nil {
    59  			return nil, newError("failed to dial WebSocket").Base(err)
    60  		}
    61  	}
    62  	return internet.Connection(conn), nil
    63  }
    64  
    65  func init() {
    66  	common.Must(internet.RegisterTransportDialer(protocolName, Dial))
    67  }
    68  
    69  func dialWebSocket(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig, ed []byte) (net.Conn, error) {
    70  	wsSettings := streamSettings.ProtocolSettings.(*Config)
    71  
    72  	dialer := &websocket.Dialer{
    73  		NetDial: func(network, addr string) (net.Conn, error) {
    74  			return internet.DialSystem(ctx, dest, streamSettings.SocketSettings)
    75  		},
    76  		ReadBufferSize:   4 * 1024,
    77  		WriteBufferSize:  4 * 1024,
    78  		HandshakeTimeout: time.Second * 8,
    79  	}
    80  
    81  	protocol := "ws"
    82  
    83  	if config := tls.ConfigFromStreamSettings(streamSettings); config != nil {
    84  		protocol = "wss"
    85  		dialer.TLSClientConfig = config.GetTLSConfig(tls.WithDestination(dest), tls.WithNextProto("http/1.1"))
    86  	}
    87  
    88  	host := dest.NetAddr()
    89  	if (protocol == "ws" && dest.Port == 80) || (protocol == "wss" && dest.Port == 443) {
    90  		host = dest.Address.String()
    91  	}
    92  	uri := protocol + "://" + host + wsSettings.GetNormalizedPath()
    93  
    94  	if conns != nil {
    95  		data := []byte(uri)
    96  		if ed != nil {
    97  			data = append(data, " "+base64.RawURLEncoding.EncodeToString(ed)...)
    98  		}
    99  		var conn *websocket.Conn
   100  		for {
   101  			conn = <-conns
   102  			if conn.WriteMessage(websocket.TextMessage, data) != nil {
   103  				conn.Close()
   104  			} else {
   105  				break
   106  			}
   107  		}
   108  		if _, p, err := conn.ReadMessage(); err != nil {
   109  			conn.Close()
   110  			return nil, err
   111  		} else if s := string(p); s != "ok" {
   112  			conn.Close()
   113  			return nil, newError(s)
   114  		}
   115  		return newConnection(conn, conn.RemoteAddr(), nil), nil
   116  	}
   117  
   118  	header := wsSettings.GetRequestHeader()
   119  	if ed != nil {
   120  		header.Set("Sec-WebSocket-Protocol", base64.StdEncoding.EncodeToString(ed))
   121  	}
   122  
   123  	conn, resp, err := dialer.Dial(uri, header)
   124  	if err != nil {
   125  		var reason string
   126  		if resp != nil {
   127  			reason = resp.Status
   128  		}
   129  		return nil, newError("failed to dial to (", uri, "): ", reason).Base(err)
   130  	}
   131  
   132  	return newConnection(conn, conn.RemoteAddr(), nil), nil
   133  }
   134  
   135  type delayDialConn struct {
   136  	net.Conn
   137  	closed         bool
   138  	dialed         chan bool
   139  	cancel         context.CancelFunc
   140  	ctx            context.Context
   141  	dest           net.Destination
   142  	streamSettings *internet.MemoryStreamConfig
   143  }
   144  
   145  func (d *delayDialConn) Write(b []byte) (int, error) {
   146  	if d.closed {
   147  		return 0, io.ErrClosedPipe
   148  	}
   149  	if d.Conn == nil {
   150  		ed := b
   151  		if len(ed) > int(d.streamSettings.ProtocolSettings.(*Config).Ed) {
   152  			ed = nil
   153  		}
   154  		var err error
   155  		if d.Conn, err = dialWebSocket(d.ctx, d.dest, d.streamSettings, ed); err != nil {
   156  			d.Close()
   157  			return 0, newError("failed to dial WebSocket").Base(err)
   158  		}
   159  		d.dialed <- true
   160  		if ed != nil {
   161  			return len(ed), nil
   162  		}
   163  	}
   164  	return d.Conn.Write(b)
   165  }
   166  
   167  func (d *delayDialConn) Read(b []byte) (int, error) {
   168  	if d.closed {
   169  		return 0, io.ErrClosedPipe
   170  	}
   171  	if d.Conn == nil {
   172  		select {
   173  		case <-d.ctx.Done():
   174  			return 0, io.ErrUnexpectedEOF
   175  		case <-d.dialed:
   176  		}
   177  	}
   178  	return d.Conn.Read(b)
   179  }
   180  
   181  func (d *delayDialConn) Close() error {
   182  	d.closed = true
   183  	d.cancel()
   184  	if d.Conn == nil {
   185  		return nil
   186  	}
   187  	return d.Conn.Close()
   188  }