github.com/codingeasygo/util@v0.0.0-20231206062002-1ce2f004b7d9/xnet/ws.go (about) 1 package xnet 2 3 import ( 4 "crypto/tls" 5 "encoding/base64" 6 "fmt" 7 "io" 8 "net" 9 "net/http" 10 "net/url" 11 "strconv" 12 "strings" 13 "time" 14 15 "golang.org/x/net/websocket" 16 ) 17 18 // WebsocketDialer is an implementation of Dialer by websocket 19 type WebsocketDialer struct { 20 Dialer RawDialer 21 HeaderGen func(remote string) (header http.Header) 22 TlsConfig *tls.Config 23 } 24 25 // NewWebsocketDialer will create new WebsocketDialer 26 func NewWebsocketDialer() (dialer *WebsocketDialer) { 27 dialer = &WebsocketDialer{ 28 Dialer: &net.Dialer{}, 29 TlsConfig: &tls.Config{}, 30 } 31 return 32 } 33 34 // Dial dial to remote by websocket 35 func (w *WebsocketDialer) Dial(remote string) (raw io.ReadWriteCloser, err error) { 36 targetURL, err := url.Parse(remote) 37 if err != nil { 38 return 39 } 40 username, password := targetURL.Query().Get("username"), targetURL.Query().Get("password") 41 if len(username) < 1 { 42 username = targetURL.User.Username() 43 password, _ = targetURL.User.Password() 44 } 45 skipVerify := targetURL.Query().Get("skip_verify") == "1" || w.TlsConfig.InsecureSkipVerify 46 timeout, _ := strconv.ParseUint(targetURL.Query().Get("timeout"), 10, 32) 47 if timeout < 1 { 48 timeout = 5 49 } 50 var origin string 51 if targetURL.Scheme == "wss" { 52 origin = fmt.Sprintf("https://%v", targetURL.Host) 53 } else { 54 origin = fmt.Sprintf("http://%v", targetURL.Host) 55 } 56 config, err := websocket.NewConfig(targetURL.String(), origin) 57 if err == nil { 58 if w.HeaderGen != nil { 59 config.Header = w.HeaderGen(remote) 60 } 61 if len(username) > 0 && len(password) > 0 { 62 config.Header.Set("Authorization", "Basic "+basicAuth(username, password)) 63 } 64 colonPos := strings.LastIndex(config.Location.Host, ":") 65 if colonPos == -1 { 66 colonPos = len(config.Location.Host) 67 } 68 hostname := config.Location.Host[:colonPos] 69 config.TlsConfig = w.TlsConfig 70 if len(config.TlsConfig.ServerName) < 1 { 71 config.TlsConfig.ServerName = hostname 72 } 73 config.TlsConfig.InsecureSkipVerify = skipVerify 74 raw, err = w.dial(config, time.Duration(timeout)*time.Second) 75 } 76 return 77 } 78 79 var portMap = map[string]string{ 80 "ws": "80", 81 "wss": "443", 82 } 83 84 func parseAuthority(location *url.URL) string { 85 if _, ok := portMap[location.Scheme]; ok { 86 if _, _, err := net.SplitHostPort(location.Host); err != nil { 87 return net.JoinHostPort(location.Host, portMap[location.Scheme]) 88 } 89 } 90 return location.Host 91 } 92 93 func tlsHandshake(rawConn net.Conn, timeout time.Duration, config *tls.Config) (conn *tls.Conn, err error) { 94 errChannel := make(chan error, 2) 95 time.AfterFunc(timeout, func() { 96 errChannel <- fmt.Errorf("timeout") 97 }) 98 conn = tls.Client(rawConn, config) 99 go func() { 100 errChannel <- conn.Handshake() 101 }() 102 err = <-errChannel 103 return 104 } 105 106 func (w *WebsocketDialer) dial(config *websocket.Config, timeout time.Duration) (conn net.Conn, err error) { 107 remote := parseAuthority(config.Location) 108 rawConn, err := w.Dialer.Dial("tcp", remote) 109 if err == nil { 110 if config.Location.Scheme == "wss" { 111 conn, err = tlsHandshake(rawConn, timeout, config.TlsConfig) 112 } else { 113 conn = rawConn 114 } 115 if err == nil { 116 conn, err = websocket.NewClient(config, conn) 117 } 118 if err != nil { 119 rawConn.Close() 120 } 121 } 122 return 123 } 124 125 func basicAuth(username, password string) string { 126 auth := username + ":" + password 127 return base64.StdEncoding.EncodeToString([]byte(auth)) 128 }