gitlab.com/aquachain/aquachain@v1.17.16-rc3.0.20221018032414-e3ddf1e1c055/rpc/websocket.go (about) 1 // Copyright 2018 The aquachain Authors 2 // This file is part of the aquachain library. 3 // 4 // The aquachain library is free software: you can redistribute it and/or modify 5 // it under the terms of the GNU Lesser General Public License as published by 6 // the Free Software Foundation, either version 3 of the License, or 7 // (at your option) any later version. 8 // 9 // The aquachain library is distributed in the hope that it will be useful, 10 // but WITHOUT ANY WARRANTY; without even the implied warranty of 11 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 // GNU Lesser General Public License for more details. 13 // 14 // You should have received a copy of the GNU Lesser General Public License 15 // along with the aquachain library. If not, see <http://www.gnu.org/licenses/>. 16 17 package rpc 18 19 import ( 20 "bytes" 21 "encoding/json" 22 "fmt" 23 "net" 24 "net/http" 25 "os" 26 "strings" 27 28 "golang.org/x/net/websocket" 29 30 set "github.com/deckarep/golang-set" 31 "gitlab.com/aquachain/aquachain/common/log" 32 "gitlab.com/aquachain/aquachain/p2p/netutil" 33 ) 34 35 // websocketJSONCodec is a custom JSON codec with payload size enforcement and 36 // special number parsing. 37 var websocketJSONCodec = websocket.Codec{ 38 // Marshal is the stock JSON marshaller used by the websocket library too. 39 Marshal: func(v interface{}) ([]byte, byte, error) { 40 msg, err := json.Marshal(v) 41 return msg, websocket.TextFrame, err 42 }, 43 // Unmarshal is a specialized unmarshaller to properly convert numbers. 44 Unmarshal: func(msg []byte, payloadType byte, v interface{}) error { 45 dec := json.NewDecoder(bytes.NewReader(msg)) 46 dec.UseNumber() 47 48 return dec.Decode(v) 49 }, 50 } 51 52 // WebsocketHandler returns a handler that serves JSON-RPC to WebSocket connections. 53 // 54 // allowedOrigins should be a comma-separated list of allowed origin URLs. 55 // To allow connections with any origin, pass "*". 56 func (srv *Server) WebsocketHandler(allowedOrigins []string, allowedIP []string, reverseproxy bool) http.Handler { 57 return websocket.Server{ 58 Handshake: wsHandshakeValidator(allowedOrigins, allowedIP, reverseproxy), 59 Handler: func(conn *websocket.Conn) { 60 61 // Create a custom encode/decode pair to enforce payload size and number encoding 62 conn.MaxPayloadBytes = maxHTTPRequestContentLength 63 64 encoder := func(v interface{}) error { 65 return websocketJSONCodec.Send(conn, v) 66 } 67 decoder := func(v interface{}) error { 68 return websocketJSONCodec.Receive(conn, v) 69 } 70 srv.ServeCodec(NewCodec(conn, encoder, decoder), OptionMethodInvocation|OptionSubscriptions) 71 }, 72 } 73 } 74 75 // NewWSServer creates a new websocket RPC server around an API provider. 76 // 77 // Deprecated: use Server.WebsocketHandler 78 func NewWSServer(allowedOrigins []string, allowedIP []string, reverseproxy bool, srv *Server) *http.Server { 79 return &http.Server{Handler: srv.WebsocketHandler(allowedOrigins, allowedIP, reverseproxy)} 80 } 81 82 // wsHandshakeValidator returns a handler that verifies the origin during the 83 // websocket upgrade process. When a '*' is specified as an allowed origins all 84 // connections are accepted. 85 func wsHandshakeValidator(allowedOrigins, allowedIP []string, reverseProxy bool) func(*websocket.Config, *http.Request) error { 86 origins := set.NewSet() 87 allowIPset := make(netutil.Netlist, 0) 88 ws := strings.NewReplacer(" ", "", "\n", "", "\t", "") 89 for _, mask := range allowedIP { 90 mask = ws.Replace(mask) 91 if mask == "" { 92 continue 93 } 94 if mask == "*" { 95 log.Warn("Allowing public RPC access. Be sure to run with -nokeys flag!!!") 96 mask = "0.0.0.0/0" 97 } 98 _, n, err := net.ParseCIDR(mask) 99 if err != nil { 100 log.Warn("error parsing allowed IPs, not adding", "badmask", mask, "err", err) 101 continue 102 } 103 allowIPset = append(allowIPset, *n) 104 } 105 allowAllOrigins := false 106 for _, origin := range allowedOrigins { 107 if origin == "*" { 108 allowAllOrigins = true 109 } 110 if origin != "" { 111 origins.Add(strings.ToLower(origin)) 112 } 113 } 114 115 // allow localhost if no allowedOrigins are specified. 116 if len(origins.ToSlice()) == 0 { 117 origins.Add("http://localhost") 118 if hostname, err := os.Hostname(); err == nil { 119 origins.Add("http://" + strings.ToLower(hostname)) 120 } 121 } 122 123 log.Debug(fmt.Sprintf("Allowed origin(s) for WS RPC interface %v\n", origins.ToSlice())) 124 log.Debug(fmt.Sprintf("Allowed IP(s) for WS RPC interface %s\n", allowIPset.String())) 125 126 f := func(cfg *websocket.Config, req *http.Request) error { 127 checkip := func(r *http.Request, reverseProxy bool) error { 128 ip := getIP(r, reverseProxy) 129 if allowIPset.Contains(ip) { 130 return nil 131 } 132 log.Warn("unwarranted websocket request", "ip", ip) 133 return fmt.Errorf("ip not allowed") 134 } 135 136 // check ip 137 if err := checkip(req, reverseProxy); err != nil { 138 return err 139 } 140 141 // check origin header 142 origin := strings.ToLower(req.Header.Get("Origin")) 143 if allowAllOrigins || origins.Contains(origin) { 144 return nil 145 } 146 log.Warn(fmt.Sprintf("origin '%s' not allowed on WS-RPC interface\n", origin)) 147 return fmt.Errorf("origin %s not allowed", origin) 148 } 149 150 return f 151 }