github.com/SmartMeshFoundation/Spectrum@v0.0.0-20220621030607-452a266fee1e/rpc/websocket.go (about) 1 // Copyright 2015 The Spectrum Authors 2 // This file is part of the Spectrum library. 3 // 4 // The Spectrum library is free software: you can redistribute it and/or modify 5 // it under the terms of the GNU Lesser General Public License as published by 6 // the Free Software Foundation, either version 3 of the License, or 7 // (at your option) any later version. 8 // 9 // The Spectrum library is distributed in the hope that it will be useful, 10 // but WITHOUT ANY WARRANTY; without even the implied warranty of 11 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 // GNU Lesser General Public License for more details. 13 // 14 // You should have received a copy of the GNU Lesser General Public License 15 // along with the Spectrum library. If not, see <http://www.gnu.org/licenses/>. 16 17 package rpc 18 19 import ( 20 "context" 21 "crypto/tls" 22 "fmt" 23 "net" 24 "net/http" 25 "net/url" 26 "os" 27 "strings" 28 "time" 29 30 "github.com/SmartMeshFoundation/Spectrum/log" 31 mapset "github.com/deckarep/golang-set" 32 "golang.org/x/net/websocket" 33 ) 34 35 // WebsocketHandler returns a handler that serves JSON-RPC to WebSocket connections. 36 // 37 // allowedOrigins should be a comma-separated list of allowed origin URLs. 38 // To allow connections with any origin, pass "*". 39 func (srv *Server) WebsocketHandler(allowedOrigins []string) http.Handler { 40 return websocket.Server{ 41 Handshake: wsHandshakeValidator(allowedOrigins), 42 Handler: func(conn *websocket.Conn) { 43 srv.ServeCodec(NewJSONCodec(conn), OptionMethodInvocation|OptionSubscriptions) 44 }, 45 } 46 } 47 48 // NewWSServer creates a new websocket RPC server around an API provider. 49 // 50 // Deprecated: use Server.WebsocketHandler 51 func NewWSServer(allowedOrigins []string, srv *Server) *http.Server { 52 return &http.Server{Handler: srv.WebsocketHandler(allowedOrigins)} 53 } 54 55 // wsHandshakeValidator returns a handler that verifies the origin during the 56 // websocket upgrade process. When a '*' is specified as an allowed origins all 57 // connections are accepted. 58 func wsHandshakeValidator(allowedOrigins []string) func(*websocket.Config, *http.Request) error { 59 //origins := set.New() 60 origins := mapset.NewSet() 61 62 allowAllOrigins := false 63 64 for _, origin := range allowedOrigins { 65 if origin == "*" { 66 allowAllOrigins = true 67 } 68 if origin != "" { 69 origins.Add(strings.ToLower(origin)) 70 } 71 } 72 73 // allow localhost if no allowedOrigins are specified. 74 if len(origins.ToSlice()) == 0 { 75 origins.Add("http://localhost") 76 if hostname, err := os.Hostname(); err == nil { 77 origins.Add("http://" + strings.ToLower(hostname)) 78 } 79 } 80 81 log.Debug(fmt.Sprintf("Allowed origin(s) for WS RPC interface %v\n", origins.ToSlice())) 82 83 f := func(cfg *websocket.Config, req *http.Request) error { 84 origin := strings.ToLower(req.Header.Get("Origin")) 85 if allowAllOrigins || origins.Contains(origin) { 86 return nil 87 } 88 log.Warn(fmt.Sprintf("origin '%s' not allowed on WS-RPC interface\n", origin)) 89 return fmt.Errorf("origin %s not allowed", origin) 90 } 91 92 return f 93 } 94 95 // DialWebsocket creates a new RPC client that communicates with a JSON-RPC server 96 // that is listening on the given endpoint. 97 // 98 // The context is used for the initial connection establishment. It does not 99 // affect subsequent interactions with the client. 100 func DialWebsocket(ctx context.Context, endpoint, origin string) (*Client, error) { 101 if origin == "" { 102 var err error 103 if origin, err = os.Hostname(); err != nil { 104 return nil, err 105 } 106 if strings.HasPrefix(endpoint, "wss") { 107 origin = "https://" + strings.ToLower(origin) 108 } else { 109 origin = "http://" + strings.ToLower(origin) 110 } 111 } 112 config, err := websocket.NewConfig(endpoint, origin) 113 if err != nil { 114 return nil, err 115 } 116 117 return newClient(ctx, func(ctx context.Context) (net.Conn, error) { 118 return wsDialContext(ctx, config) 119 }) 120 } 121 122 func wsDialContext(ctx context.Context, config *websocket.Config) (*websocket.Conn, error) { 123 var conn net.Conn 124 var err error 125 switch config.Location.Scheme { 126 case "ws": 127 conn, err = dialContext(ctx, "tcp", wsDialAddress(config.Location)) 128 case "wss": 129 dialer := contextDialer(ctx) 130 conn, err = tls.DialWithDialer(dialer, "tcp", wsDialAddress(config.Location), config.TlsConfig) 131 default: 132 err = websocket.ErrBadScheme 133 } 134 if err != nil { 135 return nil, err 136 } 137 ws, err := websocket.NewClient(config, conn) 138 if err != nil { 139 conn.Close() 140 return nil, err 141 } 142 return ws, err 143 } 144 145 var wsPortMap = map[string]string{"ws": "80", "wss": "443"} 146 147 func wsDialAddress(location *url.URL) string { 148 if _, ok := wsPortMap[location.Scheme]; ok { 149 if _, _, err := net.SplitHostPort(location.Host); err != nil { 150 return net.JoinHostPort(location.Host, wsPortMap[location.Scheme]) 151 } 152 } 153 return location.Host 154 } 155 156 func dialContext(ctx context.Context, network, addr string) (net.Conn, error) { 157 d := &net.Dialer{KeepAlive: tcpKeepAliveInterval} 158 return d.DialContext(ctx, network, addr) 159 } 160 161 func contextDialer(ctx context.Context) *net.Dialer { 162 dialer := &net.Dialer{Cancel: ctx.Done(), KeepAlive: tcpKeepAliveInterval} 163 if deadline, ok := ctx.Deadline(); ok { 164 dialer.Deadline = deadline 165 } else { 166 dialer.Deadline = time.Now().Add(defaultDialTimeout) 167 } 168 return dialer 169 }