github.com/codingeasygo/util@v0.0.0-20231206062002-1ce2f004b7d9/xnet/ws.go (about)

     1  package xnet
     2  
     3  import (
     4  	"crypto/tls"
     5  	"encoding/base64"
     6  	"fmt"
     7  	"io"
     8  	"net"
     9  	"net/http"
    10  	"net/url"
    11  	"strconv"
    12  	"strings"
    13  	"time"
    14  
    15  	"golang.org/x/net/websocket"
    16  )
    17  
    18  // WebsocketDialer is an implementation of Dialer by websocket
    19  type WebsocketDialer struct {
    20  	Dialer    RawDialer
    21  	HeaderGen func(remote string) (header http.Header)
    22  	TlsConfig *tls.Config
    23  }
    24  
    25  // NewWebsocketDialer will create new WebsocketDialer
    26  func NewWebsocketDialer() (dialer *WebsocketDialer) {
    27  	dialer = &WebsocketDialer{
    28  		Dialer:    &net.Dialer{},
    29  		TlsConfig: &tls.Config{},
    30  	}
    31  	return
    32  }
    33  
    34  // Dial dial to remote by websocket
    35  func (w *WebsocketDialer) Dial(remote string) (raw io.ReadWriteCloser, err error) {
    36  	targetURL, err := url.Parse(remote)
    37  	if err != nil {
    38  		return
    39  	}
    40  	username, password := targetURL.Query().Get("username"), targetURL.Query().Get("password")
    41  	if len(username) < 1 {
    42  		username = targetURL.User.Username()
    43  		password, _ = targetURL.User.Password()
    44  	}
    45  	skipVerify := targetURL.Query().Get("skip_verify") == "1" || w.TlsConfig.InsecureSkipVerify
    46  	timeout, _ := strconv.ParseUint(targetURL.Query().Get("timeout"), 10, 32)
    47  	if timeout < 1 {
    48  		timeout = 5
    49  	}
    50  	var origin string
    51  	if targetURL.Scheme == "wss" {
    52  		origin = fmt.Sprintf("https://%v", targetURL.Host)
    53  	} else {
    54  		origin = fmt.Sprintf("http://%v", targetURL.Host)
    55  	}
    56  	config, err := websocket.NewConfig(targetURL.String(), origin)
    57  	if err == nil {
    58  		if w.HeaderGen != nil {
    59  			config.Header = w.HeaderGen(remote)
    60  		}
    61  		if len(username) > 0 && len(password) > 0 {
    62  			config.Header.Set("Authorization", "Basic "+basicAuth(username, password))
    63  		}
    64  		colonPos := strings.LastIndex(config.Location.Host, ":")
    65  		if colonPos == -1 {
    66  			colonPos = len(config.Location.Host)
    67  		}
    68  		hostname := config.Location.Host[:colonPos]
    69  		config.TlsConfig = w.TlsConfig
    70  		if len(config.TlsConfig.ServerName) < 1 {
    71  			config.TlsConfig.ServerName = hostname
    72  		}
    73  		config.TlsConfig.InsecureSkipVerify = skipVerify
    74  		raw, err = w.dial(config, time.Duration(timeout)*time.Second)
    75  	}
    76  	return
    77  }
    78  
    79  var portMap = map[string]string{
    80  	"ws":  "80",
    81  	"wss": "443",
    82  }
    83  
    84  func parseAuthority(location *url.URL) string {
    85  	if _, ok := portMap[location.Scheme]; ok {
    86  		if _, _, err := net.SplitHostPort(location.Host); err != nil {
    87  			return net.JoinHostPort(location.Host, portMap[location.Scheme])
    88  		}
    89  	}
    90  	return location.Host
    91  }
    92  
    93  func tlsHandshake(rawConn net.Conn, timeout time.Duration, config *tls.Config) (conn *tls.Conn, err error) {
    94  	errChannel := make(chan error, 2)
    95  	time.AfterFunc(timeout, func() {
    96  		errChannel <- fmt.Errorf("timeout")
    97  	})
    98  	conn = tls.Client(rawConn, config)
    99  	go func() {
   100  		errChannel <- conn.Handshake()
   101  	}()
   102  	err = <-errChannel
   103  	return
   104  }
   105  
   106  func (w *WebsocketDialer) dial(config *websocket.Config, timeout time.Duration) (conn net.Conn, err error) {
   107  	remote := parseAuthority(config.Location)
   108  	rawConn, err := w.Dialer.Dial("tcp", remote)
   109  	if err == nil {
   110  		if config.Location.Scheme == "wss" {
   111  			conn, err = tlsHandshake(rawConn, timeout, config.TlsConfig)
   112  		} else {
   113  			conn = rawConn
   114  		}
   115  		if err == nil {
   116  			conn, err = websocket.NewClient(config, conn)
   117  		}
   118  		if err != nil {
   119  			rawConn.Close()
   120  		}
   121  	}
   122  	return
   123  }
   124  
   125  func basicAuth(username, password string) string {
   126  	auth := username + ":" + password
   127  	return base64.StdEncoding.EncodeToString([]byte(auth))
   128  }