github.com/pslzym/go-ethereum@v1.8.17-0.20180926104442-4b6824e07b1b/rpc/websocket.go (about) 1 // Copyright 2015 The go-ethereum Authors 2 // This file is part of the go-ethereum library. 3 // 4 // The go-ethereum library is free software: you can redistribute it and/or modify 5 // it under the terms of the GNU Lesser General Public License as published by 6 // the Free Software Foundation, either version 3 of the License, or 7 // (at your option) any later version. 8 // 9 // The go-ethereum library is distributed in the hope that it will be useful, 10 // but WITHOUT ANY WARRANTY; without even the implied warranty of 11 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 // GNU Lesser General Public License for more details. 13 // 14 // You should have received a copy of the GNU Lesser General Public License 15 // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. 16 17 package rpc 18 19 import ( 20 "bytes" 21 "context" 22 "crypto/tls" 23 "encoding/base64" 24 "encoding/json" 25 "fmt" 26 "net" 27 "net/http" 28 "net/url" 29 "os" 30 "strings" 31 "time" 32 33 mapset "github.com/deckarep/golang-set" 34 "github.com/ethereum/go-ethereum/log" 35 "golang.org/x/net/websocket" 36 ) 37 38 // websocketJSONCodec is a custom JSON codec with payload size enforcement and 39 // special number parsing. 40 var websocketJSONCodec = websocket.Codec{ 41 // Marshal is the stock JSON marshaller used by the websocket library too. 42 Marshal: func(v interface{}) ([]byte, byte, error) { 43 msg, err := json.Marshal(v) 44 return msg, websocket.TextFrame, err 45 }, 46 // Unmarshal is a specialized unmarshaller to properly convert numbers. 47 Unmarshal: func(msg []byte, payloadType byte, v interface{}) error { 48 dec := json.NewDecoder(bytes.NewReader(msg)) 49 dec.UseNumber() 50 51 return dec.Decode(v) 52 }, 53 } 54 55 // WebsocketHandler returns a handler that serves JSON-RPC to WebSocket connections. 56 // 57 // allowedOrigins should be a comma-separated list of allowed origin URLs. 58 // To allow connections with any origin, pass "*". 59 func (srv *Server) WebsocketHandler(allowedOrigins []string) http.Handler { 60 return websocket.Server{ 61 Handshake: wsHandshakeValidator(allowedOrigins), 62 Handler: func(conn *websocket.Conn) { 63 // Create a custom encode/decode pair to enforce payload size and number encoding 64 conn.MaxPayloadBytes = maxRequestContentLength 65 66 encoder := func(v interface{}) error { 67 return websocketJSONCodec.Send(conn, v) 68 } 69 decoder := func(v interface{}) error { 70 return websocketJSONCodec.Receive(conn, v) 71 } 72 srv.ServeCodec(NewCodec(conn, encoder, decoder), OptionMethodInvocation|OptionSubscriptions) 73 }, 74 } 75 } 76 77 // NewWSServer creates a new websocket RPC server around an API provider. 78 // 79 // Deprecated: use Server.WebsocketHandler 80 func NewWSServer(allowedOrigins []string, srv *Server) *http.Server { 81 return &http.Server{Handler: srv.WebsocketHandler(allowedOrigins)} 82 } 83 84 // wsHandshakeValidator returns a handler that verifies the origin during the 85 // websocket upgrade process. When a '*' is specified as an allowed origins all 86 // connections are accepted. 87 func wsHandshakeValidator(allowedOrigins []string) func(*websocket.Config, *http.Request) error { 88 origins := mapset.NewSet() 89 allowAllOrigins := false 90 91 for _, origin := range allowedOrigins { 92 if origin == "*" { 93 allowAllOrigins = true 94 } 95 if origin != "" { 96 origins.Add(strings.ToLower(origin)) 97 } 98 } 99 100 // allow localhost if no allowedOrigins are specified. 101 if len(origins.ToSlice()) == 0 { 102 origins.Add("http://localhost") 103 if hostname, err := os.Hostname(); err == nil { 104 origins.Add("http://" + strings.ToLower(hostname)) 105 } 106 } 107 108 log.Debug(fmt.Sprintf("Allowed origin(s) for WS RPC interface %v\n", origins.ToSlice())) 109 110 f := func(cfg *websocket.Config, req *http.Request) error { 111 origin := strings.ToLower(req.Header.Get("Origin")) 112 if allowAllOrigins || origins.Contains(origin) { 113 return nil 114 } 115 log.Warn(fmt.Sprintf("origin '%s' not allowed on WS-RPC interface\n", origin)) 116 return fmt.Errorf("origin %s not allowed", origin) 117 } 118 119 return f 120 } 121 122 func wsGetConfig(endpoint, origin string) (*websocket.Config, error) { 123 if origin == "" { 124 var err error 125 if origin, err = os.Hostname(); err != nil { 126 return nil, err 127 } 128 if strings.HasPrefix(endpoint, "wss") { 129 origin = "https://" + strings.ToLower(origin) 130 } else { 131 origin = "http://" + strings.ToLower(origin) 132 } 133 } 134 config, err := websocket.NewConfig(endpoint, origin) 135 if err != nil { 136 return nil, err 137 } 138 139 if config.Location.User != nil { 140 b64auth := base64.StdEncoding.EncodeToString([]byte(config.Location.User.String())) 141 config.Header.Add("Authorization", "Basic "+b64auth) 142 config.Location.User = nil 143 } 144 return config, nil 145 } 146 147 // DialWebsocket creates a new RPC client that communicates with a JSON-RPC server 148 // that is listening on the given endpoint. 149 // 150 // The context is used for the initial connection establishment. It does not 151 // affect subsequent interactions with the client. 152 func DialWebsocket(ctx context.Context, endpoint, origin string) (*Client, error) { 153 config, err := wsGetConfig(endpoint, origin) 154 if err != nil { 155 return nil, err 156 } 157 158 return newClient(ctx, func(ctx context.Context) (net.Conn, error) { 159 return wsDialContext(ctx, config) 160 }) 161 } 162 163 func wsDialContext(ctx context.Context, config *websocket.Config) (*websocket.Conn, error) { 164 var conn net.Conn 165 var err error 166 switch config.Location.Scheme { 167 case "ws": 168 conn, err = dialContext(ctx, "tcp", wsDialAddress(config.Location)) 169 case "wss": 170 dialer := contextDialer(ctx) 171 conn, err = tls.DialWithDialer(dialer, "tcp", wsDialAddress(config.Location), config.TlsConfig) 172 default: 173 err = websocket.ErrBadScheme 174 } 175 if err != nil { 176 return nil, err 177 } 178 ws, err := websocket.NewClient(config, conn) 179 if err != nil { 180 conn.Close() 181 return nil, err 182 } 183 return ws, err 184 } 185 186 var wsPortMap = map[string]string{"ws": "80", "wss": "443"} 187 188 func wsDialAddress(location *url.URL) string { 189 if _, ok := wsPortMap[location.Scheme]; ok { 190 if _, _, err := net.SplitHostPort(location.Host); err != nil { 191 return net.JoinHostPort(location.Host, wsPortMap[location.Scheme]) 192 } 193 } 194 return location.Host 195 } 196 197 func dialContext(ctx context.Context, network, addr string) (net.Conn, error) { 198 d := &net.Dialer{KeepAlive: tcpKeepAliveInterval} 199 return d.DialContext(ctx, network, addr) 200 } 201 202 func contextDialer(ctx context.Context) *net.Dialer { 203 dialer := &net.Dialer{Cancel: ctx.Done(), KeepAlive: tcpKeepAliveInterval} 204 if deadline, ok := ctx.Deadline(); ok { 205 dialer.Deadline = deadline 206 } else { 207 dialer.Deadline = time.Now().Add(defaultDialTimeout) 208 } 209 return dialer 210 }