github.com/sberex/go-sberex@v1.8.2-0.20181113200658-ed96ac38f7d7/rpc/websocket.go (about) 1 // This file is part of the go-sberex library. The go-sberex library is 2 // free software: you can redistribute it and/or modify it under the terms 3 // of the GNU Lesser General Public License as published by the Free 4 // Software Foundation, either version 3 of the License, or (at your option) 5 // any later version. 6 // 7 // The go-sberex library is distributed in the hope that it will be useful, 8 // but WITHOUT ANY WARRANTY; without even the implied warranty of 9 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser 10 // General Public License <http://www.gnu.org/licenses/> for more details. 11 12 package rpc 13 14 import ( 15 "context" 16 "crypto/tls" 17 "fmt" 18 "net" 19 "net/http" 20 "net/url" 21 "os" 22 "strings" 23 "time" 24 25 "github.com/Sberex/go-sberex/log" 26 "golang.org/x/net/websocket" 27 "gopkg.in/fatih/set.v0" 28 ) 29 30 // WebsocketHandler returns a handler that serves JSON-RPC to WebSocket connections. 31 // 32 // allowedOrigins should be a comma-separated list of allowed origin URLs. 33 // To allow connections with any origin, pass "*". 34 func (srv *Server) WebsocketHandler(allowedOrigins []string) http.Handler { 35 return websocket.Server{ 36 Handshake: wsHandshakeValidator(allowedOrigins), 37 Handler: func(conn *websocket.Conn) { 38 srv.ServeCodec(NewJSONCodec(conn), OptionMethodInvocation|OptionSubscriptions) 39 }, 40 } 41 } 42 43 // NewWSServer creates a new websocket RPC server around an API provider. 44 // 45 // Deprecated: use Server.WebsocketHandler 46 func NewWSServer(allowedOrigins []string, srv *Server) *http.Server { 47 return &http.Server{Handler: srv.WebsocketHandler(allowedOrigins)} 48 } 49 50 // wsHandshakeValidator returns a handler that verifies the origin during the 51 // websocket upgrade process. When a '*' is specified as an allowed origins all 52 // connections are accepted. 53 func wsHandshakeValidator(allowedOrigins []string) func(*websocket.Config, *http.Request) error { 54 origins := set.New() 55 allowAllOrigins := false 56 57 for _, origin := range allowedOrigins { 58 if origin == "*" { 59 allowAllOrigins = true 60 } 61 if origin != "" { 62 origins.Add(strings.ToLower(origin)) 63 } 64 } 65 66 // allow localhost if no allowedOrigins are specified. 67 if len(origins.List()) == 0 { 68 origins.Add("http://localhost") 69 if hostname, err := os.Hostname(); err == nil { 70 origins.Add("http://" + strings.ToLower(hostname)) 71 } 72 } 73 74 log.Debug(fmt.Sprintf("Allowed origin(s) for WS RPC interface %v\n", origins.List())) 75 76 f := func(cfg *websocket.Config, req *http.Request) error { 77 origin := strings.ToLower(req.Header.Get("Origin")) 78 if allowAllOrigins || origins.Has(origin) { 79 return nil 80 } 81 log.Warn(fmt.Sprintf("origin '%s' not allowed on WS-RPC interface\n", origin)) 82 return fmt.Errorf("origin %s not allowed", origin) 83 } 84 85 return f 86 } 87 88 // DialWebsocket creates a new RPC client that communicates with a JSON-RPC server 89 // that is listening on the given endpoint. 90 // 91 // The context is used for the initial connection establishment. It does not 92 // affect subsequent interactions with the client. 93 func DialWebsocket(ctx context.Context, endpoint, origin string) (*Client, error) { 94 if origin == "" { 95 var err error 96 if origin, err = os.Hostname(); err != nil { 97 return nil, err 98 } 99 if strings.HasPrefix(endpoint, "wss") { 100 origin = "https://" + strings.ToLower(origin) 101 } else { 102 origin = "http://" + strings.ToLower(origin) 103 } 104 } 105 config, err := websocket.NewConfig(endpoint, origin) 106 if err != nil { 107 return nil, err 108 } 109 110 return newClient(ctx, func(ctx context.Context) (net.Conn, error) { 111 return wsDialContext(ctx, config) 112 }) 113 } 114 115 func wsDialContext(ctx context.Context, config *websocket.Config) (*websocket.Conn, error) { 116 var conn net.Conn 117 var err error 118 switch config.Location.Scheme { 119 case "ws": 120 conn, err = dialContext(ctx, "tcp", wsDialAddress(config.Location)) 121 case "wss": 122 dialer := contextDialer(ctx) 123 conn, err = tls.DialWithDialer(dialer, "tcp", wsDialAddress(config.Location), config.TlsConfig) 124 default: 125 err = websocket.ErrBadScheme 126 } 127 if err != nil { 128 return nil, err 129 } 130 ws, err := websocket.NewClient(config, conn) 131 if err != nil { 132 conn.Close() 133 return nil, err 134 } 135 return ws, err 136 } 137 138 var wsPortMap = map[string]string{"ws": "80", "wss": "443"} 139 140 func wsDialAddress(location *url.URL) string { 141 if _, ok := wsPortMap[location.Scheme]; ok { 142 if _, _, err := net.SplitHostPort(location.Host); err != nil { 143 return net.JoinHostPort(location.Host, wsPortMap[location.Scheme]) 144 } 145 } 146 return location.Host 147 } 148 149 func dialContext(ctx context.Context, network, addr string) (net.Conn, error) { 150 d := &net.Dialer{KeepAlive: tcpKeepAliveInterval} 151 return d.DialContext(ctx, network, addr) 152 } 153 154 func contextDialer(ctx context.Context) *net.Dialer { 155 dialer := &net.Dialer{Cancel: ctx.Done(), KeepAlive: tcpKeepAliveInterval} 156 if deadline, ok := ctx.Deadline(); ok { 157 dialer.Deadline = deadline 158 } else { 159 dialer.Deadline = time.Now().Add(defaultDialTimeout) 160 } 161 return dialer 162 }