github.com/ethereum-optimism/optimism/l2geth@v0.0.0-20230612200230-50b04ade19e3/rpc/websocket.go (about) 1 // Copyright 2015 The go-ethereum Authors 2 // This file is part of the go-ethereum library. 3 // 4 // The go-ethereum library is free software: you can redistribute it and/or modify 5 // it under the terms of the GNU Lesser General Public License as published by 6 // the Free Software Foundation, either version 3 of the License, or 7 // (at your option) any later version. 8 // 9 // The go-ethereum library is distributed in the hope that it will be useful, 10 // but WITHOUT ANY WARRANTY; without even the implied warranty of 11 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 // GNU Lesser General Public License for more details. 13 // 14 // You should have received a copy of the GNU Lesser General Public License 15 // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. 16 17 package rpc 18 19 import ( 20 "context" 21 "encoding/base64" 22 "fmt" 23 "net/http" 24 "net/url" 25 "os" 26 "strings" 27 "sync" 28 29 mapset "github.com/deckarep/golang-set" 30 "github.com/ethereum-optimism/optimism/l2geth/log" 31 "github.com/gorilla/websocket" 32 ) 33 34 const ( 35 wsReadBuffer = 1024 36 wsWriteBuffer = 1024 37 ) 38 39 var wsBufferPool = new(sync.Pool) 40 41 // NewWSServer creates a new websocket RPC server around an API provider. 42 // 43 // Deprecated: use Server.WebsocketHandler 44 func NewWSServer(allowedOrigins []string, srv *Server) *http.Server { 45 return &http.Server{Handler: srv.WebsocketHandler(allowedOrigins)} 46 } 47 48 // WebsocketHandler returns a handler that serves JSON-RPC to WebSocket connections. 49 // 50 // allowedOrigins should be a comma-separated list of allowed origin URLs. 51 // To allow connections with any origin, pass "*". 52 func (s *Server) WebsocketHandler(allowedOrigins []string) http.Handler { 53 var upgrader = websocket.Upgrader{ 54 ReadBufferSize: wsReadBuffer, 55 WriteBufferSize: wsWriteBuffer, 56 WriteBufferPool: wsBufferPool, 57 CheckOrigin: wsHandshakeValidator(allowedOrigins), 58 } 59 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 60 conn, err := upgrader.Upgrade(w, r, nil) 61 if err != nil { 62 log.Debug("WebSocket upgrade failed", "err", err) 63 return 64 } 65 codec := newWebsocketCodec(conn) 66 s.ServeCodec(codec, 0) 67 }) 68 } 69 70 // wsHandshakeValidator returns a handler that verifies the origin during the 71 // websocket upgrade process. When a '*' is specified as an allowed origins all 72 // connections are accepted. 73 func wsHandshakeValidator(allowedOrigins []string) func(*http.Request) bool { 74 origins := mapset.NewSet() 75 allowAllOrigins := false 76 77 for _, origin := range allowedOrigins { 78 if origin == "*" { 79 allowAllOrigins = true 80 } 81 if origin != "" { 82 origins.Add(strings.ToLower(origin)) 83 } 84 } 85 // allow localhost if no allowedOrigins are specified. 86 if len(origins.ToSlice()) == 0 { 87 origins.Add("http://localhost") 88 if hostname, err := os.Hostname(); err == nil { 89 origins.Add("http://" + strings.ToLower(hostname)) 90 } 91 } 92 log.Debug(fmt.Sprintf("Allowed origin(s) for WS RPC interface %v", origins.ToSlice())) 93 94 f := func(req *http.Request) bool { 95 // Skip origin verification if no Origin header is present. The origin check 96 // is supposed to protect against browser based attacks. Browsers always set 97 // Origin. Non-browser software can put anything in origin and checking it doesn't 98 // provide additional security. 99 if _, ok := req.Header["Origin"]; !ok { 100 return true 101 } 102 // Verify origin against whitelist. 103 origin := strings.ToLower(req.Header.Get("Origin")) 104 if allowAllOrigins || origins.Contains(origin) { 105 return true 106 } 107 log.Warn("Rejected WebSocket connection", "origin", origin) 108 return false 109 } 110 111 return f 112 } 113 114 type wsHandshakeError struct { 115 err error 116 status string 117 } 118 119 func (e wsHandshakeError) Error() string { 120 s := e.err.Error() 121 if e.status != "" { 122 s += " (HTTP status " + e.status + ")" 123 } 124 return s 125 } 126 127 // DialWebsocketWithDialer creates a new RPC client that communicates with a JSON-RPC server 128 // that is listening on the given endpoint using the provided dialer. 129 func DialWebsocketWithDialer(ctx context.Context, endpoint, origin string, dialer websocket.Dialer) (*Client, error) { 130 endpoint, header, err := wsClientHeaders(endpoint, origin) 131 if err != nil { 132 return nil, err 133 } 134 return newClient(ctx, func(ctx context.Context) (ServerCodec, error) { 135 conn, resp, err := dialer.DialContext(ctx, endpoint, header) 136 if err != nil { 137 hErr := wsHandshakeError{err: err} 138 if resp != nil { 139 hErr.status = resp.Status 140 } 141 return nil, hErr 142 } 143 return newWebsocketCodec(conn), nil 144 }) 145 } 146 147 // DialWebsocket creates a new RPC client that communicates with a JSON-RPC server 148 // that is listening on the given endpoint. 149 // 150 // The context is used for the initial connection establishment. It does not 151 // affect subsequent interactions with the client. 152 func DialWebsocket(ctx context.Context, endpoint, origin string) (*Client, error) { 153 dialer := websocket.Dialer{ 154 ReadBufferSize: wsReadBuffer, 155 WriteBufferSize: wsWriteBuffer, 156 WriteBufferPool: wsBufferPool, 157 } 158 return DialWebsocketWithDialer(ctx, endpoint, origin, dialer) 159 } 160 161 func wsClientHeaders(endpoint, origin string) (string, http.Header, error) { 162 endpointURL, err := url.Parse(endpoint) 163 if err != nil { 164 return endpoint, nil, err 165 } 166 header := make(http.Header) 167 if origin != "" { 168 header.Add("origin", origin) 169 } 170 if endpointURL.User != nil { 171 b64auth := base64.StdEncoding.EncodeToString([]byte(endpointURL.User.String())) 172 header.Add("authorization", "Basic "+b64auth) 173 endpointURL.User = nil 174 } 175 return endpointURL.String(), header, nil 176 } 177 178 func newWebsocketCodec(conn *websocket.Conn) ServerCodec { 179 conn.SetReadLimit(maxRequestContentLength) 180 return NewFuncCodec(conn, conn.WriteJSON, conn.ReadJSON) 181 }