github.com/gofiber/fiber/v2@v2.47.0/middleware/proxy/proxy.go (about) 1 package proxy 2 3 import ( 4 "bytes" 5 "crypto/tls" 6 "log" 7 "net/url" 8 "strings" 9 "sync" 10 "time" 11 12 "github.com/gofiber/fiber/v2" 13 "github.com/gofiber/fiber/v2/utils" 14 15 "github.com/valyala/fasthttp" 16 ) 17 18 // New is deprecated 19 func New(config Config) fiber.Handler { 20 log.Printf("[Warning] - [PROXY] proxy.New is deprecated, please use proxy.Balancer instead\n") 21 return Balancer(config) 22 } 23 24 // Balancer creates a load balancer among multiple upstream servers 25 func Balancer(config Config) fiber.Handler { 26 // Set default config 27 cfg := configDefault(config) 28 29 // Load balanced client 30 lbc := &fasthttp.LBClient{} 31 // Note that Servers, Timeout, WriteBufferSize, ReadBufferSize and TlsConfig 32 // will not be used if the client are set. 33 if config.Client == nil { 34 // Set timeout 35 lbc.Timeout = cfg.Timeout 36 // Scheme must be provided, falls back to http 37 for _, server := range cfg.Servers { 38 if !strings.HasPrefix(server, "http") { 39 server = "http://" + server 40 } 41 42 u, err := url.Parse(server) 43 if err != nil { 44 panic(err) 45 } 46 47 client := &fasthttp.HostClient{ 48 NoDefaultUserAgentHeader: true, 49 DisablePathNormalizing: true, 50 Addr: u.Host, 51 52 ReadBufferSize: config.ReadBufferSize, 53 WriteBufferSize: config.WriteBufferSize, 54 55 TLSConfig: config.TlsConfig, 56 } 57 58 lbc.Clients = append(lbc.Clients, client) 59 } 60 } else { 61 // Set custom client 62 lbc = config.Client 63 } 64 65 // Return new handler 66 return func(c *fiber.Ctx) error { 67 // Don't execute middleware if Next returns true 68 if cfg.Next != nil && cfg.Next(c) { 69 return c.Next() 70 } 71 72 // Set request and response 73 req := c.Request() 74 res := c.Response() 75 76 // Don't proxy "Connection" header 77 req.Header.Del(fiber.HeaderConnection) 78 79 // Modify request 80 if cfg.ModifyRequest != nil { 81 if err := cfg.ModifyRequest(c); err != nil { 82 return err 83 } 84 } 85 86 req.SetRequestURI(utils.UnsafeString(req.RequestURI())) 87 88 // Forward request 89 if err := lbc.Do(req, res); err != nil { 90 return err 91 } 92 93 // Don't proxy "Connection" header 94 res.Header.Del(fiber.HeaderConnection) 95 96 // Modify response 97 if cfg.ModifyResponse != nil { 98 if err := cfg.ModifyResponse(c); err != nil { 99 return err 100 } 101 } 102 103 // Return nil to end proxying if no error 104 return nil 105 } 106 } 107 108 var client = &fasthttp.Client{ 109 NoDefaultUserAgentHeader: true, 110 DisablePathNormalizing: true, 111 } 112 113 var lock sync.RWMutex 114 115 // WithTlsConfig update http client with a user specified tls.config 116 // This function should be called before Do and Forward. 117 // Deprecated: use WithClient instead. 118 // 119 //nolint:stylecheck,revive // TODO: Rename to "WithTLSConfig" in v3 120 func WithTlsConfig(tlsConfig *tls.Config) { 121 client.TLSConfig = tlsConfig 122 } 123 124 // WithClient sets the global proxy client. 125 // This function should be called before Do and Forward. 126 func WithClient(cli *fasthttp.Client) { 127 lock.Lock() 128 defer lock.Unlock() 129 client = cli 130 } 131 132 // Forward performs the given http request and fills the given http response. 133 // This method will return an fiber.Handler 134 func Forward(addr string, clients ...*fasthttp.Client) fiber.Handler { 135 return func(c *fiber.Ctx) error { 136 return Do(c, addr, clients...) 137 } 138 } 139 140 // Do performs the given http request and fills the given http response. 141 // This method can be used within a fiber.Handler 142 func Do(c *fiber.Ctx, addr string, clients ...*fasthttp.Client) error { 143 return doAction(c, addr, func(cli *fasthttp.Client, req *fasthttp.Request, resp *fasthttp.Response) error { 144 return cli.Do(req, resp) 145 }, clients...) 146 } 147 148 // DoRedirects performs the given http request and fills the given http response, following up to maxRedirectsCount redirects. 149 // When the redirect count exceeds maxRedirectsCount, ErrTooManyRedirects is returned. 150 // This method can be used within a fiber.Handler 151 func DoRedirects(c *fiber.Ctx, addr string, maxRedirectsCount int, clients ...*fasthttp.Client) error { 152 return doAction(c, addr, func(cli *fasthttp.Client, req *fasthttp.Request, resp *fasthttp.Response) error { 153 return cli.DoRedirects(req, resp, maxRedirectsCount) 154 }, clients...) 155 } 156 157 // DoDeadline performs the given request and waits for response until the given deadline. 158 // This method can be used within a fiber.Handler 159 func DoDeadline(c *fiber.Ctx, addr string, deadline time.Time, clients ...*fasthttp.Client) error { 160 return doAction(c, addr, func(cli *fasthttp.Client, req *fasthttp.Request, resp *fasthttp.Response) error { 161 return cli.DoDeadline(req, resp, deadline) 162 }, clients...) 163 } 164 165 // DoTimeout performs the given request and waits for response during the given timeout duration. 166 // This method can be used within a fiber.Handler 167 func DoTimeout(c *fiber.Ctx, addr string, timeout time.Duration, clients ...*fasthttp.Client) error { 168 return doAction(c, addr, func(cli *fasthttp.Client, req *fasthttp.Request, resp *fasthttp.Response) error { 169 return cli.DoTimeout(req, resp, timeout) 170 }, clients...) 171 } 172 173 func doAction( 174 c *fiber.Ctx, 175 addr string, 176 action func(cli *fasthttp.Client, req *fasthttp.Request, resp *fasthttp.Response) error, 177 clients ...*fasthttp.Client, 178 ) error { 179 var cli *fasthttp.Client 180 181 // set local or global client 182 if len(clients) != 0 { 183 cli = clients[0] 184 } else { 185 lock.RLock() 186 cli = client 187 lock.RUnlock() 188 } 189 190 req := c.Request() 191 res := c.Response() 192 originalURL := utils.CopyString(c.OriginalURL()) 193 defer req.SetRequestURI(originalURL) 194 195 copiedURL := utils.CopyString(addr) 196 req.SetRequestURI(copiedURL) 197 // NOTE: if req.isTLS is true, SetRequestURI keeps the scheme as https. 198 // Reference: https://github.com/gofiber/fiber/issues/1762 199 if scheme := getScheme(utils.UnsafeBytes(copiedURL)); len(scheme) > 0 { 200 req.URI().SetSchemeBytes(scheme) 201 } 202 203 req.Header.Del(fiber.HeaderConnection) 204 if err := action(cli, req, res); err != nil { 205 return err 206 } 207 res.Header.Del(fiber.HeaderConnection) 208 return nil 209 } 210 211 func getScheme(uri []byte) []byte { 212 i := bytes.IndexByte(uri, '/') 213 if i < 1 || uri[i-1] != ':' || i == len(uri)-1 || uri[i+1] != '/' { 214 return nil 215 } 216 return uri[:i-1] 217 } 218 219 // DomainForward performs an http request based on the given domain and populates the given http response. 220 // This method will return an fiber.Handler 221 func DomainForward(hostname, addr string, clients ...*fasthttp.Client) fiber.Handler { 222 return func(c *fiber.Ctx) error { 223 host := string(c.Request().Host()) 224 if host == hostname { 225 return Do(c, addr+c.OriginalURL(), clients...) 226 } 227 return nil 228 } 229 } 230 231 type roundrobin struct { 232 sync.Mutex 233 234 current int 235 pool []string 236 } 237 238 // this method will return a string of addr server from list server. 239 func (r *roundrobin) get() string { 240 r.Lock() 241 defer r.Unlock() 242 243 if r.current >= len(r.pool) { 244 r.current %= len(r.pool) 245 } 246 247 result := r.pool[r.current] 248 r.current++ 249 return result 250 } 251 252 // BalancerForward Forward performs the given http request with round robin algorithm to server and fills the given http response. 253 // This method will return an fiber.Handler 254 func BalancerForward(servers []string, clients ...*fasthttp.Client) fiber.Handler { 255 r := &roundrobin{ 256 current: 0, 257 pool: servers, 258 } 259 return func(c *fiber.Ctx) error { 260 server := r.get() 261 if !strings.HasPrefix(server, "http") { 262 server = "http://" + server 263 } 264 c.Request().Header.Add("X-Real-IP", c.IP()) 265 return Do(c, server+c.OriginalURL(), clients...) 266 } 267 }