github.com/blend/go-sdk@v1.20220411.3/cmd/reverseproxy/main.go (about) 1 /* 2 3 Copyright (c) 2022 - Present. Blend Labs, Inc. All rights reserved 4 Use of this source code is governed by a MIT license that can be found in the LICENSE file. 5 6 */ 7 8 package main 9 10 import ( 11 "crypto/tls" 12 "flag" 13 "fmt" 14 "net/http" 15 "net/url" 16 "os" 17 "strings" 18 19 "github.com/blend/go-sdk/certutil" 20 "github.com/blend/go-sdk/graceful" 21 "github.com/blend/go-sdk/logger" 22 "github.com/blend/go-sdk/proxyprotocol" 23 "github.com/blend/go-sdk/reverseproxy" 24 "github.com/blend/go-sdk/webutil" 25 ) 26 27 func main() { 28 log, err := logger.New( 29 logger.OptConfigFromEnv(), 30 logger.OptEnabled(webutil.FlagHTTPRequest), 31 logger.OptPath("reverse-proxy"), 32 ) 33 if err != nil { 34 logger.FatalExit(err) 35 } 36 37 var upstreams Upstreams 38 flag.Var(&upstreams, "upstream", "An upstream server to proxy traffic to") 39 40 var tlsCert string 41 flag.StringVar(&tlsCert, "tls-cert", "", "The path to the tls certificate file (--tls-key must also be set)") 42 43 var tlsKey string 44 flag.StringVar(&tlsKey, "tls-key", "", "The path to the tls key file (--tls-cert must also be set)") 45 46 var addr string 47 flag.StringVar(&addr, "addr", reverseproxy.DefaultAddr, "The address to listen on.") 48 49 var upgradeAddr string 50 flag.StringVar(&upgradeAddr, "upgrade-addr", "", "The upgrade address to listen on.") 51 52 var useProxyProtocol bool 53 flag.BoolVar(&useProxyProtocol, "proxyProtocol", false, "If we should decode proxy protocol.") 54 55 var upstreamHeaders UpstreamHeader 56 flag.Var(&upstreamHeaders, "upstream-header", "Upstream heaeders to add for all requests.") 57 58 flag.Parse() 59 60 if len(upstreams) == 0 { 61 flag.Usage() 62 os.Exit(1) 63 } 64 65 var listenerOptions []proxyprotocol.CreateListenerOption 66 67 proxy, _ := reverseproxy.NewProxy() 68 proxy.Log = log 69 70 var servers []graceful.Graceful 71 for _, upstream := range upstreams { 72 log.Infof("upstream: %s", upstream) 73 target, err := url.Parse(upstream) 74 if err != nil { 75 log.Fatal(err) 76 os.Exit(1) 77 } 78 79 proxyUpstream := reverseproxy.NewUpstream(target) 80 proxyUpstream.Log = log 81 if err = proxyUpstream.UseHTTP2(); err != nil { 82 log.Fatal(err) 83 os.Exit(1) 84 } 85 proxy.Upstreams = append(proxy.Upstreams, proxyUpstream) 86 } 87 88 for _, header := range upstreamHeaders { 89 pieces := strings.SplitN(header, "=", 2) 90 if len(pieces) < 2 { 91 log.Fatal(fmt.Errorf("invalid header; must be in the form key=value")) 92 os.Exit(1) 93 } 94 log.Infof("proxy using upstream header: %s=%s", pieces[0], pieces[1]) 95 proxy.Headers.Add(pieces[0], pieces[1]) 96 } 97 98 if len(tlsCert) > 0 && len(tlsKey) == 0 { 99 log.Fatal(fmt.Errorf("`--tls-key` is unset, cannot continue")) 100 os.Exit(1) 101 } 102 if len(tlsCert) == 0 && len(tlsKey) > 0 { 103 log.Fatal(fmt.Errorf("`--tls-key` is unset, cannot continue")) 104 os.Exit(1) 105 } 106 if len(tlsCert) > 0 && len(tlsKey) > 0 { 107 certFileWatcher, err := certutil.NewCertFileWatcher(certutil.KeyPair{CertPath: tlsCert, KeyPath: tlsKey}) 108 if err != nil { 109 log.Fatal(err) 110 os.Exit(1) 111 } 112 113 log.Infof("watching tls cert/key files for changes") 114 servers = append(servers, certFileWatcher) 115 116 rootCAs, err := certutil.ExtendSystemCertPool() 117 if err != nil { 118 log.Fatal(err) 119 os.Exit(1) 120 } 121 122 proxyServerTLSConfig := &tls.Config{ 123 RootCAs: rootCAs, 124 GetCertificate: certFileWatcher.GetCertificate, 125 } 126 webutil.TLSSecureCipherSuites(proxyServerTLSConfig) 127 listenerOptions = append(listenerOptions, proxyprotocol.OptTLSConfig(proxyServerTLSConfig)) 128 } 129 130 proxyServerListener, err := proxyprotocol.CreateListener("tcp", addr, listenerOptions...) 131 if err != nil { 132 log.Fatal(err) 133 os.Exit(1) 134 } 135 136 proxyServer := &http.Server{ 137 Handler: webutil.NestMiddleware(proxy.ServeHTTP, webutil.HTTPLogged(log)), 138 } 139 servers = append(servers, 140 webutil.NewGracefulHTTPServer(proxyServer, webutil.OptGracefulHTTPServerListener(proxyServerListener)), 141 ) 142 143 if upgradeAddr != "" { 144 log.Infof("http upgrader listening on: %s", upgradeAddr) 145 upgrader := reverseproxy.HTTPRedirect{} 146 servers = append(servers, webutil.NewGracefulHTTPServer(&http.Server{ 147 Addr: upgradeAddr, 148 Handler: webutil.NestMiddleware(upgrader.ServeHTTP, webutil.HTTPLogged(log)), 149 })) 150 } 151 152 log.Infof("reverse proxy listening on: %s", addr) 153 if err := graceful.Shutdown( 154 servers..., 155 ); err != nil { 156 log.Fatal(err) 157 os.Exit(1) 158 } 159 } 160 161 // Upstreams is a flag variable for upstreams. 162 type Upstreams []string 163 164 // String returns a string representation of the upstreams. 165 func (u *Upstreams) String() string { 166 if u == nil { 167 return "<nil>" 168 } 169 return strings.Join(*u, ", ") 170 } 171 172 // Set adds a flag value. 173 func (u *Upstreams) Set(value string) error { 174 *u = append(*u, value) 175 return nil 176 } 177 178 // UpstreamHeader is a flag variable for upstreams. 179 type UpstreamHeader []string 180 181 // String returns a string representation of the upstreams. 182 func (u *UpstreamHeader) String() string { 183 if u == nil { 184 return "<nil>" 185 } 186 return strings.Join(*u, ", ") 187 } 188 189 // Set adds a flag value. 190 func (u *UpstreamHeader) Set(value string) error { 191 *u = append(*u, value) 192 return nil 193 }