github.com/volts-dev/volts@v0.0.0-20240120094013-5e9c65924106/router/reverse_proxy.go (about)

     1  package router
     2  
     3  import (
     4  	"fmt"
     5  	"io"
     6  	"net"
     7  	"net/http"
     8  	"net/http/httputil"
     9  	"net/url"
    10  	"strings"
    11  
    12  	"github.com/volts-dev/volts/selector"
    13  )
    14  
    15  // TODO 改名称
    16  func RpcReverseProxy(ctx *TRpcContext) {}
    17  
    18  // TODO 改名称
    19  func HttpReverseProxy(ctx *THttpContext) {
    20  	service, err := getService(ctx)
    21  	if err != nil {
    22  		ctx.WriteHeader(500)
    23  		return
    24  	}
    25  
    26  	if len(service) == 0 {
    27  		ctx.WriteHeader(404)
    28  		return
    29  	}
    30  
    31  	rp, err := url.Parse(service)
    32  	if err != nil {
    33  		ctx.WriteHeader(500)
    34  		return
    35  	}
    36  
    37  	if isWebSocket(ctx) {
    38  		serveWebSocket(rp.Host, ctx.Response(), ctx.Request().Request)
    39  		return
    40  	}
    41  
    42  	httputil.NewSingleHostReverseProxy(rp).ServeHTTP(ctx.Response(), ctx.Request().Request)
    43  }
    44  
    45  // getService returns the service for this request from the selector
    46  func getService(ctx *THttpContext) (string, error) {
    47  	// create a random selector
    48  	next := selector.Random(ctx.Handler().Services)
    49  
    50  	// get the next service node
    51  	s, err := next()
    52  	if err != nil {
    53  		return "", nil
    54  	}
    55  
    56  	// FIXME http/https
    57  	return fmt.Sprintf("http://%s", s.Address), nil
    58  }
    59  
    60  // serveWebSocket used to serve a web socket proxied connection
    61  func serveWebSocket(host string, w http.ResponseWriter, r *http.Request) {
    62  	req := new(http.Request)
    63  	*req = *r
    64  
    65  	if len(host) == 0 {
    66  		http.Error(w, "invalid host", 500)
    67  		return
    68  	}
    69  
    70  	// set x-forward-for
    71  	if clientIP, _, err := net.SplitHostPort(r.RemoteAddr); err == nil {
    72  		if ips, ok := req.Header["X-Forwarded-For"]; ok {
    73  			clientIP = strings.Join(ips, ", ") + ", " + clientIP
    74  		}
    75  		req.Header.Set("X-Forwarded-For", clientIP)
    76  	}
    77  
    78  	// connect to the backend host
    79  	conn, err := net.Dial("tcp", host)
    80  	if err != nil {
    81  		http.Error(w, err.Error(), 500)
    82  		return
    83  	}
    84  
    85  	// hijack the connection
    86  	hj, ok := w.(http.Hijacker)
    87  	if !ok {
    88  		http.Error(w, "failed to connect", 500)
    89  		return
    90  	}
    91  
    92  	nc, _, err := hj.Hijack()
    93  	if err != nil {
    94  		return
    95  	}
    96  
    97  	defer nc.Close()
    98  	defer conn.Close()
    99  
   100  	if err = req.Write(conn); err != nil {
   101  		return
   102  	}
   103  
   104  	errCh := make(chan error, 2)
   105  
   106  	cp := func(dst io.Writer, src io.Reader) {
   107  		_, err := io.Copy(dst, src)
   108  		errCh <- err
   109  	}
   110  
   111  	go cp(conn, nc)
   112  	go cp(nc, conn)
   113  
   114  	<-errCh
   115  }
   116  
   117  func isWebSocket(ctx *THttpContext) bool {
   118  	contains := func(key, val string) bool {
   119  		vv := strings.Split(ctx.Request().Header().Get(key), ",")
   120  		for _, v := range vv {
   121  			if val == strings.ToLower(strings.TrimSpace(v)) {
   122  				return true
   123  			}
   124  		}
   125  		return false
   126  	}
   127  
   128  	if contains("Connection", "upgrade") && contains("Upgrade", "websocket") {
   129  		return true
   130  	}
   131  
   132  	return false
   133  }