github.com/volts-dev/volts@v0.0.0-20240120094013-5e9c65924106/router/reverse_proxy.go (about) 1 package router 2 3 import ( 4 "fmt" 5 "io" 6 "net" 7 "net/http" 8 "net/http/httputil" 9 "net/url" 10 "strings" 11 12 "github.com/volts-dev/volts/selector" 13 ) 14 15 // TODO 改名称 16 func RpcReverseProxy(ctx *TRpcContext) {} 17 18 // TODO 改名称 19 func HttpReverseProxy(ctx *THttpContext) { 20 service, err := getService(ctx) 21 if err != nil { 22 ctx.WriteHeader(500) 23 return 24 } 25 26 if len(service) == 0 { 27 ctx.WriteHeader(404) 28 return 29 } 30 31 rp, err := url.Parse(service) 32 if err != nil { 33 ctx.WriteHeader(500) 34 return 35 } 36 37 if isWebSocket(ctx) { 38 serveWebSocket(rp.Host, ctx.Response(), ctx.Request().Request) 39 return 40 } 41 42 httputil.NewSingleHostReverseProxy(rp).ServeHTTP(ctx.Response(), ctx.Request().Request) 43 } 44 45 // getService returns the service for this request from the selector 46 func getService(ctx *THttpContext) (string, error) { 47 // create a random selector 48 next := selector.Random(ctx.Handler().Services) 49 50 // get the next service node 51 s, err := next() 52 if err != nil { 53 return "", nil 54 } 55 56 // FIXME http/https 57 return fmt.Sprintf("http://%s", s.Address), nil 58 } 59 60 // serveWebSocket used to serve a web socket proxied connection 61 func serveWebSocket(host string, w http.ResponseWriter, r *http.Request) { 62 req := new(http.Request) 63 *req = *r 64 65 if len(host) == 0 { 66 http.Error(w, "invalid host", 500) 67 return 68 } 69 70 // set x-forward-for 71 if clientIP, _, err := net.SplitHostPort(r.RemoteAddr); err == nil { 72 if ips, ok := req.Header["X-Forwarded-For"]; ok { 73 clientIP = strings.Join(ips, ", ") + ", " + clientIP 74 } 75 req.Header.Set("X-Forwarded-For", clientIP) 76 } 77 78 // connect to the backend host 79 conn, err := net.Dial("tcp", host) 80 if err != nil { 81 http.Error(w, err.Error(), 500) 82 return 83 } 84 85 // hijack the connection 86 hj, ok := w.(http.Hijacker) 87 if !ok { 88 http.Error(w, "failed to connect", 500) 89 return 90 } 91 92 nc, _, err := hj.Hijack() 93 if err != nil { 94 return 95 } 96 97 defer nc.Close() 98 defer conn.Close() 99 100 if err = req.Write(conn); err != nil { 101 return 102 } 103 104 errCh := make(chan error, 2) 105 106 cp := func(dst io.Writer, src io.Reader) { 107 _, err := io.Copy(dst, src) 108 errCh <- err 109 } 110 111 go cp(conn, nc) 112 go cp(nc, conn) 113 114 <-errCh 115 } 116 117 func isWebSocket(ctx *THttpContext) bool { 118 contains := func(key, val string) bool { 119 vv := strings.Split(ctx.Request().Header().Get(key), ",") 120 for _, v := range vv { 121 if val == strings.ToLower(strings.TrimSpace(v)) { 122 return true 123 } 124 } 125 return false 126 } 127 128 if contains("Connection", "upgrade") && contains("Upgrade", "websocket") { 129 return true 130 } 131 132 return false 133 }