github.com/gofiber/fiber/v2@v2.47.0/middleware/proxy/proxy.go (about)

     1  package proxy
     2  
     3  import (
     4  	"bytes"
     5  	"crypto/tls"
     6  	"log"
     7  	"net/url"
     8  	"strings"
     9  	"sync"
    10  	"time"
    11  
    12  	"github.com/gofiber/fiber/v2"
    13  	"github.com/gofiber/fiber/v2/utils"
    14  
    15  	"github.com/valyala/fasthttp"
    16  )
    17  
    18  // New is deprecated
    19  func New(config Config) fiber.Handler {
    20  	log.Printf("[Warning] - [PROXY] proxy.New is deprecated, please use proxy.Balancer instead\n")
    21  	return Balancer(config)
    22  }
    23  
    24  // Balancer creates a load balancer among multiple upstream servers
    25  func Balancer(config Config) fiber.Handler {
    26  	// Set default config
    27  	cfg := configDefault(config)
    28  
    29  	// Load balanced client
    30  	lbc := &fasthttp.LBClient{}
    31  	// Note that Servers, Timeout, WriteBufferSize, ReadBufferSize and TlsConfig
    32  	// will not be used if the client are set.
    33  	if config.Client == nil {
    34  		// Set timeout
    35  		lbc.Timeout = cfg.Timeout
    36  		// Scheme must be provided, falls back to http
    37  		for _, server := range cfg.Servers {
    38  			if !strings.HasPrefix(server, "http") {
    39  				server = "http://" + server
    40  			}
    41  
    42  			u, err := url.Parse(server)
    43  			if err != nil {
    44  				panic(err)
    45  			}
    46  
    47  			client := &fasthttp.HostClient{
    48  				NoDefaultUserAgentHeader: true,
    49  				DisablePathNormalizing:   true,
    50  				Addr:                     u.Host,
    51  
    52  				ReadBufferSize:  config.ReadBufferSize,
    53  				WriteBufferSize: config.WriteBufferSize,
    54  
    55  				TLSConfig: config.TlsConfig,
    56  			}
    57  
    58  			lbc.Clients = append(lbc.Clients, client)
    59  		}
    60  	} else {
    61  		// Set custom client
    62  		lbc = config.Client
    63  	}
    64  
    65  	// Return new handler
    66  	return func(c *fiber.Ctx) error {
    67  		// Don't execute middleware if Next returns true
    68  		if cfg.Next != nil && cfg.Next(c) {
    69  			return c.Next()
    70  		}
    71  
    72  		// Set request and response
    73  		req := c.Request()
    74  		res := c.Response()
    75  
    76  		// Don't proxy "Connection" header
    77  		req.Header.Del(fiber.HeaderConnection)
    78  
    79  		// Modify request
    80  		if cfg.ModifyRequest != nil {
    81  			if err := cfg.ModifyRequest(c); err != nil {
    82  				return err
    83  			}
    84  		}
    85  
    86  		req.SetRequestURI(utils.UnsafeString(req.RequestURI()))
    87  
    88  		// Forward request
    89  		if err := lbc.Do(req, res); err != nil {
    90  			return err
    91  		}
    92  
    93  		// Don't proxy "Connection" header
    94  		res.Header.Del(fiber.HeaderConnection)
    95  
    96  		// Modify response
    97  		if cfg.ModifyResponse != nil {
    98  			if err := cfg.ModifyResponse(c); err != nil {
    99  				return err
   100  			}
   101  		}
   102  
   103  		// Return nil to end proxying if no error
   104  		return nil
   105  	}
   106  }
   107  
   108  var client = &fasthttp.Client{
   109  	NoDefaultUserAgentHeader: true,
   110  	DisablePathNormalizing:   true,
   111  }
   112  
   113  var lock sync.RWMutex
   114  
   115  // WithTlsConfig update http client with a user specified tls.config
   116  // This function should be called before Do and Forward.
   117  // Deprecated: use WithClient instead.
   118  //
   119  //nolint:stylecheck,revive // TODO: Rename to "WithTLSConfig" in v3
   120  func WithTlsConfig(tlsConfig *tls.Config) {
   121  	client.TLSConfig = tlsConfig
   122  }
   123  
   124  // WithClient sets the global proxy client.
   125  // This function should be called before Do and Forward.
   126  func WithClient(cli *fasthttp.Client) {
   127  	lock.Lock()
   128  	defer lock.Unlock()
   129  	client = cli
   130  }
   131  
   132  // Forward performs the given http request and fills the given http response.
   133  // This method will return an fiber.Handler
   134  func Forward(addr string, clients ...*fasthttp.Client) fiber.Handler {
   135  	return func(c *fiber.Ctx) error {
   136  		return Do(c, addr, clients...)
   137  	}
   138  }
   139  
   140  // Do performs the given http request and fills the given http response.
   141  // This method can be used within a fiber.Handler
   142  func Do(c *fiber.Ctx, addr string, clients ...*fasthttp.Client) error {
   143  	return doAction(c, addr, func(cli *fasthttp.Client, req *fasthttp.Request, resp *fasthttp.Response) error {
   144  		return cli.Do(req, resp)
   145  	}, clients...)
   146  }
   147  
   148  // DoRedirects performs the given http request and fills the given http response, following up to maxRedirectsCount redirects.
   149  // When the redirect count exceeds maxRedirectsCount, ErrTooManyRedirects is returned.
   150  // This method can be used within a fiber.Handler
   151  func DoRedirects(c *fiber.Ctx, addr string, maxRedirectsCount int, clients ...*fasthttp.Client) error {
   152  	return doAction(c, addr, func(cli *fasthttp.Client, req *fasthttp.Request, resp *fasthttp.Response) error {
   153  		return cli.DoRedirects(req, resp, maxRedirectsCount)
   154  	}, clients...)
   155  }
   156  
   157  // DoDeadline performs the given request and waits for response until the given deadline.
   158  // This method can be used within a fiber.Handler
   159  func DoDeadline(c *fiber.Ctx, addr string, deadline time.Time, clients ...*fasthttp.Client) error {
   160  	return doAction(c, addr, func(cli *fasthttp.Client, req *fasthttp.Request, resp *fasthttp.Response) error {
   161  		return cli.DoDeadline(req, resp, deadline)
   162  	}, clients...)
   163  }
   164  
   165  // DoTimeout performs the given request and waits for response during the given timeout duration.
   166  // This method can be used within a fiber.Handler
   167  func DoTimeout(c *fiber.Ctx, addr string, timeout time.Duration, clients ...*fasthttp.Client) error {
   168  	return doAction(c, addr, func(cli *fasthttp.Client, req *fasthttp.Request, resp *fasthttp.Response) error {
   169  		return cli.DoTimeout(req, resp, timeout)
   170  	}, clients...)
   171  }
   172  
   173  func doAction(
   174  	c *fiber.Ctx,
   175  	addr string,
   176  	action func(cli *fasthttp.Client, req *fasthttp.Request, resp *fasthttp.Response) error,
   177  	clients ...*fasthttp.Client,
   178  ) error {
   179  	var cli *fasthttp.Client
   180  
   181  	// set local or global client
   182  	if len(clients) != 0 {
   183  		cli = clients[0]
   184  	} else {
   185  		lock.RLock()
   186  		cli = client
   187  		lock.RUnlock()
   188  	}
   189  
   190  	req := c.Request()
   191  	res := c.Response()
   192  	originalURL := utils.CopyString(c.OriginalURL())
   193  	defer req.SetRequestURI(originalURL)
   194  
   195  	copiedURL := utils.CopyString(addr)
   196  	req.SetRequestURI(copiedURL)
   197  	// NOTE: if req.isTLS is true, SetRequestURI keeps the scheme as https.
   198  	// Reference: https://github.com/gofiber/fiber/issues/1762
   199  	if scheme := getScheme(utils.UnsafeBytes(copiedURL)); len(scheme) > 0 {
   200  		req.URI().SetSchemeBytes(scheme)
   201  	}
   202  
   203  	req.Header.Del(fiber.HeaderConnection)
   204  	if err := action(cli, req, res); err != nil {
   205  		return err
   206  	}
   207  	res.Header.Del(fiber.HeaderConnection)
   208  	return nil
   209  }
   210  
   211  func getScheme(uri []byte) []byte {
   212  	i := bytes.IndexByte(uri, '/')
   213  	if i < 1 || uri[i-1] != ':' || i == len(uri)-1 || uri[i+1] != '/' {
   214  		return nil
   215  	}
   216  	return uri[:i-1]
   217  }
   218  
   219  // DomainForward performs an http request based on the given domain and populates the given http response.
   220  // This method will return an fiber.Handler
   221  func DomainForward(hostname, addr string, clients ...*fasthttp.Client) fiber.Handler {
   222  	return func(c *fiber.Ctx) error {
   223  		host := string(c.Request().Host())
   224  		if host == hostname {
   225  			return Do(c, addr+c.OriginalURL(), clients...)
   226  		}
   227  		return nil
   228  	}
   229  }
   230  
   231  type roundrobin struct {
   232  	sync.Mutex
   233  
   234  	current int
   235  	pool    []string
   236  }
   237  
   238  // this method will return a string of addr server from list server.
   239  func (r *roundrobin) get() string {
   240  	r.Lock()
   241  	defer r.Unlock()
   242  
   243  	if r.current >= len(r.pool) {
   244  		r.current %= len(r.pool)
   245  	}
   246  
   247  	result := r.pool[r.current]
   248  	r.current++
   249  	return result
   250  }
   251  
   252  // BalancerForward Forward performs the given http request with round robin algorithm to server and fills the given http response.
   253  // This method will return an fiber.Handler
   254  func BalancerForward(servers []string, clients ...*fasthttp.Client) fiber.Handler {
   255  	r := &roundrobin{
   256  		current: 0,
   257  		pool:    servers,
   258  	}
   259  	return func(c *fiber.Ctx) error {
   260  		server := r.get()
   261  		if !strings.HasPrefix(server, "http") {
   262  			server = "http://" + server
   263  		}
   264  		c.Request().Header.Add("X-Real-IP", c.IP())
   265  		return Do(c, server+c.OriginalURL(), clients...)
   266  	}
   267  }