github.com/xxf098/lite-proxy@v0.15.1-0.20230422081941-12c69f323218/transport/vmess/websocket.go (about) 1 package vmess 2 3 import ( 4 "bytes" 5 "context" 6 "crypto/tls" 7 "encoding/base64" 8 "errors" 9 "fmt" 10 "io" 11 "net" 12 "net/http" 13 "net/url" 14 "strconv" 15 "strings" 16 "sync" 17 "time" 18 19 "github.com/gorilla/websocket" 20 ) 21 22 type websocketConn struct { 23 conn *websocket.Conn 24 reader io.Reader 25 remoteAddr net.Addr 26 27 // https://godoc.org/github.com/gorilla/websocket#hdr-Concurrency 28 rMux sync.Mutex 29 wMux sync.Mutex 30 } 31 32 type websocketWithEarlyDataConn struct { 33 net.Conn 34 underlay net.Conn 35 closed bool 36 dialed chan bool 37 cancel context.CancelFunc 38 ctx context.Context 39 config *WebsocketConfig 40 } 41 42 type WebsocketConfig struct { 43 Host string 44 Port string 45 Path string 46 Headers http.Header 47 TLS bool 48 TLSConfig *tls.Config 49 SkipCertVerify bool 50 ServerName string 51 SessionCache tls.ClientSessionCache 52 MaxEarlyData int 53 EarlyDataHeaderName string 54 } 55 56 // Read implements net.Conn.Read() 57 func (wsc *websocketConn) Read(b []byte) (int, error) { 58 wsc.rMux.Lock() 59 defer wsc.rMux.Unlock() 60 for { 61 reader, err := wsc.getReader() 62 if err != nil { 63 return 0, err 64 } 65 66 nBytes, err := reader.Read(b) 67 if err == io.EOF { 68 wsc.reader = nil 69 continue 70 } 71 return nBytes, err 72 } 73 } 74 75 // Write implements io.Writer. 76 func (wsc *websocketConn) Write(b []byte) (int, error) { 77 wsc.wMux.Lock() 78 defer wsc.wMux.Unlock() 79 if err := wsc.conn.WriteMessage(websocket.BinaryMessage, b); err != nil { 80 return 0, err 81 } 82 return len(b), nil 83 } 84 85 func (wsc *websocketConn) Close() error { 86 var errors []string 87 if err := wsc.conn.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""), time.Now().Add(time.Second*5)); err != nil { 88 errors = append(errors, err.Error()) 89 } 90 if err := wsc.conn.Close(); err != nil { 91 errors = append(errors, err.Error()) 92 } 93 if len(errors) > 0 { 94 return fmt.Errorf("failed to close connection: %s", strings.Join(errors, ",")) 95 } 96 return nil 97 } 98 99 func (wsc *websocketConn) getReader() (io.Reader, error) { 100 if wsc.reader != nil { 101 return wsc.reader, nil 102 } 103 104 _, reader, err := wsc.conn.NextReader() 105 if err != nil { 106 return nil, err 107 } 108 wsc.reader = reader 109 return reader, nil 110 } 111 112 func (wsc *websocketConn) LocalAddr() net.Addr { 113 return wsc.conn.LocalAddr() 114 } 115 116 func (wsc *websocketConn) RemoteAddr() net.Addr { 117 return wsc.remoteAddr 118 } 119 120 func (wsc *websocketConn) SetDeadline(t time.Time) error { 121 if err := wsc.SetReadDeadline(t); err != nil { 122 return err 123 } 124 return wsc.SetWriteDeadline(t) 125 } 126 127 func (wsc *websocketConn) SetReadDeadline(t time.Time) error { 128 return wsc.conn.SetReadDeadline(t) 129 } 130 131 func (wsc *websocketConn) SetWriteDeadline(t time.Time) error { 132 return wsc.conn.SetWriteDeadline(t) 133 } 134 135 func (wsedc *websocketWithEarlyDataConn) Dial(earlyData []byte) error { 136 base64DataBuf := &bytes.Buffer{} 137 base64EarlyDataEncoder := base64.NewEncoder(base64.RawURLEncoding, base64DataBuf) 138 139 earlyDataBuf := bytes.NewBuffer(earlyData) 140 if _, err := base64EarlyDataEncoder.Write(earlyDataBuf.Next(wsedc.config.MaxEarlyData)); err != nil { 141 return errors.New("failed to encode early data: " + err.Error()) 142 } 143 144 if errc := base64EarlyDataEncoder.Close(); errc != nil { 145 return errors.New("failed to encode early data tail: " + errc.Error()) 146 } 147 148 var err error 149 if wsedc.Conn, err = streamWebsocketConn(wsedc.underlay, wsedc.config, base64DataBuf); err != nil { 150 wsedc.Close() 151 return errors.New("failed to dial WebSocket: " + err.Error()) 152 } 153 154 wsedc.dialed <- true 155 if earlyDataBuf.Len() != 0 { 156 _, err = wsedc.Conn.Write(earlyDataBuf.Bytes()) 157 } 158 159 return err 160 } 161 162 func (wsedc *websocketWithEarlyDataConn) Write(b []byte) (int, error) { 163 if wsedc.closed { 164 return 0, io.ErrClosedPipe 165 } 166 if wsedc.Conn == nil { 167 if err := wsedc.Dial(b); err != nil { 168 return 0, err 169 } 170 return len(b), nil 171 } 172 173 return wsedc.Conn.Write(b) 174 } 175 176 func (wsedc *websocketWithEarlyDataConn) Read(b []byte) (int, error) { 177 if wsedc.closed { 178 return 0, io.ErrClosedPipe 179 } 180 if wsedc.Conn == nil { 181 select { 182 case <-wsedc.ctx.Done(): 183 return 0, io.ErrUnexpectedEOF 184 case <-wsedc.dialed: 185 } 186 } 187 return wsedc.Conn.Read(b) 188 } 189 190 func (wsedc *websocketWithEarlyDataConn) Close() error { 191 wsedc.closed = true 192 wsedc.cancel() 193 if wsedc.Conn == nil { 194 return nil 195 } 196 return wsedc.Conn.Close() 197 } 198 199 func (wsedc *websocketWithEarlyDataConn) LocalAddr() net.Addr { 200 if wsedc.Conn == nil { 201 return wsedc.underlay.LocalAddr() 202 } 203 return wsedc.Conn.LocalAddr() 204 } 205 206 func (wsedc *websocketWithEarlyDataConn) RemoteAddr() net.Addr { 207 if wsedc.Conn == nil { 208 return wsedc.underlay.RemoteAddr() 209 } 210 return wsedc.Conn.RemoteAddr() 211 } 212 213 func (wsedc *websocketWithEarlyDataConn) SetDeadline(t time.Time) error { 214 if err := wsedc.SetReadDeadline(t); err != nil { 215 return err 216 } 217 return wsedc.SetWriteDeadline(t) 218 } 219 220 func (wsedc *websocketWithEarlyDataConn) SetReadDeadline(t time.Time) error { 221 if wsedc.Conn == nil { 222 return nil 223 } 224 return wsedc.Conn.SetReadDeadline(t) 225 } 226 227 func (wsedc *websocketWithEarlyDataConn) SetWriteDeadline(t time.Time) error { 228 if wsedc.Conn == nil { 229 return nil 230 } 231 return wsedc.Conn.SetWriteDeadline(t) 232 } 233 234 func streamWebsocketWithEarlyDataConn(conn net.Conn, c *WebsocketConfig) (net.Conn, error) { 235 ctx, cancel := context.WithCancel(context.Background()) 236 conn = &websocketWithEarlyDataConn{ 237 dialed: make(chan bool, 1), 238 cancel: cancel, 239 ctx: ctx, 240 underlay: conn, 241 config: c, 242 } 243 return conn, nil 244 } 245 246 func streamWebsocketConn(conn net.Conn, c *WebsocketConfig, earlyData *bytes.Buffer) (net.Conn, error) { 247 dialer := &websocket.Dialer{ 248 NetDial: func(network, addr string) (net.Conn, error) { 249 return conn, nil 250 }, 251 ReadBufferSize: 4 * 1024, 252 WriteBufferSize: 4 * 1024, 253 HandshakeTimeout: time.Second * 8, 254 } 255 256 scheme := "ws" 257 if c.TLS { 258 scheme = "wss" 259 if c.TLSConfig != nil { 260 dialer.TLSClientConfig = c.TLSConfig 261 } else { 262 dialer.TLSClientConfig = &tls.Config{ 263 ServerName: c.Host, 264 InsecureSkipVerify: c.SkipCertVerify, 265 ClientSessionCache: c.SessionCache, 266 } 267 } 268 if c.ServerName != "" { 269 dialer.TLSClientConfig.ServerName = c.ServerName 270 } else if host := c.Headers.Get("Host"); host != "" { 271 dialer.TLSClientConfig.ServerName = host 272 } 273 } 274 275 u, err := url.Parse(c.Path) 276 if err != nil { 277 return nil, fmt.Errorf("parse url %s error: %w", c.Path, err) 278 } 279 280 uri := url.URL{ 281 Scheme: scheme, 282 Host: net.JoinHostPort(c.Host, c.Port), 283 Path: u.Path, 284 RawQuery: u.RawQuery, 285 } 286 287 headers := http.Header{} 288 if c.Headers != nil { 289 for k := range c.Headers { 290 headers.Add(k, c.Headers.Get(k)) 291 } 292 } 293 294 if earlyData != nil { 295 if c.EarlyDataHeaderName == "" { 296 uri.Path += earlyData.String() 297 } else { 298 headers.Set(c.EarlyDataHeaderName, earlyData.String()) 299 } 300 } 301 302 wsConn, resp, err := dialer.Dial(uri.String(), headers) 303 if err != nil { 304 reason := err.Error() 305 if resp != nil { 306 reason = resp.Status 307 } 308 return nil, fmt.Errorf("dial %s error: %s", uri.Host, reason) 309 } 310 311 return &websocketConn{ 312 conn: wsConn, 313 remoteAddr: conn.RemoteAddr(), 314 }, nil 315 } 316 317 func StreamWebsocketConn(conn net.Conn, c *WebsocketConfig) (net.Conn, error) { 318 319 if u, err := url.Parse(c.Path); err == nil { 320 if q := u.Query(); q.Get("ed") != "" { 321 if ed, err := strconv.Atoi(q.Get("ed")); err == nil { 322 c.MaxEarlyData = ed 323 c.EarlyDataHeaderName = "Sec-WebSocket-Protocol" 324 q.Del("ed") 325 u.RawQuery = q.Encode() 326 c.Path = u.String() 327 } 328 } 329 } 330 331 // dialer := &websocket.Dialer{ 332 // NetDial: func(network, addr string) (net.Conn, error) { 333 // return conn, nil 334 // }, 335 // ReadBufferSize: 4 * 1024, 336 // WriteBufferSize: 4 * 1024, 337 // HandshakeTimeout: time.Second * 8, 338 // } 339 340 // scheme := "ws" 341 // if c.TLS { 342 // scheme = "wss" 343 // if c.TLSConfig != nil { 344 // dialer.TLSClientConfig = c.TLSConfig 345 // } else { 346 // dialer.TLSClientConfig = &tls.Config{ 347 // ServerName: c.Host, 348 // InsecureSkipVerify: c.SkipCertVerify, 349 // ClientSessionCache: c.SessionCache, 350 // } 351 // } 352 // if c.ServerName != "" { 353 // dialer.TLSClientConfig.ServerName = c.ServerName 354 // } else if host := c.Headers.Get("Host"); host != "" { 355 // dialer.TLSClientConfig.ServerName = host 356 // } 357 // } 358 359 // uri := url.URL{ 360 // Scheme: scheme, 361 // Host: net.JoinHostPort(c.Host, c.Port), 362 // Path: c.Path, 363 // } 364 365 // headers := http.Header{} 366 // if c.Headers != nil { 367 // for k := range c.Headers { 368 // headers.Add(k, c.Headers.Get(k)) 369 // } 370 // } 371 372 // wsConn, resp, err := dialer.Dial(uri.String(), headers) 373 // if err != nil { 374 // reason := err.Error() 375 // if resp != nil { 376 // reason = resp.Status 377 // } 378 // return nil, fmt.Errorf("dial %s error: %s", uri.Host, reason) 379 // } 380 381 // return &websocketConn{ 382 // conn: wsConn, 383 // remoteAddr: conn.RemoteAddr(), 384 // }, nil 385 if c.MaxEarlyData > 0 { 386 return streamWebsocketWithEarlyDataConn(conn, c) 387 } 388 389 return streamWebsocketConn(conn, c, nil) 390 }