github.com/gambol99/goproxy@v0.0.0-20170612135454-e713e5909438/websocket.go (about)

     1  package goproxy
     2  
     3  import (
     4  	"bufio"
     5  	"crypto/tls"
     6  	"io"
     7  	"net/http"
     8  	"net/url"
     9  	"strings"
    10  )
    11  
    12  func headerContains(header http.Header, name string, value string) bool {
    13  	for _, v := range header[name] {
    14  		for _, s := range strings.Split(v, ",") {
    15  			if strings.EqualFold(value, strings.TrimSpace(s)) {
    16  				return true
    17  			}
    18  		}
    19  	}
    20  	return false
    21  }
    22  
    23  func isWebSocketRequest(r *http.Request) bool {
    24  	return headerContains(r.Header, "Connection", "upgrade") &&
    25  		headerContains(r.Header, "Upgrade", "websocket")
    26  }
    27  
    28  func (proxy *ProxyHttpServer) serveWebsocketTLS(ctx *ProxyCtx, w http.ResponseWriter, req *http.Request, tlsConfig *tls.Config, clientConn *tls.Conn) {
    29  	targetURL := url.URL{Scheme: "wss", Host: req.URL.Host, Path: req.URL.Path}
    30  
    31  	// Connect to upstream
    32  	targetConn, err := tls.Dial("tcp", targetURL.Host, tlsConfig)
    33  	if err != nil {
    34  		ctx.Warnf("Error dialing target site: %v", err)
    35  		return
    36  	}
    37  	defer targetConn.Close()
    38  
    39  	// Perform handshake
    40  	if err := proxy.websocketHandshake(ctx, req, targetConn, clientConn); err != nil {
    41  		ctx.Warnf("Websocket handshake error: %v", err)
    42  		return
    43  	}
    44  
    45  	// Proxy wss connection
    46  	proxy.proxyWebsocket(ctx, targetConn, clientConn)
    47  }
    48  
    49  func (proxy *ProxyHttpServer) serveWebsocket(ctx *ProxyCtx, w http.ResponseWriter, req *http.Request) {
    50  	targetURL := url.URL{Scheme: "ws", Host: req.URL.Host, Path: req.URL.Path}
    51  
    52  	targetConn, err := proxy.connectDial("tcp", targetURL.Host)
    53  	if err != nil {
    54  		ctx.Warnf("Error dialing target site: %v", err)
    55  		return
    56  	}
    57  	defer targetConn.Close()
    58  
    59  	// Connect to Client
    60  	hj, ok := w.(http.Hijacker)
    61  	if !ok {
    62  		panic("httpserver does not support hijacking")
    63  	}
    64  	clientConn, _, err := hj.Hijack()
    65  	if err != nil {
    66  		ctx.Warnf("Hijack error: %v", err)
    67  		return
    68  	}
    69  
    70  	// Perform handshake
    71  	if err := proxy.websocketHandshake(ctx, req, targetConn, clientConn); err != nil {
    72  		ctx.Warnf("Websocket handshake error: %v", err)
    73  		return
    74  	}
    75  
    76  	// Proxy ws connection
    77  	proxy.proxyWebsocket(ctx, targetConn, clientConn)
    78  }
    79  
    80  func (proxy *ProxyHttpServer) websocketHandshake(ctx *ProxyCtx, req *http.Request, targetSiteConn io.ReadWriter, clientConn io.ReadWriter) error {
    81  	// write handshake request to target
    82  	err := req.Write(targetSiteConn)
    83  	if err != nil {
    84  		ctx.Warnf("Error writing upgrade request: %v", err)
    85  		return err
    86  	}
    87  
    88  	targetTLSReader := bufio.NewReader(targetSiteConn)
    89  
    90  	// Read handshake response from target
    91  	resp, err := http.ReadResponse(targetTLSReader, req)
    92  	if err != nil {
    93  		ctx.Warnf("Error reading handhsake response  %v", err)
    94  		return err
    95  	}
    96  
    97  	// Run response through handlers
    98  	resp = proxy.filterResponse(resp, ctx)
    99  
   100  	// Proxy handshake back to client
   101  	err = resp.Write(clientConn)
   102  	if err != nil {
   103  		ctx.Warnf("Error writing handshake response: %v", err)
   104  		return err
   105  	}
   106  	return nil
   107  }
   108  
   109  func (proxy *ProxyHttpServer) proxyWebsocket(ctx *ProxyCtx, dest io.ReadWriter, source io.ReadWriter) {
   110  	errChan := make(chan error, 2)
   111  	cp := func(dst io.Writer, src io.Reader) {
   112  		_, err := io.Copy(dst, src)
   113  		ctx.Warnf("Websocket error: %v", err)
   114  		errChan <- err
   115  	}
   116  
   117  	// Start proxying websocket data
   118  	go cp(dest, source)
   119  	go cp(source, dest)
   120  	<-errChan
   121  }