github.com/linapex/ethereum-go-chinese@v0.0.0-20190316121929-f8b7a73c3fa1/rpc/websocket.go (about) 1 2 //<developer> 3 // <name>linapex 曹一峰</name> 4 // <email>linapex@163.com</email> 5 // <wx>superexc</wx> 6 // <qqgroup>128148617</qqgroup> 7 // <url>https://jsq.ink</url> 8 // <role>pku engineer</role> 9 // <date>2019-03-16 19:16:42</date> 10 //</624450110058663936> 11 12 13 package rpc 14 15 import ( 16 "bytes" 17 "context" 18 "crypto/tls" 19 "encoding/base64" 20 "encoding/json" 21 "fmt" 22 "net" 23 "net/http" 24 "net/url" 25 "os" 26 "strings" 27 "time" 28 29 mapset "github.com/deckarep/golang-set" 30 "github.com/ethereum/go-ethereum/log" 31 "golang.org/x/net/websocket" 32 ) 33 34 //WebSocketJSoncodec是一个自定义的JSON编解码器,具有有效负载大小强制和 35 //特殊数字分析。 36 var websocketJSONCodec = websocket.Codec{ 37 //Marshal也是WebSocket库使用的常用JSON Marshaller。 38 Marshal: func(v interface{}) ([]byte, byte, error) { 39 msg, err := json.Marshal(v) 40 return msg, websocket.TextFrame, err 41 }, 42 //解组是一种特殊的解组器,用于正确转换数字。 43 Unmarshal: func(msg []byte, payloadType byte, v interface{}) error { 44 dec := json.NewDecoder(bytes.NewReader(msg)) 45 dec.UseNumber() 46 47 return dec.Decode(v) 48 }, 49 } 50 51 //WebSocketHandler返回一个为JSON-RPC到WebSocket连接提供服务的处理程序。 52 // 53 //allowedorigins应该是允许的原始URL的逗号分隔列表。 54 //要允许与任何来源的连接,请通过“*”。 55 func (srv *Server) WebsocketHandler(allowedOrigins []string) http.Handler { 56 return websocket.Server{ 57 Handshake: wsHandshakeValidator(allowedOrigins), 58 Handler: func(conn *websocket.Conn) { 59 //创建自定义编码/解码对以强制有效负载大小和数字编码 60 conn.MaxPayloadBytes = maxRequestContentLength 61 62 encoder := func(v interface{}) error { 63 return websocketJSONCodec.Send(conn, v) 64 } 65 decoder := func(v interface{}) error { 66 return websocketJSONCodec.Receive(conn, v) 67 } 68 srv.ServeCodec(NewCodec(conn, encoder, decoder), OptionMethodInvocation|OptionSubscriptions) 69 }, 70 } 71 } 72 73 //newwsserver围绕API提供程序创建新的WebSocket RPC服务器。 74 // 75 //已弃用:使用server.websockethandler 76 func NewWSServer(allowedOrigins []string, srv *Server) *http.Server { 77 return &http.Server{Handler: srv.WebsocketHandler(allowedOrigins)} 78 } 79 80 //wshandshakevalidator返回一个处理程序,该处理程序在 81 //WebSocket升级过程。当将“*”指定为允许的源时,所有 82 //接受连接。 83 func wsHandshakeValidator(allowedOrigins []string) func(*websocket.Config, *http.Request) error { 84 origins := mapset.NewSet() 85 allowAllOrigins := false 86 87 for _, origin := range allowedOrigins { 88 if origin == "*" { 89 allowAllOrigins = true 90 } 91 if origin != "" { 92 origins.Add(strings.ToLower(origin)) 93 } 94 } 95 96 //如果未指定allowedorigins,则允许localhost。 97 if len(origins.ToSlice()) == 0 { 98 origins.Add("http://“本地主机” 99 if hostname, err := os.Hostname(); err == nil { 100 origins.Add("http://“+strings.tolower(主机名)) 101 } 102 } 103 104 log.Debug(fmt.Sprintf("Allowed origin(s) for WS RPC interface %v\n", origins.ToSlice())) 105 106 f := func(cfg *websocket.Config, req *http.Request) error { 107 origin := strings.ToLower(req.Header.Get("Origin")) 108 if allowAllOrigins || origins.Contains(origin) { 109 return nil 110 } 111 log.Warn(fmt.Sprintf("origin '%s' not allowed on WS-RPC interface\n", origin)) 112 return fmt.Errorf("origin %s not allowed", origin) 113 } 114 115 return f 116 } 117 118 func wsGetConfig(endpoint, origin string) (*websocket.Config, error) { 119 if origin == "" { 120 var err error 121 if origin, err = os.Hostname(); err != nil { 122 return nil, err 123 } 124 if strings.HasPrefix(endpoint, "wss") { 125 origin = "https://“+strings.tolower(原点) 126 } else { 127 origin = "http://“+strings.tolower(原点) 128 } 129 } 130 config, err := websocket.NewConfig(endpoint, origin) 131 if err != nil { 132 return nil, err 133 } 134 135 if config.Location.User != nil { 136 b64auth := base64.StdEncoding.EncodeToString([]byte(config.Location.User.String())) 137 config.Header.Add("Authorization", "Basic "+b64auth) 138 config.Location.User = nil 139 } 140 return config, nil 141 } 142 143 //DialWebSocket创建一个新的与JSON-RPC服务器通信的RPC客户端 144 //正在侦听给定的端点。 145 // 146 //上下文用于建立初始连接。它不 147 //影响与客户的后续交互。 148 func DialWebsocket(ctx context.Context, endpoint, origin string) (*Client, error) { 149 config, err := wsGetConfig(endpoint, origin) 150 if err != nil { 151 return nil, err 152 } 153 154 return newClient(ctx, func(ctx context.Context) (net.Conn, error) { 155 return wsDialContext(ctx, config) 156 }) 157 } 158 159 func wsDialContext(ctx context.Context, config *websocket.Config) (*websocket.Conn, error) { 160 var conn net.Conn 161 var err error 162 switch config.Location.Scheme { 163 case "ws": 164 conn, err = dialContext(ctx, "tcp", wsDialAddress(config.Location)) 165 case "wss": 166 dialer := contextDialer(ctx) 167 conn, err = tls.DialWithDialer(dialer, "tcp", wsDialAddress(config.Location), config.TlsConfig) 168 default: 169 err = websocket.ErrBadScheme 170 } 171 if err != nil { 172 return nil, err 173 } 174 ws, err := websocket.NewClient(config, conn) 175 if err != nil { 176 conn.Close() 177 return nil, err 178 } 179 return ws, err 180 } 181 182 var wsPortMap = map[string]string{"ws": "80", "wss": "443"} 183 184 func wsDialAddress(location *url.URL) string { 185 if _, ok := wsPortMap[location.Scheme]; ok { 186 if _, _, err := net.SplitHostPort(location.Host); err != nil { 187 return net.JoinHostPort(location.Host, wsPortMap[location.Scheme]) 188 } 189 } 190 return location.Host 191 } 192 193 func dialContext(ctx context.Context, network, addr string) (net.Conn, error) { 194 d := &net.Dialer{KeepAlive: tcpKeepAliveInterval} 195 return d.DialContext(ctx, network, addr) 196 } 197 198 func contextDialer(ctx context.Context) *net.Dialer { 199 dialer := &net.Dialer{Cancel: ctx.Done(), KeepAlive: tcpKeepAliveInterval} 200 if deadline, ok := ctx.Deadline(); ok { 201 dialer.Deadline = deadline 202 } else { 203 dialer.Deadline = time.Now().Add(defaultDialTimeout) 204 } 205 return dialer 206 } 207