github.com/dbernstein1/tyk@v2.9.0-beta9-dl-apic+incompatible/gateway/handler_websocket.go (about) 1 package gateway 2 3 import ( 4 "context" 5 "crypto/tls" 6 "errors" 7 "io" 8 "net" 9 "net/http" 10 "net/url" 11 "strings" 12 13 "github.com/sirupsen/logrus" 14 15 "github.com/TykTechnologies/tyk/headers" 16 "github.com/TykTechnologies/tyk/request" 17 18 "github.com/TykTechnologies/tyk/config" 19 ) 20 21 func canonicalAddr(url *url.URL) string { 22 addr := url.Host 23 // If the addr has a port number attached 24 if !(strings.LastIndex(addr, ":") > strings.LastIndex(addr, "]")) { 25 return addr + ":80" 26 } 27 return addr 28 } 29 30 type WSDialer struct { 31 *http.Transport 32 RW http.ResponseWriter 33 TLSConfig *tls.Config 34 } 35 36 func (ws *WSDialer) RoundTrip(req *http.Request) (*http.Response, error) { 37 38 if !config.Global().HttpServerOptions.EnableWebSockets { 39 return nil, errors.New("WebSockets has been disabled on this host") 40 } 41 42 target := canonicalAddr(req.URL) 43 ip := request.RealIP(req) 44 45 // TLS 46 dial := ws.DialContext 47 if dial == nil { 48 var d net.Dialer 49 dial = d.DialContext 50 } 51 52 // We do not get this WSS scheme, need another way to identify it 53 switch req.URL.Scheme { 54 case "wss", "https": 55 var tlsConfig *tls.Config 56 if ws.TLSClientConfig == nil { 57 tlsConfig = &tls.Config{} 58 } else { 59 tlsConfig = ws.TLSClientConfig 60 } 61 dial = func(_ context.Context, network, address string) (net.Conn, error) { 62 return tls.Dial("tcp", target, tlsConfig) 63 } 64 } 65 66 d, err := dial(context.TODO(), "tcp", target) 67 if err != nil { 68 http.Error(ws.RW, "Error contacting backend server.", http.StatusInternalServerError) 69 log.WithFields(logrus.Fields{ 70 "path": target, 71 "origin": ip, 72 }).Error("Error dialing websocket backend", target, ": ", err) 73 return nil, err 74 } 75 defer d.Close() 76 77 hj, ok := ws.RW.(http.Hijacker) 78 if !ok { 79 http.Error(ws.RW, "Not a hijacker?", http.StatusInternalServerError) 80 return nil, errors.New("Not a hjijacker?") 81 } 82 83 nc, _, err := hj.Hijack() 84 if err != nil { 85 log.WithFields(logrus.Fields{ 86 "path": req.URL.Path, 87 "origin": ip, 88 }).Errorf("Hijack error: %v", err) 89 return nil, err 90 } 91 defer nc.Close() 92 93 if err := req.Write(d); err != nil { 94 log.WithFields(logrus.Fields{ 95 "path": req.URL.Path, 96 "origin": ip, 97 }).Errorf("Error copying request to target: %v", err) 98 return nil, err 99 } 100 101 errc := make(chan error, 2) 102 cp := func(dst io.Writer, src io.Reader) { 103 _, err := io.Copy(dst, src) 104 errc <- err 105 } 106 go cp(d, nc) 107 go cp(nc, d) 108 109 for i := 0; i < 2; i++ { 110 cerr := <-errc 111 if cerr == nil { 112 continue 113 } 114 err = cerr 115 log.WithFields(logrus.Fields{ 116 "path": req.URL.Path, 117 "origin": ip, 118 }).Errorf("Error transmitting request: %v", err) 119 } 120 121 return nil, err 122 } 123 124 func IsWebsocket(req *http.Request) bool { 125 if !config.Global().HttpServerOptions.EnableWebSockets { 126 return false 127 } 128 129 contentType := strings.ToLower(strings.TrimSpace(req.Header.Get(headers.Accept))) 130 if contentType == "text/event-stream" { 131 return true 132 } 133 134 connection := strings.ToLower(strings.TrimSpace(req.Header.Get(headers.Connection))) 135 if connection != "upgrade" { 136 return false 137 } 138 139 upgrade := strings.ToLower(strings.TrimSpace(req.Header.Get("Upgrade"))) 140 return upgrade == "websocket" 141 }