github.com/blend/go-sdk@v1.20220411.3/cmd/reverseproxy/main.go (about)

     1  /*
     2  
     3  Copyright (c) 2022 - Present. Blend Labs, Inc. All rights reserved
     4  Use of this source code is governed by a MIT license that can be found in the LICENSE file.
     5  
     6  */
     7  
     8  package main
     9  
    10  import (
    11  	"crypto/tls"
    12  	"flag"
    13  	"fmt"
    14  	"net/http"
    15  	"net/url"
    16  	"os"
    17  	"strings"
    18  
    19  	"github.com/blend/go-sdk/certutil"
    20  	"github.com/blend/go-sdk/graceful"
    21  	"github.com/blend/go-sdk/logger"
    22  	"github.com/blend/go-sdk/proxyprotocol"
    23  	"github.com/blend/go-sdk/reverseproxy"
    24  	"github.com/blend/go-sdk/webutil"
    25  )
    26  
    27  func main() {
    28  	log, err := logger.New(
    29  		logger.OptConfigFromEnv(),
    30  		logger.OptEnabled(webutil.FlagHTTPRequest),
    31  		logger.OptPath("reverse-proxy"),
    32  	)
    33  	if err != nil {
    34  		logger.FatalExit(err)
    35  	}
    36  
    37  	var upstreams Upstreams
    38  	flag.Var(&upstreams, "upstream", "An upstream server to proxy traffic to")
    39  
    40  	var tlsCert string
    41  	flag.StringVar(&tlsCert, "tls-cert", "", "The path to the tls certificate file (--tls-key must also be set)")
    42  
    43  	var tlsKey string
    44  	flag.StringVar(&tlsKey, "tls-key", "", "The path to the tls key file (--tls-cert must also be set)")
    45  
    46  	var addr string
    47  	flag.StringVar(&addr, "addr", reverseproxy.DefaultAddr, "The address to listen on.")
    48  
    49  	var upgradeAddr string
    50  	flag.StringVar(&upgradeAddr, "upgrade-addr", "", "The upgrade address to listen on.")
    51  
    52  	var useProxyProtocol bool
    53  	flag.BoolVar(&useProxyProtocol, "proxyProtocol", false, "If we should decode proxy protocol.")
    54  
    55  	var upstreamHeaders UpstreamHeader
    56  	flag.Var(&upstreamHeaders, "upstream-header", "Upstream heaeders to add for all requests.")
    57  
    58  	flag.Parse()
    59  
    60  	if len(upstreams) == 0 {
    61  		flag.Usage()
    62  		os.Exit(1)
    63  	}
    64  
    65  	var listenerOptions []proxyprotocol.CreateListenerOption
    66  
    67  	proxy, _ := reverseproxy.NewProxy()
    68  	proxy.Log = log
    69  
    70  	var servers []graceful.Graceful
    71  	for _, upstream := range upstreams {
    72  		log.Infof("upstream: %s", upstream)
    73  		target, err := url.Parse(upstream)
    74  		if err != nil {
    75  			log.Fatal(err)
    76  			os.Exit(1)
    77  		}
    78  
    79  		proxyUpstream := reverseproxy.NewUpstream(target)
    80  		proxyUpstream.Log = log
    81  		if err = proxyUpstream.UseHTTP2(); err != nil {
    82  			log.Fatal(err)
    83  			os.Exit(1)
    84  		}
    85  		proxy.Upstreams = append(proxy.Upstreams, proxyUpstream)
    86  	}
    87  
    88  	for _, header := range upstreamHeaders {
    89  		pieces := strings.SplitN(header, "=", 2)
    90  		if len(pieces) < 2 {
    91  			log.Fatal(fmt.Errorf("invalid header; must be in the form key=value"))
    92  			os.Exit(1)
    93  		}
    94  		log.Infof("proxy using upstream header: %s=%s", pieces[0], pieces[1])
    95  		proxy.Headers.Add(pieces[0], pieces[1])
    96  	}
    97  
    98  	if len(tlsCert) > 0 && len(tlsKey) == 0 {
    99  		log.Fatal(fmt.Errorf("`--tls-key` is unset, cannot continue"))
   100  		os.Exit(1)
   101  	}
   102  	if len(tlsCert) == 0 && len(tlsKey) > 0 {
   103  		log.Fatal(fmt.Errorf("`--tls-key` is unset, cannot continue"))
   104  		os.Exit(1)
   105  	}
   106  	if len(tlsCert) > 0 && len(tlsKey) > 0 {
   107  		certFileWatcher, err := certutil.NewCertFileWatcher(certutil.KeyPair{CertPath: tlsCert, KeyPath: tlsKey})
   108  		if err != nil {
   109  			log.Fatal(err)
   110  			os.Exit(1)
   111  		}
   112  
   113  		log.Infof("watching tls cert/key files for changes")
   114  		servers = append(servers, certFileWatcher)
   115  
   116  		rootCAs, err := certutil.ExtendSystemCertPool()
   117  		if err != nil {
   118  			log.Fatal(err)
   119  			os.Exit(1)
   120  		}
   121  
   122  		proxyServerTLSConfig := &tls.Config{
   123  			RootCAs:        rootCAs,
   124  			GetCertificate: certFileWatcher.GetCertificate,
   125  		}
   126  		webutil.TLSSecureCipherSuites(proxyServerTLSConfig)
   127  		listenerOptions = append(listenerOptions, proxyprotocol.OptTLSConfig(proxyServerTLSConfig))
   128  	}
   129  
   130  	proxyServerListener, err := proxyprotocol.CreateListener("tcp", addr, listenerOptions...)
   131  	if err != nil {
   132  		log.Fatal(err)
   133  		os.Exit(1)
   134  	}
   135  
   136  	proxyServer := &http.Server{
   137  		Handler: webutil.NestMiddleware(proxy.ServeHTTP, webutil.HTTPLogged(log)),
   138  	}
   139  	servers = append(servers,
   140  		webutil.NewGracefulHTTPServer(proxyServer, webutil.OptGracefulHTTPServerListener(proxyServerListener)),
   141  	)
   142  
   143  	if upgradeAddr != "" {
   144  		log.Infof("http upgrader listening on: %s", upgradeAddr)
   145  		upgrader := reverseproxy.HTTPRedirect{}
   146  		servers = append(servers, webutil.NewGracefulHTTPServer(&http.Server{
   147  			Addr:    upgradeAddr,
   148  			Handler: webutil.NestMiddleware(upgrader.ServeHTTP, webutil.HTTPLogged(log)),
   149  		}))
   150  	}
   151  
   152  	log.Infof("reverse proxy listening on: %s", addr)
   153  	if err := graceful.Shutdown(
   154  		servers...,
   155  	); err != nil {
   156  		log.Fatal(err)
   157  		os.Exit(1)
   158  	}
   159  }
   160  
   161  // Upstreams is a flag variable for upstreams.
   162  type Upstreams []string
   163  
   164  // String returns a string representation of the upstreams.
   165  func (u *Upstreams) String() string {
   166  	if u == nil {
   167  		return "<nil>"
   168  	}
   169  	return strings.Join(*u, ", ")
   170  }
   171  
   172  // Set adds a flag value.
   173  func (u *Upstreams) Set(value string) error {
   174  	*u = append(*u, value)
   175  	return nil
   176  }
   177  
   178  // UpstreamHeader is a flag variable for upstreams.
   179  type UpstreamHeader []string
   180  
   181  // String returns a string representation of the upstreams.
   182  func (u *UpstreamHeader) String() string {
   183  	if u == nil {
   184  		return "<nil>"
   185  	}
   186  	return strings.Join(*u, ", ")
   187  }
   188  
   189  // Set adds a flag value.
   190  func (u *UpstreamHeader) Set(value string) error {
   191  	*u = append(*u, value)
   192  	return nil
   193  }