gitlab.com/aquachain/aquachain@v1.17.16-rc3.0.20221018032414-e3ddf1e1c055/rpc/rpcclient/websocket.go (about) 1 // Copyright 2018 The aquachain Authors 2 // This file is part of the aquachain library. 3 // 4 // The aquachain library is free software: you can redistribute it and/or modify 5 // it under the terms of the GNU Lesser General Public License as published by 6 // the Free Software Foundation, either version 3 of the License, or 7 // (at your option) any later version. 8 // 9 // The aquachain library is distributed in the hope that it will be useful, 10 // but WITHOUT ANY WARRANTY; without even the implied warranty of 11 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 // GNU Lesser General Public License for more details. 13 // 14 // You should have received a copy of the GNU Lesser General Public License 15 // along with the aquachain library. If not, see <http://www.gnu.org/licenses/>. 16 17 package rpc 18 19 import ( 20 "context" 21 "crypto/tls" 22 "net" 23 "net/url" 24 "os" 25 "strings" 26 "time" 27 28 "golang.org/x/net/websocket" 29 ) 30 31 // DialWebsocket creates a new RPC client that communicates with a JSON-RPC server 32 // that is listening on the given endpoint. 33 // 34 // The context is used for the initial connection establishment. It does not 35 // affect subsequent interactions with the client. 36 func DialWebsocket(ctx context.Context, endpoint, origin string) (*Client, error) { 37 if origin == "" { 38 var err error 39 if origin, err = os.Hostname(); err != nil { 40 return nil, err 41 } 42 if strings.HasPrefix(endpoint, "wss") { 43 origin = "https://" + strings.ToLower(origin) 44 } else { 45 origin = "http://" + strings.ToLower(origin) 46 } 47 } 48 config, err := websocket.NewConfig(endpoint, origin) 49 if err != nil { 50 return nil, err 51 } 52 53 return newClient(ctx, func(ctx context.Context) (net.Conn, error) { 54 return wsDialContext(ctx, config) 55 }) 56 } 57 58 func wsDialContext(ctx context.Context, config *websocket.Config) (*websocket.Conn, error) { 59 var conn net.Conn 60 var err error 61 switch config.Location.Scheme { 62 case "ws": 63 conn, err = dialContext(ctx, "tcp", wsDialAddress(config.Location)) 64 case "wss": 65 dialer := contextDialer(ctx) 66 conn, err = tls.DialWithDialer(dialer, "tcp", wsDialAddress(config.Location), config.TlsConfig) 67 default: 68 err = websocket.ErrBadScheme 69 } 70 if err != nil { 71 return nil, err 72 } 73 ws, err := websocket.NewClient(config, conn) 74 if err != nil { 75 conn.Close() 76 return nil, err 77 } 78 return ws, err 79 } 80 81 var wsPortMap = map[string]string{"ws": "80", "wss": "443"} 82 83 func wsDialAddress(location *url.URL) string { 84 if _, ok := wsPortMap[location.Scheme]; ok { 85 if _, _, err := net.SplitHostPort(location.Host); err != nil { 86 return net.JoinHostPort(location.Host, wsPortMap[location.Scheme]) 87 } 88 } 89 return location.Host 90 } 91 92 func dialContext(ctx context.Context, network, addr string) (net.Conn, error) { 93 d := &net.Dialer{KeepAlive: tcpKeepAliveInterval} 94 return d.DialContext(ctx, network, addr) 95 } 96 97 func contextDialer(ctx context.Context) *net.Dialer { 98 dialer := &net.Dialer{Cancel: ctx.Done(), KeepAlive: tcpKeepAliveInterval} 99 if deadline, ok := ctx.Deadline(); ok { 100 dialer.Deadline = deadline 101 } else { 102 dialer.Deadline = time.Now().Add(defaultDialTimeout) 103 } 104 return dialer 105 }