github.com/ConsenSys/Quorum@v20.10.0+incompatible/rpc/websocket.go (about) 1 // Copyright 2015 The go-ethereum Authors 2 // This file is part of the go-ethereum library. 3 // 4 // The go-ethereum library is free software: you can redistribute it and/or modify 5 // it under the terms of the GNU Lesser General Public License as published by 6 // the Free Software Foundation, either version 3 of the License, or 7 // (at your option) any later version. 8 // 9 // The go-ethereum library is distributed in the hope that it will be useful, 10 // but WITHOUT ANY WARRANTY; without even the implied warranty of 11 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 // GNU Lesser General Public License for more details. 13 // 14 // You should have received a copy of the GNU Lesser General Public License 15 // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. 16 17 package rpc 18 19 import ( 20 "context" 21 "crypto/tls" 22 "encoding/base64" 23 "fmt" 24 "net/http" 25 "net/url" 26 "os" 27 "strings" 28 "sync" 29 30 mapset "github.com/deckarep/golang-set" 31 "github.com/ethereum/go-ethereum/log" 32 "github.com/gorilla/websocket" 33 ) 34 35 const ( 36 wsReadBuffer = 1024 37 wsWriteBuffer = 1024 38 ) 39 40 var wsBufferPool = new(sync.Pool) 41 42 // NewWSServer creates a new websocket RPC server around an API provider. 43 // 44 // Deprecated: use Server.WebsocketHandler 45 func NewWSServer(allowedOrigins []string, srv *Server) *http.Server { 46 return &http.Server{Handler: srv.WebsocketHandler(allowedOrigins)} 47 } 48 49 // WebsocketHandler returns a handler that serves JSON-RPC to WebSocket connections. 50 // 51 // allowedOrigins should be a comma-separated list of allowed origin URLs. 52 // To allow connections with any origin, pass "*". 53 func (s *Server) WebsocketHandler(allowedOrigins []string) http.Handler { 54 var upgrader = websocket.Upgrader{ 55 ReadBufferSize: wsReadBuffer, 56 WriteBufferSize: wsWriteBuffer, 57 WriteBufferPool: wsBufferPool, 58 CheckOrigin: wsHandshakeValidator(allowedOrigins), 59 } 60 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 61 conn, err := upgrader.Upgrade(w, r, nil) 62 if err != nil { 63 log.Debug("WebSocket upgrade failed", "err", err) 64 return 65 } 66 codec := newWebsocketCodec(conn) 67 s.authenticateHttpRequest(r, codec) 68 s.ServeCodec(codec, OptionMethodInvocation|OptionSubscriptions) 69 }) 70 } 71 72 // wsHandshakeValidator returns a handler that verifies the origin during the 73 // websocket upgrade process. When a '*' is specified as an allowed origins all 74 // connections are accepted. 75 func wsHandshakeValidator(allowedOrigins []string) func(*http.Request) bool { 76 origins := mapset.NewSet() 77 allowAllOrigins := false 78 79 for _, origin := range allowedOrigins { 80 if origin == "*" { 81 allowAllOrigins = true 82 } 83 if origin != "" { 84 origins.Add(strings.ToLower(origin)) 85 } 86 } 87 // allow localhost if no allowedOrigins are specified. 88 if len(origins.ToSlice()) == 0 { 89 origins.Add("http://localhost") 90 origins.Add("https://localhost") 91 if hostname, err := os.Hostname(); err == nil { 92 origins.Add("http://" + strings.ToLower(hostname)) 93 origins.Add("https://" + strings.ToLower(hostname)) 94 } 95 } 96 log.Debug(fmt.Sprintf("Allowed origin(s) for WS RPC interface %v", origins.ToSlice())) 97 98 f := func(req *http.Request) bool { 99 // Skip origin verification if no Origin header is present. The origin check 100 // is supposed to protect against browser based attacks. Browsers always set 101 // Origin. Non-browser software can put anything in origin and checking it doesn't 102 // provide additional security. 103 if _, ok := req.Header["Origin"]; !ok { 104 return true 105 } 106 // Verify origin against whitelist. 107 origin := strings.ToLower(req.Header.Get("Origin")) 108 if allowAllOrigins || origins.Contains(origin) { 109 return true 110 } 111 log.Warn("Rejected WebSocket connection", "origin", origin) 112 return false 113 } 114 115 return f 116 } 117 118 type wsHandshakeError struct { 119 err error 120 status string 121 } 122 123 func (e wsHandshakeError) Error() string { 124 s := e.err.Error() 125 if e.status != "" { 126 s += " (HTTP status " + e.status + ")" 127 } 128 return s 129 } 130 131 // DialWebsocket creates a new RPC client that communicates with a JSON-RPC server 132 // that is listening on the given endpoint. 133 // 134 // The context is used for the initial connection establishment. It does not 135 // affect subsequent interactions with the client. 136 func DialWebsocket(ctx context.Context, endpoint, origin string) (*Client, error) { 137 return DialWebsocketWithCustomTLS(ctx, endpoint, origin, nil) 138 } 139 140 // Quorum 141 // 142 // DialWebsocketWithCustomTLS creates a new RPC client that communicates with a JSON-RPC server 143 // that is listening on the given endpoint. 144 // At the same time, allowing to customize TLSClientConfig of the dialer 145 // 146 // The context is used for the initial connection establishment. It does not 147 // affect subsequent interactions with the client. 148 func DialWebsocketWithCustomTLS(ctx context.Context, endpoint, origin string, tlsConfig *tls.Config) (*Client, error) { 149 endpoint, header, err := wsClientHeaders(endpoint, origin) 150 if err != nil { 151 return nil, err 152 } 153 dialer := websocket.Dialer{ 154 ReadBufferSize: wsReadBuffer, 155 WriteBufferSize: wsWriteBuffer, 156 WriteBufferPool: wsBufferPool, 157 } 158 if tlsConfig != nil { 159 dialer.TLSClientConfig = tlsConfig 160 } 161 credProviderFunc, hasCredProviderFunc := ctx.Value(CtxCredentialsProvider).(HttpCredentialsProviderFunc) 162 return newClient(ctx, func(ctx context.Context) (ServerCodec, error) { 163 if hasCredProviderFunc { 164 token, err := credProviderFunc(ctx) 165 if err != nil { 166 log.Warn("unable to obtain credentials from provider", "err", err) 167 } else { 168 header.Set(HttpAuthorizationHeader, token) 169 } 170 } 171 conn, resp, err := dialer.DialContext(ctx, endpoint, header) 172 if err != nil { 173 hErr := wsHandshakeError{err: err} 174 if resp != nil { 175 hErr.status = resp.Status 176 } 177 return nil, hErr 178 } 179 return newWebsocketCodec(conn), nil 180 }) 181 } 182 183 func wsClientHeaders(endpoint, origin string) (string, http.Header, error) { 184 endpointURL, err := url.Parse(endpoint) 185 if err != nil { 186 return endpoint, nil, err 187 } 188 header := make(http.Header) 189 if origin != "" { 190 header.Add("origin", origin) 191 } 192 if endpointURL.User != nil { 193 b64auth := base64.StdEncoding.EncodeToString([]byte(endpointURL.User.String())) 194 header.Add(HttpAuthorizationHeader, "Basic "+b64auth) 195 endpointURL.User = nil 196 } 197 return endpointURL.String(), header, nil 198 } 199 200 func newWebsocketCodec(conn *websocket.Conn) ServerCodec { 201 conn.SetReadLimit(maxRequestContentLength) 202 return newCodec(conn, conn.WriteJSON, conn.ReadJSON) 203 }