github.com/aquanetwork/aquachain@v1.7.8/rpc/websocket.go (about) 1 // Copyright 2015 The aquachain Authors 2 // This file is part of the aquachain library. 3 // 4 // The aquachain library is free software: you can redistribute it and/or modify 5 // it under the terms of the GNU Lesser General Public License as published by 6 // the Free Software Foundation, either version 3 of the License, or 7 // (at your option) any later version. 8 // 9 // The aquachain library is distributed in the hope that it will be useful, 10 // but WITHOUT ANY WARRANTY; without even the implied warranty of 11 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 // GNU Lesser General Public License for more details. 13 // 14 // You should have received a copy of the GNU Lesser General Public License 15 // along with the aquachain library. If not, see <http://www.gnu.org/licenses/>. 16 17 package rpc 18 19 import ( 20 "bytes" 21 "context" 22 "crypto/tls" 23 "encoding/json" 24 "fmt" 25 "net" 26 "net/http" 27 "net/url" 28 "os" 29 "strings" 30 "time" 31 32 "gitlab.com/aquachain/aquachain/common/log" 33 "golang.org/x/net/websocket" 34 "gopkg.in/fatih/set.v0" 35 ) 36 37 // websocketJSONCodec is a custom JSON codec with payload size enforcement and 38 // special number parsing. 39 var websocketJSONCodec = websocket.Codec{ 40 // Marshal is the stock JSON marshaller used by the websocket library too. 41 Marshal: func(v interface{}) ([]byte, byte, error) { 42 msg, err := json.Marshal(v) 43 return msg, websocket.TextFrame, err 44 }, 45 // Unmarshal is a specialized unmarshaller to properly convert numbers. 46 Unmarshal: func(msg []byte, payloadType byte, v interface{}) error { 47 dec := json.NewDecoder(bytes.NewReader(msg)) 48 dec.UseNumber() 49 50 return dec.Decode(v) 51 }, 52 } 53 54 // WebsocketHandler returns a handler that serves JSON-RPC to WebSocket connections. 55 // 56 // allowedOrigins should be a comma-separated list of allowed origin URLs. 57 // To allow connections with any origin, pass "*". 58 func (srv *Server) WebsocketHandler(allowedOrigins []string, allowedIP []string) http.Handler { 59 return websocket.Server{ 60 Handshake: wsHandshakeValidator(allowedOrigins, allowedIP), 61 Handler: func(conn *websocket.Conn) { 62 // Create a custom encode/decode pair to enforce payload size and number encoding 63 conn.MaxPayloadBytes = maxHTTPRequestContentLength 64 65 encoder := func(v interface{}) error { 66 return websocketJSONCodec.Send(conn, v) 67 } 68 decoder := func(v interface{}) error { 69 return websocketJSONCodec.Receive(conn, v) 70 } 71 srv.ServeCodec(NewCodec(conn, encoder, decoder), OptionMethodInvocation|OptionSubscriptions) 72 }, 73 } 74 } 75 76 // NewWSServer creates a new websocket RPC server around an API provider. 77 // 78 // Deprecated: use Server.WebsocketHandler 79 func NewWSServer(allowedOrigins []string, allowedIP []string, srv *Server) *http.Server { 80 return &http.Server{Handler: srv.WebsocketHandler(allowedOrigins, allowedIP)} 81 } 82 83 // wsHandshakeValidator returns a handler that verifies the origin during the 84 // websocket upgrade process. When a '*' is specified as an allowed origins all 85 // connections are accepted. 86 func wsHandshakeValidator(allowedOrigins, allowedIP []string) func(*websocket.Config, *http.Request) error { 87 origins := set.New() 88 allowedIPset := set.New() 89 allowAllOrigins := false 90 allowAllIP := false 91 for _, origin := range allowedOrigins { 92 if origin == "*" { 93 allowAllOrigins = true 94 } 95 if origin != "" { 96 origins.Add(strings.ToLower(origin)) 97 } 98 } 99 for _, ip := range allowedIP { 100 if ip == "*" { 101 allowAllIP = true 102 break 103 } 104 if ip != "" { 105 allowedIPset.Add(strings.ToLower(ip)) 106 } 107 } 108 109 // allow localhost if no allowedOrigins are specified. 110 if len(origins.List()) == 0 { 111 origins.Add("http://localhost") 112 if hostname, err := os.Hostname(); err == nil { 113 origins.Add("http://" + strings.ToLower(hostname)) 114 } 115 } 116 117 log.Debug(fmt.Sprintf("Allowed origin(s) for WS RPC interface %v\n", origins.List())) 118 119 f := func(cfg *websocket.Config, req *http.Request) error { 120 if !allowAllIP { 121 checkip := func(r *http.Request) error { 122 ip := getIP(r) 123 if ip == "" { 124 ip = strings.Split(r.RemoteAddr, ":")[0] 125 } 126 if allowedIPset.Has(ip) { 127 return nil 128 } 129 log.Warn("unwarranted websocket request", "ip", ip) 130 return fmt.Errorf("ip not allowed") 131 } 132 if err := checkip(req); err != nil { 133 return err 134 } 135 } 136 origin := strings.ToLower(req.Header.Get("Origin")) 137 if allowAllOrigins || origins.Has(origin) { 138 return nil 139 } 140 log.Warn(fmt.Sprintf("origin '%s' not allowed on WS-RPC interface\n", origin)) 141 return fmt.Errorf("origin %s not allowed", origin) 142 } 143 144 return f 145 } 146 147 // DialWebsocket creates a new RPC client that communicates with a JSON-RPC server 148 // that is listening on the given endpoint. 149 // 150 // The context is used for the initial connection establishment. It does not 151 // affect subsequent interactions with the client. 152 func DialWebsocket(ctx context.Context, endpoint, origin string) (*Client, error) { 153 if origin == "" { 154 var err error 155 if origin, err = os.Hostname(); err != nil { 156 return nil, err 157 } 158 if strings.HasPrefix(endpoint, "wss") { 159 origin = "https://" + strings.ToLower(origin) 160 } else { 161 origin = "http://" + strings.ToLower(origin) 162 } 163 } 164 config, err := websocket.NewConfig(endpoint, origin) 165 if err != nil { 166 return nil, err 167 } 168 169 return newClient(ctx, func(ctx context.Context) (net.Conn, error) { 170 return wsDialContext(ctx, config) 171 }) 172 } 173 174 func wsDialContext(ctx context.Context, config *websocket.Config) (*websocket.Conn, error) { 175 var conn net.Conn 176 var err error 177 switch config.Location.Scheme { 178 case "ws": 179 conn, err = dialContext(ctx, "tcp", wsDialAddress(config.Location)) 180 case "wss": 181 dialer := contextDialer(ctx) 182 conn, err = tls.DialWithDialer(dialer, "tcp", wsDialAddress(config.Location), config.TlsConfig) 183 default: 184 err = websocket.ErrBadScheme 185 } 186 if err != nil { 187 return nil, err 188 } 189 ws, err := websocket.NewClient(config, conn) 190 if err != nil { 191 conn.Close() 192 return nil, err 193 } 194 return ws, err 195 } 196 197 var wsPortMap = map[string]string{"ws": "80", "wss": "443"} 198 199 func wsDialAddress(location *url.URL) string { 200 if _, ok := wsPortMap[location.Scheme]; ok { 201 if _, _, err := net.SplitHostPort(location.Host); err != nil { 202 return net.JoinHostPort(location.Host, wsPortMap[location.Scheme]) 203 } 204 } 205 return location.Host 206 } 207 208 func dialContext(ctx context.Context, network, addr string) (net.Conn, error) { 209 d := &net.Dialer{KeepAlive: tcpKeepAliveInterval} 210 return d.DialContext(ctx, network, addr) 211 } 212 213 func contextDialer(ctx context.Context) *net.Dialer { 214 dialer := &net.Dialer{Cancel: ctx.Done(), KeepAlive: tcpKeepAliveInterval} 215 if deadline, ok := ctx.Deadline(); ok { 216 dialer.Deadline = deadline 217 } else { 218 dialer.Deadline = time.Now().Add(defaultDialTimeout) 219 } 220 return dialer 221 }