github.com/chwjbn/xclash@v0.2.0/transport/vmess/websocket.go (about) 1 package vmess 2 3 import ( 4 "bytes" 5 "context" 6 "crypto/tls" 7 "encoding/base64" 8 "errors" 9 "fmt" 10 "io" 11 "net" 12 "net/http" 13 "net/url" 14 "strconv" 15 "strings" 16 "sync" 17 "time" 18 19 "github.com/gorilla/websocket" 20 ) 21 22 type websocketConn struct { 23 conn *websocket.Conn 24 reader io.Reader 25 remoteAddr net.Addr 26 27 // https://godoc.org/github.com/gorilla/websocket#hdr-Concurrency 28 rMux sync.Mutex 29 wMux sync.Mutex 30 } 31 32 type websocketWithEarlyDataConn struct { 33 net.Conn 34 underlay net.Conn 35 closed bool 36 dialed chan bool 37 cancel context.CancelFunc 38 ctx context.Context 39 config *WebsocketConfig 40 } 41 42 type WebsocketConfig struct { 43 Host string 44 Port string 45 Path string 46 Headers http.Header 47 TLS bool 48 TLSConfig *tls.Config 49 MaxEarlyData int 50 EarlyDataHeaderName string 51 } 52 53 // Read implements net.Conn.Read() 54 func (wsc *websocketConn) Read(b []byte) (int, error) { 55 wsc.rMux.Lock() 56 defer wsc.rMux.Unlock() 57 for { 58 reader, err := wsc.getReader() 59 if err != nil { 60 return 0, err 61 } 62 63 nBytes, err := reader.Read(b) 64 if err == io.EOF { 65 wsc.reader = nil 66 continue 67 } 68 return nBytes, err 69 } 70 } 71 72 // Write implements io.Writer. 73 func (wsc *websocketConn) Write(b []byte) (int, error) { 74 wsc.wMux.Lock() 75 defer wsc.wMux.Unlock() 76 if err := wsc.conn.WriteMessage(websocket.BinaryMessage, b); err != nil { 77 return 0, err 78 } 79 return len(b), nil 80 } 81 82 func (wsc *websocketConn) Close() error { 83 var errors []string 84 if err := wsc.conn.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""), time.Now().Add(time.Second*5)); err != nil { 85 errors = append(errors, err.Error()) 86 } 87 if err := wsc.conn.Close(); err != nil { 88 errors = append(errors, err.Error()) 89 } 90 if len(errors) > 0 { 91 return fmt.Errorf("failed to close connection: %s", strings.Join(errors, ",")) 92 } 93 return nil 94 } 95 96 func (wsc *websocketConn) getReader() (io.Reader, error) { 97 if wsc.reader != nil { 98 return wsc.reader, nil 99 } 100 101 _, reader, err := wsc.conn.NextReader() 102 if err != nil { 103 return nil, err 104 } 105 wsc.reader = reader 106 return reader, nil 107 } 108 109 func (wsc *websocketConn) LocalAddr() net.Addr { 110 return wsc.conn.LocalAddr() 111 } 112 113 func (wsc *websocketConn) RemoteAddr() net.Addr { 114 return wsc.remoteAddr 115 } 116 117 func (wsc *websocketConn) SetDeadline(t time.Time) error { 118 if err := wsc.SetReadDeadline(t); err != nil { 119 return err 120 } 121 return wsc.SetWriteDeadline(t) 122 } 123 124 func (wsc *websocketConn) SetReadDeadline(t time.Time) error { 125 return wsc.conn.SetReadDeadline(t) 126 } 127 128 func (wsc *websocketConn) SetWriteDeadline(t time.Time) error { 129 return wsc.conn.SetWriteDeadline(t) 130 } 131 132 func (wsedc *websocketWithEarlyDataConn) Dial(earlyData []byte) error { 133 base64DataBuf := &bytes.Buffer{} 134 base64EarlyDataEncoder := base64.NewEncoder(base64.RawURLEncoding, base64DataBuf) 135 136 earlyDataBuf := bytes.NewBuffer(earlyData) 137 if _, err := base64EarlyDataEncoder.Write(earlyDataBuf.Next(wsedc.config.MaxEarlyData)); err != nil { 138 return errors.New("failed to encode early data: " + err.Error()) 139 } 140 141 if errc := base64EarlyDataEncoder.Close(); errc != nil { 142 return errors.New("failed to encode early data tail: " + errc.Error()) 143 } 144 145 var err error 146 if wsedc.Conn, err = streamWebsocketConn(wsedc.underlay, wsedc.config, base64DataBuf); err != nil { 147 wsedc.Close() 148 return errors.New("failed to dial WebSocket: " + err.Error()) 149 } 150 151 wsedc.dialed <- true 152 if earlyDataBuf.Len() != 0 { 153 _, err = wsedc.Conn.Write(earlyDataBuf.Bytes()) 154 } 155 156 return err 157 } 158 159 func (wsedc *websocketWithEarlyDataConn) Write(b []byte) (int, error) { 160 if wsedc.closed { 161 return 0, io.ErrClosedPipe 162 } 163 if wsedc.Conn == nil { 164 if err := wsedc.Dial(b); err != nil { 165 return 0, err 166 } 167 return len(b), nil 168 } 169 170 return wsedc.Conn.Write(b) 171 } 172 173 func (wsedc *websocketWithEarlyDataConn) Read(b []byte) (int, error) { 174 if wsedc.closed { 175 return 0, io.ErrClosedPipe 176 } 177 if wsedc.Conn == nil { 178 select { 179 case <-wsedc.ctx.Done(): 180 return 0, io.ErrUnexpectedEOF 181 case <-wsedc.dialed: 182 } 183 } 184 return wsedc.Conn.Read(b) 185 } 186 187 func (wsedc *websocketWithEarlyDataConn) Close() error { 188 wsedc.closed = true 189 wsedc.cancel() 190 if wsedc.Conn == nil { 191 return nil 192 } 193 return wsedc.Conn.Close() 194 } 195 196 func (wsedc *websocketWithEarlyDataConn) LocalAddr() net.Addr { 197 if wsedc.Conn == nil { 198 return wsedc.underlay.LocalAddr() 199 } 200 return wsedc.Conn.LocalAddr() 201 } 202 203 func (wsedc *websocketWithEarlyDataConn) RemoteAddr() net.Addr { 204 if wsedc.Conn == nil { 205 return wsedc.underlay.RemoteAddr() 206 } 207 return wsedc.Conn.RemoteAddr() 208 } 209 210 func (wsedc *websocketWithEarlyDataConn) SetDeadline(t time.Time) error { 211 if err := wsedc.SetReadDeadline(t); err != nil { 212 return err 213 } 214 return wsedc.SetWriteDeadline(t) 215 } 216 217 func (wsedc *websocketWithEarlyDataConn) SetReadDeadline(t time.Time) error { 218 if wsedc.Conn == nil { 219 return nil 220 } 221 return wsedc.Conn.SetReadDeadline(t) 222 } 223 224 func (wsedc *websocketWithEarlyDataConn) SetWriteDeadline(t time.Time) error { 225 if wsedc.Conn == nil { 226 return nil 227 } 228 return wsedc.Conn.SetWriteDeadline(t) 229 } 230 231 func streamWebsocketWithEarlyDataConn(conn net.Conn, c *WebsocketConfig) (net.Conn, error) { 232 ctx, cancel := context.WithCancel(context.Background()) 233 conn = &websocketWithEarlyDataConn{ 234 dialed: make(chan bool, 1), 235 cancel: cancel, 236 ctx: ctx, 237 underlay: conn, 238 config: c, 239 } 240 return conn, nil 241 } 242 243 func streamWebsocketConn(conn net.Conn, c *WebsocketConfig, earlyData *bytes.Buffer) (net.Conn, error) { 244 dialer := &websocket.Dialer{ 245 NetDial: func(network, addr string) (net.Conn, error) { 246 return conn, nil 247 }, 248 ReadBufferSize: 4 * 1024, 249 WriteBufferSize: 4 * 1024, 250 HandshakeTimeout: time.Second * 8, 251 } 252 253 scheme := "ws" 254 if c.TLS { 255 scheme = "wss" 256 dialer.TLSClientConfig = c.TLSConfig 257 } 258 259 uri := url.URL{ 260 Scheme: scheme, 261 Host: net.JoinHostPort(c.Host, c.Port), 262 Path: c.Path, 263 } 264 265 headers := http.Header{} 266 if c.Headers != nil { 267 for k := range c.Headers { 268 headers.Add(k, c.Headers.Get(k)) 269 } 270 } 271 272 if earlyData != nil { 273 if c.EarlyDataHeaderName == "" { 274 uri.Path += earlyData.String() 275 } else { 276 headers.Set(c.EarlyDataHeaderName, earlyData.String()) 277 } 278 } 279 280 wsConn, resp, err := dialer.Dial(uri.String(), headers) 281 if err != nil { 282 reason := err.Error() 283 if resp != nil { 284 reason = resp.Status 285 } 286 return nil, fmt.Errorf("dial %s error: %s", uri.Host, reason) 287 } 288 289 return &websocketConn{ 290 conn: wsConn, 291 remoteAddr: conn.RemoteAddr(), 292 }, nil 293 } 294 295 func StreamWebsocketConn(conn net.Conn, c *WebsocketConfig) (net.Conn, error) { 296 if u, err := url.Parse(c.Path); err == nil { 297 if q := u.Query(); q.Get("ed") != "" { 298 if ed, err := strconv.Atoi(q.Get("ed")); err == nil { 299 c.MaxEarlyData = ed 300 c.EarlyDataHeaderName = "Sec-WebSocket-Protocol" 301 q.Del("ed") 302 u.RawQuery = q.Encode() 303 c.Path = u.String() 304 } 305 } 306 } 307 308 if c.MaxEarlyData > 0 { 309 return streamWebsocketWithEarlyDataConn(conn, c) 310 } 311 312 return streamWebsocketConn(conn, c, nil) 313 }