github.com/arieschain/arieschain@v0.0.0-20191023063405-37c074544356/rpc/websocket.go (about) 1 package rpc 2 3 import ( 4 "bytes" 5 "context" 6 "crypto/tls" 7 "encoding/json" 8 "fmt" 9 "net" 10 "net/http" 11 "net/url" 12 "os" 13 "strings" 14 "time" 15 16 "github.com/quickchainproject/quickchain/log" 17 "golang.org/x/net/websocket" 18 "gopkg.in/fatih/set.v0" 19 ) 20 21 // websocketJSONCodec is a custom JSON codec with payload size enforcement and 22 // special number parsing. 23 var websocketJSONCodec = websocket.Codec{ 24 // Marshal is the stock JSON marshaller used by the websocket library too. 25 Marshal: func(v interface{}) ([]byte, byte, error) { 26 msg, err := json.Marshal(v) 27 return msg, websocket.TextFrame, err 28 }, 29 // Unmarshal is a specialized unmarshaller to properly convert numbers. 30 Unmarshal: func(msg []byte, payloadType byte, v interface{}) error { 31 dec := json.NewDecoder(bytes.NewReader(msg)) 32 dec.UseNumber() 33 34 return dec.Decode(v) 35 }, 36 } 37 38 // WebsocketHandler returns a handler that serves JSON-RPC to WebSocket connections. 39 // 40 // allowedOrigins should be a comma-separated list of allowed origin URLs. 41 // To allow connections with any origin, pass "*". 42 func (srv *Server) WebsocketHandler(allowedOrigins []string) http.Handler { 43 return websocket.Server{ 44 Handshake: wsHandshakeValidator(allowedOrigins), 45 Handler: func(conn *websocket.Conn) { 46 // Create a custom encode/decode pair to enforce payload size and number encoding 47 conn.MaxPayloadBytes = maxRequestContentLength 48 49 encoder := func(v interface{}) error { 50 return websocketJSONCodec.Send(conn, v) 51 } 52 decoder := func(v interface{}) error { 53 return websocketJSONCodec.Receive(conn, v) 54 } 55 srv.ServeCodec(NewCodec(conn, encoder, decoder), OptionMethodInvocation|OptionSubscriptions) 56 }, 57 } 58 } 59 60 // NewWSServer creates a new websocket RPC server around an API provider. 61 // 62 // Deprecated: use Server.WebsocketHandler 63 func NewWSServer(allowedOrigins []string, srv *Server) *http.Server { 64 return &http.Server{Handler: srv.WebsocketHandler(allowedOrigins)} 65 } 66 67 // wsHandshakeValidator returns a handler that verifies the origin during the 68 // websocket upgrade process. When a '*' is specified as an allowed origins all 69 // connections are accepted. 70 func wsHandshakeValidator(allowedOrigins []string) func(*websocket.Config, *http.Request) error { 71 origins := set.New() 72 allowAllOrigins := false 73 74 for _, origin := range allowedOrigins { 75 if origin == "*" { 76 allowAllOrigins = true 77 } 78 if origin != "" { 79 origins.Add(strings.ToLower(origin)) 80 } 81 } 82 83 // allow localhost if no allowedOrigins are specified. 84 if len(origins.List()) == 0 { 85 origins.Add("http://localhost") 86 if hostname, err := os.Hostname(); err == nil { 87 origins.Add("http://" + strings.ToLower(hostname)) 88 } 89 } 90 91 log.Debug(fmt.Sprintf("Allowed origin(s) for WS RPC interface %v\n", origins.List())) 92 93 f := func(cfg *websocket.Config, req *http.Request) error { 94 origin := strings.ToLower(req.Header.Get("Origin")) 95 if allowAllOrigins || origins.Has(origin) { 96 return nil 97 } 98 log.Warn(fmt.Sprintf("origin '%s' not allowed on WS-RPC interface\n", origin)) 99 return fmt.Errorf("origin %s not allowed", origin) 100 } 101 102 return f 103 } 104 105 // DialWebsocket creates a new RPC client that communicates with a JSON-RPC server 106 // that is listening on the given endpoint. 107 // 108 // The context is used for the initial connection establishment. It does not 109 // affect subsequent interactions with the client. 110 func DialWebsocket(ctx context.Context, endpoint, origin string) (*Client, error) { 111 if origin == "" { 112 var err error 113 if origin, err = os.Hostname(); err != nil { 114 return nil, err 115 } 116 if strings.HasPrefix(endpoint, "wss") { 117 origin = "https://" + strings.ToLower(origin) 118 } else { 119 origin = "http://" + strings.ToLower(origin) 120 } 121 } 122 config, err := websocket.NewConfig(endpoint, origin) 123 if err != nil { 124 return nil, err 125 } 126 127 return newClient(ctx, func(ctx context.Context) (net.Conn, error) { 128 return wsDialContext(ctx, config) 129 }) 130 } 131 132 func wsDialContext(ctx context.Context, config *websocket.Config) (*websocket.Conn, error) { 133 var conn net.Conn 134 var err error 135 switch config.Location.Scheme { 136 case "ws": 137 conn, err = dialContext(ctx, "tcp", wsDialAddress(config.Location)) 138 case "wss": 139 dialer := contextDialer(ctx) 140 conn, err = tls.DialWithDialer(dialer, "tcp", wsDialAddress(config.Location), config.TlsConfig) 141 default: 142 err = websocket.ErrBadScheme 143 } 144 if err != nil { 145 return nil, err 146 } 147 ws, err := websocket.NewClient(config, conn) 148 if err != nil { 149 conn.Close() 150 return nil, err 151 } 152 return ws, err 153 } 154 155 var wsPortMap = map[string]string{"ws": "80", "wss": "443"} 156 157 func wsDialAddress(location *url.URL) string { 158 if _, ok := wsPortMap[location.Scheme]; ok { 159 if _, _, err := net.SplitHostPort(location.Host); err != nil { 160 return net.JoinHostPort(location.Host, wsPortMap[location.Scheme]) 161 } 162 } 163 return location.Host 164 } 165 166 func dialContext(ctx context.Context, network, addr string) (net.Conn, error) { 167 d := &net.Dialer{KeepAlive: tcpKeepAliveInterval} 168 return d.DialContext(ctx, network, addr) 169 } 170 171 func contextDialer(ctx context.Context) *net.Dialer { 172 dialer := &net.Dialer{Cancel: ctx.Done(), KeepAlive: tcpKeepAliveInterval} 173 if deadline, ok := ctx.Deadline(); ok { 174 dialer.Deadline = deadline 175 } else { 176 dialer.Deadline = time.Now().Add(defaultDialTimeout) 177 } 178 return dialer 179 }