github.com/neatlab/neatio@v1.7.3-0.20220425043230-d903e92fcc75/network/rpc/websocket.go (about) 1 // Copyright 2015 The go-ethereum Authors 2 // This file is part of the go-ethereum library. 3 // 4 // The go-ethereum library is free software: you can redistribute it and/or modify 5 // it under the terms of the GNU Lesser General Public License as published by 6 // the Free Software Foundation, either version 3 of the License, or 7 // (at your option) any later version. 8 // 9 // The go-ethereum library is distributed in the hope that it will be useful, 10 // but WITHOUT ANY WARRANTY; without even the implied warranty of 11 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 // GNU Lesser General Public License for more details. 13 // 14 // You should have received a copy of the GNU Lesser General Public License 15 // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. 16 17 package rpc 18 19 import ( 20 "bytes" 21 "context" 22 "crypto/tls" 23 "encoding/base64" 24 "encoding/json" 25 "errors" 26 "fmt" 27 "net" 28 "net/http" 29 "net/url" 30 "os" 31 "strings" 32 "time" 33 34 mapset "github.com/deckarep/golang-set" 35 "github.com/neatlab/neatio/chain/log" 36 "golang.org/x/net/websocket" 37 ) 38 39 // websocketJSONCodec is a custom JSON codec with payload size enforcement and 40 // special number parsing. 41 var websocketJSONCodec = websocket.Codec{ 42 // Marshal is the stock JSON marshaller used by the websocket library too. 43 Marshal: func(v interface{}) ([]byte, byte, error) { 44 msg, err := json.Marshal(v) 45 return msg, websocket.TextFrame, err 46 }, 47 // Unmarshal is a specialized unmarshaller to properly convert numbers. 48 Unmarshal: func(msg []byte, payloadType byte, v interface{}) error { 49 dec := json.NewDecoder(bytes.NewReader(msg)) 50 dec.UseNumber() 51 52 return dec.Decode(v) 53 }, 54 } 55 56 // WebsocketHandler returns a handler that serves JSON-RPC to WebSocket connections. 57 // 58 // allowedOrigins should be a comma-separated list of allowed origin URLs. 59 // To allow connections with any origin, pass "*". 60 func (s *Server) WebsocketHandler(allowedOrigins []string) http.Handler { 61 return websocket.Server{ 62 Handshake: wsHandshakeValidator(allowedOrigins), 63 Handler: func(conn *websocket.Conn) { 64 codec := newWebsocketCodec(conn) 65 s.ServeCodec(codec, OptionMethodInvocation|OptionSubscriptions) 66 }, 67 } 68 } 69 70 func newWebsocketCodec(conn *websocket.Conn) ServerCodec { 71 // Create a custom encode/decode pair to enforce payload size and number encoding 72 conn.MaxPayloadBytes = maxRequestContentLength 73 encoder := func(v interface{}) error { 74 return websocketJSONCodec.Send(conn, v) 75 } 76 decoder := func(v interface{}) error { 77 return websocketJSONCodec.Receive(conn, v) 78 } 79 rpcconn := Conn(conn) 80 if conn.IsServerConn() { 81 // Override remote address with the actual socket address because 82 // package websocket crashes if there is no request origin. 83 addr := conn.Request().RemoteAddr 84 if wsaddr := conn.RemoteAddr().(*websocket.Addr); wsaddr.URL != nil { 85 // Add origin if present. 86 addr += "(" + wsaddr.URL.String() + ")" 87 } 88 rpcconn = connWithRemoteAddr{conn, addr} 89 } 90 return NewCodec(rpcconn, encoder, decoder) 91 } 92 93 // NewWSServer creates a new websocket RPC server around an API provider. 94 // 95 // Deprecated: use Server.WebsocketHandler 96 func NewWSServer(allowedOrigins []string, srv *Server) *http.Server { 97 return &http.Server{Handler: srv.WebsocketHandler(allowedOrigins)} 98 } 99 100 // wsHandshakeValidator returns a handler that verifies the origin during the 101 // websocket upgrade process. When a '*' is specified as an allowed origins all 102 // connections are accepted. 103 func wsHandshakeValidator(allowedOrigins []string) func(*websocket.Config, *http.Request) error { 104 origins := mapset.NewSet() 105 allowAllOrigins := false 106 107 for _, origin := range allowedOrigins { 108 if origin == "*" { 109 allowAllOrigins = true 110 } 111 if origin != "" { 112 origins.Add(strings.ToLower(origin)) 113 } 114 } 115 116 // allow localhost if no allowedOrigins are specified. 117 if len(origins.ToSlice()) == 0 { 118 origins.Add("http://localhost") 119 if hostname, err := os.Hostname(); err == nil { 120 origins.Add("http://" + strings.ToLower(hostname)) 121 } 122 } 123 124 log.Debug(fmt.Sprintf("Allowed origin(s) for WS RPC interface %v", origins.ToSlice())) 125 126 f := func(cfg *websocket.Config, req *http.Request) error { 127 // Skip origin verification if no Origin header is present. The origin check 128 // is supposed to protect against browser based attacks. Browsers always set 129 // Origin. Non-browser software can put anything in origin and checking it doesn't 130 // provide additional security. 131 if _, ok := req.Header["Origin"]; !ok { 132 return nil 133 } 134 // Verify origin against whitelist. 135 origin := strings.ToLower(req.Header.Get("Origin")) 136 if allowAllOrigins || origins.Contains(origin) { 137 return nil 138 } 139 log.Warn("Rejected WebSocket connection", "origin", origin) 140 return errors.New("origin not allowed") 141 } 142 143 return f 144 } 145 146 func wsGetConfig(endpoint, origin string) (*websocket.Config, error) { 147 if origin == "" { 148 var err error 149 if origin, err = os.Hostname(); err != nil { 150 return nil, err 151 } 152 if strings.HasPrefix(endpoint, "wss") { 153 origin = "https://" + strings.ToLower(origin) 154 } else { 155 origin = "http://" + strings.ToLower(origin) 156 } 157 } 158 config, err := websocket.NewConfig(endpoint, origin) 159 if err != nil { 160 return nil, err 161 } 162 163 if config.Location.User != nil { 164 b64auth := base64.StdEncoding.EncodeToString([]byte(config.Location.User.String())) 165 config.Header.Add("Authorization", "Basic "+b64auth) 166 config.Location.User = nil 167 } 168 return config, nil 169 } 170 171 // DialWebsocket creates a new RPC client that communicates with a JSON-RPC server 172 // that is listening on the given endpoint. 173 // 174 // The context is used for the initial connection establishment. It does not 175 // affect subsequent interactions with the client. 176 func DialWebsocket(ctx context.Context, endpoint, origin string) (*Client, error) { 177 config, err := wsGetConfig(endpoint, origin) 178 if err != nil { 179 return nil, err 180 } 181 182 return newClient(ctx, func(ctx context.Context) (ServerCodec, error) { 183 conn, err := wsDialContext(ctx, config) 184 if err != nil { 185 return nil, err 186 } 187 return newWebsocketCodec(conn), nil 188 }) 189 } 190 191 func wsDialContext(ctx context.Context, config *websocket.Config) (*websocket.Conn, error) { 192 var conn net.Conn 193 var err error 194 switch config.Location.Scheme { 195 case "ws": 196 conn, err = dialContext(ctx, "tcp", wsDialAddress(config.Location)) 197 case "wss": 198 dialer := contextDialer(ctx) 199 conn, err = tls.DialWithDialer(dialer, "tcp", wsDialAddress(config.Location), config.TlsConfig) 200 default: 201 err = websocket.ErrBadScheme 202 } 203 if err != nil { 204 return nil, err 205 } 206 ws, err := websocket.NewClient(config, conn) 207 if err != nil { 208 conn.Close() 209 return nil, err 210 } 211 return ws, err 212 } 213 214 var wsPortMap = map[string]string{"ws": "80", "wss": "443"} 215 216 func wsDialAddress(location *url.URL) string { 217 if _, ok := wsPortMap[location.Scheme]; ok { 218 if _, _, err := net.SplitHostPort(location.Host); err != nil { 219 return net.JoinHostPort(location.Host, wsPortMap[location.Scheme]) 220 } 221 } 222 return location.Host 223 } 224 225 func dialContext(ctx context.Context, network, addr string) (net.Conn, error) { 226 d := &net.Dialer{KeepAlive: tcpKeepAliveInterval} 227 return d.DialContext(ctx, network, addr) 228 } 229 230 func contextDialer(ctx context.Context) *net.Dialer { 231 dialer := &net.Dialer{Cancel: ctx.Done(), KeepAlive: tcpKeepAliveInterval} 232 if deadline, ok := ctx.Deadline(); ok { 233 dialer.Deadline = deadline 234 } else { 235 dialer.Deadline = time.Now().Add(defaultDialTimeout) 236 } 237 return dialer 238 }