github.com/EagleQL/Xray-core@v1.4.3/transport/internet/websocket/dialer.go (about) 1 package websocket 2 3 import ( 4 "context" 5 _ "embed" 6 "encoding/base64" 7 "fmt" 8 "io" 9 "net/http" 10 "os" 11 "time" 12 13 "github.com/gorilla/websocket" 14 15 "github.com/xtls/xray-core/common" 16 "github.com/xtls/xray-core/common/net" 17 "github.com/xtls/xray-core/common/session" 18 "github.com/xtls/xray-core/transport/internet" 19 "github.com/xtls/xray-core/transport/internet/tls" 20 ) 21 22 //go:embed dialer.html 23 var webpage []byte 24 var conns chan *websocket.Conn 25 26 func init() { 27 if addr := os.Getenv("XRAY_BROWSER_DIALER"); addr != "" { 28 conns = make(chan *websocket.Conn, 256) 29 go http.ListenAndServe(addr, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 30 if r.URL.Path == "/websocket" { 31 if conn, err := upgrader.Upgrade(w, r, nil); err == nil { 32 conns <- conn 33 } else { 34 fmt.Println("unexpected error") 35 } 36 } else { 37 w.Write(webpage) 38 } 39 })) 40 } 41 } 42 43 // Dial dials a WebSocket connection to the given destination. 44 func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (internet.Connection, error) { 45 newError("creating connection to ", dest).WriteToLog(session.ExportIDToError(ctx)) 46 var conn net.Conn 47 if streamSettings.ProtocolSettings.(*Config).Ed > 0 { 48 ctx, cancel := context.WithCancel(ctx) 49 conn = &delayDialConn{ 50 dialed: make(chan bool, 1), 51 cancel: cancel, 52 ctx: ctx, 53 dest: dest, 54 streamSettings: streamSettings, 55 } 56 } else { 57 var err error 58 if conn, err = dialWebSocket(ctx, dest, streamSettings, nil); err != nil { 59 return nil, newError("failed to dial WebSocket").Base(err) 60 } 61 } 62 return internet.Connection(conn), nil 63 } 64 65 func init() { 66 common.Must(internet.RegisterTransportDialer(protocolName, Dial)) 67 } 68 69 func dialWebSocket(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig, ed []byte) (net.Conn, error) { 70 wsSettings := streamSettings.ProtocolSettings.(*Config) 71 72 dialer := &websocket.Dialer{ 73 NetDial: func(network, addr string) (net.Conn, error) { 74 return internet.DialSystem(ctx, dest, streamSettings.SocketSettings) 75 }, 76 ReadBufferSize: 4 * 1024, 77 WriteBufferSize: 4 * 1024, 78 HandshakeTimeout: time.Second * 8, 79 } 80 81 protocol := "ws" 82 83 if config := tls.ConfigFromStreamSettings(streamSettings); config != nil { 84 protocol = "wss" 85 dialer.TLSClientConfig = config.GetTLSConfig(tls.WithDestination(dest), tls.WithNextProto("http/1.1")) 86 } 87 88 host := dest.NetAddr() 89 if (protocol == "ws" && dest.Port == 80) || (protocol == "wss" && dest.Port == 443) { 90 host = dest.Address.String() 91 } 92 uri := protocol + "://" + host + wsSettings.GetNormalizedPath() 93 94 if conns != nil { 95 data := []byte(uri) 96 if ed != nil { 97 data = append(data, " "+base64.RawURLEncoding.EncodeToString(ed)...) 98 } 99 var conn *websocket.Conn 100 for { 101 conn = <-conns 102 if conn.WriteMessage(websocket.TextMessage, data) != nil { 103 conn.Close() 104 } else { 105 break 106 } 107 } 108 if _, p, err := conn.ReadMessage(); err != nil { 109 conn.Close() 110 return nil, err 111 } else if s := string(p); s != "ok" { 112 conn.Close() 113 return nil, newError(s) 114 } 115 return newConnection(conn, conn.RemoteAddr(), nil), nil 116 } 117 118 header := wsSettings.GetRequestHeader() 119 if ed != nil { 120 header.Set("Sec-WebSocket-Protocol", base64.StdEncoding.EncodeToString(ed)) 121 } 122 123 conn, resp, err := dialer.Dial(uri, header) 124 if err != nil { 125 var reason string 126 if resp != nil { 127 reason = resp.Status 128 } 129 return nil, newError("failed to dial to (", uri, "): ", reason).Base(err) 130 } 131 132 return newConnection(conn, conn.RemoteAddr(), nil), nil 133 } 134 135 type delayDialConn struct { 136 net.Conn 137 closed bool 138 dialed chan bool 139 cancel context.CancelFunc 140 ctx context.Context 141 dest net.Destination 142 streamSettings *internet.MemoryStreamConfig 143 } 144 145 func (d *delayDialConn) Write(b []byte) (int, error) { 146 if d.closed { 147 return 0, io.ErrClosedPipe 148 } 149 if d.Conn == nil { 150 ed := b 151 if len(ed) > int(d.streamSettings.ProtocolSettings.(*Config).Ed) { 152 ed = nil 153 } 154 var err error 155 if d.Conn, err = dialWebSocket(d.ctx, d.dest, d.streamSettings, ed); err != nil { 156 d.Close() 157 return 0, newError("failed to dial WebSocket").Base(err) 158 } 159 d.dialed <- true 160 if ed != nil { 161 return len(ed), nil 162 } 163 } 164 return d.Conn.Write(b) 165 } 166 167 func (d *delayDialConn) Read(b []byte) (int, error) { 168 if d.closed { 169 return 0, io.ErrClosedPipe 170 } 171 if d.Conn == nil { 172 select { 173 case <-d.ctx.Done(): 174 return 0, io.ErrUnexpectedEOF 175 case <-d.dialed: 176 } 177 } 178 return d.Conn.Read(b) 179 } 180 181 func (d *delayDialConn) Close() error { 182 d.closed = true 183 d.cancel() 184 if d.Conn == nil { 185 return nil 186 } 187 return d.Conn.Close() 188 }