github.com/dbernstein1/tyk@v2.9.0-beta9-dl-apic+incompatible/gateway/handler_websocket.go (about)

     1  package gateway
     2  
     3  import (
     4  	"context"
     5  	"crypto/tls"
     6  	"errors"
     7  	"io"
     8  	"net"
     9  	"net/http"
    10  	"net/url"
    11  	"strings"
    12  
    13  	"github.com/sirupsen/logrus"
    14  
    15  	"github.com/TykTechnologies/tyk/headers"
    16  	"github.com/TykTechnologies/tyk/request"
    17  
    18  	"github.com/TykTechnologies/tyk/config"
    19  )
    20  
    21  func canonicalAddr(url *url.URL) string {
    22  	addr := url.Host
    23  	// If the addr has a port number attached
    24  	if !(strings.LastIndex(addr, ":") > strings.LastIndex(addr, "]")) {
    25  		return addr + ":80"
    26  	}
    27  	return addr
    28  }
    29  
    30  type WSDialer struct {
    31  	*http.Transport
    32  	RW        http.ResponseWriter
    33  	TLSConfig *tls.Config
    34  }
    35  
    36  func (ws *WSDialer) RoundTrip(req *http.Request) (*http.Response, error) {
    37  
    38  	if !config.Global().HttpServerOptions.EnableWebSockets {
    39  		return nil, errors.New("WebSockets has been disabled on this host")
    40  	}
    41  
    42  	target := canonicalAddr(req.URL)
    43  	ip := request.RealIP(req)
    44  
    45  	// TLS
    46  	dial := ws.DialContext
    47  	if dial == nil {
    48  		var d net.Dialer
    49  		dial = d.DialContext
    50  	}
    51  
    52  	// We do not get this WSS scheme, need another way to identify it
    53  	switch req.URL.Scheme {
    54  	case "wss", "https":
    55  		var tlsConfig *tls.Config
    56  		if ws.TLSClientConfig == nil {
    57  			tlsConfig = &tls.Config{}
    58  		} else {
    59  			tlsConfig = ws.TLSClientConfig
    60  		}
    61  		dial = func(_ context.Context, network, address string) (net.Conn, error) {
    62  			return tls.Dial("tcp", target, tlsConfig)
    63  		}
    64  	}
    65  
    66  	d, err := dial(context.TODO(), "tcp", target)
    67  	if err != nil {
    68  		http.Error(ws.RW, "Error contacting backend server.", http.StatusInternalServerError)
    69  		log.WithFields(logrus.Fields{
    70  			"path":   target,
    71  			"origin": ip,
    72  		}).Error("Error dialing websocket backend", target, ": ", err)
    73  		return nil, err
    74  	}
    75  	defer d.Close()
    76  
    77  	hj, ok := ws.RW.(http.Hijacker)
    78  	if !ok {
    79  		http.Error(ws.RW, "Not a hijacker?", http.StatusInternalServerError)
    80  		return nil, errors.New("Not a hjijacker?")
    81  	}
    82  
    83  	nc, _, err := hj.Hijack()
    84  	if err != nil {
    85  		log.WithFields(logrus.Fields{
    86  			"path":   req.URL.Path,
    87  			"origin": ip,
    88  		}).Errorf("Hijack error: %v", err)
    89  		return nil, err
    90  	}
    91  	defer nc.Close()
    92  
    93  	if err := req.Write(d); err != nil {
    94  		log.WithFields(logrus.Fields{
    95  			"path":   req.URL.Path,
    96  			"origin": ip,
    97  		}).Errorf("Error copying request to target: %v", err)
    98  		return nil, err
    99  	}
   100  
   101  	errc := make(chan error, 2)
   102  	cp := func(dst io.Writer, src io.Reader) {
   103  		_, err := io.Copy(dst, src)
   104  		errc <- err
   105  	}
   106  	go cp(d, nc)
   107  	go cp(nc, d)
   108  
   109  	for i := 0; i < 2; i++ {
   110  		cerr := <-errc
   111  		if cerr == nil {
   112  			continue
   113  		}
   114  		err = cerr
   115  		log.WithFields(logrus.Fields{
   116  			"path":   req.URL.Path,
   117  			"origin": ip,
   118  		}).Errorf("Error transmitting request: %v", err)
   119  	}
   120  
   121  	return nil, err
   122  }
   123  
   124  func IsWebsocket(req *http.Request) bool {
   125  	if !config.Global().HttpServerOptions.EnableWebSockets {
   126  		return false
   127  	}
   128  
   129  	contentType := strings.ToLower(strings.TrimSpace(req.Header.Get(headers.Accept)))
   130  	if contentType == "text/event-stream" {
   131  		return true
   132  	}
   133  
   134  	connection := strings.ToLower(strings.TrimSpace(req.Header.Get(headers.Connection)))
   135  	if connection != "upgrade" {
   136  		return false
   137  	}
   138  
   139  	upgrade := strings.ToLower(strings.TrimSpace(req.Header.Get("Upgrade")))
   140  	return upgrade == "websocket"
   141  }