github.com/mholt/caddy-l4@v0.0.0-20241104153248-ec8fae209322/modules/l4proxy/proxy.go (about) 1 // Copyright 2020 Matthew Holt 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package l4proxy 16 17 import ( 18 "crypto/tls" 19 "fmt" 20 "io" 21 "log" 22 "net" 23 "runtime/debug" 24 "strconv" 25 "sync" 26 "sync/atomic" 27 "time" 28 29 "github.com/caddyserver/caddy/v2" 30 "github.com/caddyserver/caddy/v2/caddyconfig" 31 "github.com/caddyserver/caddy/v2/caddyconfig/caddyfile" 32 "github.com/mastercactapus/proxyprotocol" 33 "go.uber.org/zap" 34 35 "github.com/mholt/caddy-l4/layer4" 36 "github.com/mholt/caddy-l4/modules/l4proxyprotocol" 37 "github.com/mholt/caddy-l4/modules/l4tls" 38 ) 39 40 func init() { 41 caddy.RegisterModule(&Handler{}) 42 } 43 44 // Handler is a handler that can proxy connections. 45 type Handler struct { 46 // Upstreams is the list of backends to proxy to. 47 Upstreams UpstreamPool `json:"upstreams,omitempty"` 48 49 // Health checks update the status of backends, whether they are 50 // up or down. Down backends will not be proxied to. 51 HealthChecks *HealthChecks `json:"health_checks,omitempty"` 52 53 // Load balancing distributes load/connections between backends. 54 LoadBalancing *LoadBalancing `json:"load_balancing,omitempty"` 55 56 // Specifies the version of the Proxy Protocol header to add, either "v1" or "v2". 57 // Ref: https://www.haproxy.org/download/1.8/doc/proxy-protocol.txt 58 ProxyProtocol string `json:"proxy_protocol,omitempty"` 59 60 proxyProtocolVersion uint8 61 62 ctx caddy.Context 63 logger *zap.Logger 64 } 65 66 // CaddyModule returns the Caddy module information. 67 func (*Handler) CaddyModule() caddy.ModuleInfo { 68 return caddy.ModuleInfo{ 69 ID: "layer4.handlers.proxy", 70 New: func() caddy.Module { return new(Handler) }, 71 } 72 } 73 74 // Provision sets up the handler. 75 func (h *Handler) Provision(ctx caddy.Context) error { 76 h.ctx = ctx 77 h.logger = ctx.Logger(h) 78 79 // start by loading modules 80 if h.LoadBalancing != nil && h.LoadBalancing.SelectionPolicyRaw != nil { 81 mod, err := ctx.LoadModule(h.LoadBalancing, "SelectionPolicyRaw") 82 if err != nil { 83 return fmt.Errorf("loading load balancing selection policy: %s", err) 84 } 85 h.LoadBalancing.SelectionPolicy = mod.(Selector) 86 } 87 88 repl := caddy.NewReplacer() 89 proxyProtocol := repl.ReplaceAll(h.ProxyProtocol, "") 90 if proxyProtocol == "v1" { 91 h.proxyProtocolVersion = 1 92 } else if proxyProtocol == "v2" { 93 h.proxyProtocolVersion = 2 94 } else if proxyProtocol != "" { 95 return fmt.Errorf("proxy_protocol: \"%s\" should be empty, or one of \"v1\" \"v2\"", proxyProtocol) 96 } 97 98 // prepare upstreams 99 if len(h.Upstreams) == 0 { 100 return fmt.Errorf("no upstreams defined") 101 } 102 for i, ups := range h.Upstreams { 103 err := ups.provision(ctx, h) 104 if err != nil { 105 return fmt.Errorf("upstream %d: %v", i, err) 106 } 107 } 108 109 // health checks 110 if h.HealthChecks != nil { 111 // set defaults on passive health checks, if necessary 112 if h.HealthChecks.Passive != nil { 113 if h.HealthChecks.Passive.FailDuration > 0 && h.HealthChecks.Passive.MaxFails == 0 { 114 h.HealthChecks.Passive.MaxFails = 1 115 } 116 } 117 118 // if active health checks are enabled, configure them and start a worker 119 if h.HealthChecks.Active != nil { 120 h.HealthChecks.Active.logger = h.logger.Named("health_checker.active") 121 122 if h.HealthChecks.Active.Timeout == 0 { 123 h.HealthChecks.Active.Timeout = caddy.Duration(5 * time.Second) 124 } 125 if h.HealthChecks.Active.Interval == 0 { 126 h.HealthChecks.Active.Interval = caddy.Duration(30 * time.Second) 127 } 128 129 go h.activeHealthChecker() 130 } 131 } 132 133 // set up load balancing; it must not be nil, even if there's just one backend 134 if h.LoadBalancing == nil { 135 h.LoadBalancing = new(LoadBalancing) 136 } 137 if h.LoadBalancing.SelectionPolicy == nil { 138 h.LoadBalancing.SelectionPolicy = &RandomSelection{} 139 } 140 if h.LoadBalancing.TryDuration > 0 && h.LoadBalancing.TryInterval == 0 { 141 // a non-zero try_duration with a zero try_interval 142 // will always spin the CPU for try_duration if the 143 // upstream is local or low-latency; avoid that by 144 // defaulting to a sane wait period between attempts 145 h.LoadBalancing.TryInterval = caddy.Duration(250 * time.Millisecond) 146 } 147 148 return nil 149 } 150 151 // Handle handles the downstream connection. 152 func (h *Handler) Handle(down *layer4.Connection, _ layer4.Handler) error { 153 repl := down.Context.Value(layer4.ReplacerCtxKey).(*caddy.Replacer) 154 155 start := time.Now() 156 157 var upConns []net.Conn 158 var proxyErr error 159 160 for { 161 // choose an available upstream 162 upstream := h.LoadBalancing.SelectionPolicy.Select(h.Upstreams, down) 163 if upstream == nil { 164 if proxyErr == nil { 165 proxyErr = fmt.Errorf("no upstreams available") 166 } 167 if !h.LoadBalancing.tryAgain(h.ctx, start) { 168 return proxyErr 169 } 170 continue 171 } 172 173 // establish all upstream connections 174 upConns, proxyErr = h.dialPeers(upstream, repl, down) 175 if proxyErr != nil { 176 // we might be able to try again 177 if !h.LoadBalancing.tryAgain(h.ctx, start) { 178 return proxyErr 179 } 180 continue 181 } 182 183 break 184 } 185 186 // make sure upstream connections all get closed 187 defer func() { 188 for _, conn := range upConns { 189 _ = conn.Close() 190 } 191 }() 192 193 // finally, proxy the connection 194 h.proxy(down, upConns) 195 196 return nil 197 } 198 199 func (h *Handler) dialPeers(upstream *Upstream, repl *caddy.Replacer, down *layer4.Connection) ([]net.Conn, error) { 200 var upConns []net.Conn 201 202 for _, p := range upstream.peers { 203 hostPort := repl.ReplaceAll(p.address.JoinHostPort(0), "") 204 205 var up net.Conn 206 var err error 207 208 if upstream.TLS == nil { 209 up, err = net.Dial(p.address.Network, hostPort) 210 } else { 211 // the prepared config could be nil if user enabled but did not customize TLS, 212 // in which case we adopt the downstream client's TLS ClientHello for ours; 213 // i.e. by default, make the client's TLS config as transparent as possible 214 tlsCfg := upstream.tlsConfig 215 if tlsCfg == nil { 216 tlsCfg = new(tls.Config) 217 if hellos := l4tls.GetClientHelloInfos(down); len(hellos) > 0 { 218 hellos[0].FillTLSClientConfig(tlsCfg) 219 } 220 } 221 up, err = tls.Dial(p.address.Network, hostPort, tlsCfg) 222 } 223 h.logger.Debug("dial upstream", 224 zap.String("remote", down.RemoteAddr().String()), 225 zap.String("upstream", hostPort), 226 zap.Error(err)) 227 228 // Send the PROXY protocol header. 229 if err == nil { 230 downConn := l4proxyprotocol.GetConn(down) 231 switch h.proxyProtocolVersion { 232 case 1: 233 var h proxyprotocol.HeaderV1 234 h.FromConn(downConn, false) 235 _, err = h.WriteTo(up) 236 case 2: 237 var h proxyprotocol.HeaderV2 238 h.FromConn(downConn, false) 239 _, err = h.WriteTo(up) 240 } 241 } 242 243 if err != nil { 244 h.countFailure(p) 245 for _, conn := range upConns { 246 _ = conn.Close() 247 } 248 return nil, err 249 } 250 251 upConns = append(upConns, up) 252 } 253 254 return upConns, nil 255 } 256 257 // proxy proxies the downstream connection to all upstream connections. 258 func (h *Handler) proxy(down *layer4.Connection, upConns []net.Conn) { 259 // every time we read from downstream, we write 260 // the same to each upstream; this is half of 261 // the proxy duplex 262 var downTee io.Reader = down 263 for _, up := range upConns { 264 downTee = io.TeeReader(downTee, up) 265 } 266 267 var wg sync.WaitGroup 268 var downClosed atomic.Bool 269 270 for _, up := range upConns { 271 wg.Add(1) 272 273 go func(up net.Conn) { 274 defer wg.Done() 275 276 if _, err := io.Copy(down, up); err != nil { 277 // If the downstream connection has been closed, we can assume this is 278 // the reason io.Copy() errored. That's normal operation for UDP 279 // connections after idle timeout, so don't log an error in that case. 280 if !downClosed.Load() { 281 h.logger.Error("upstream connection", 282 zap.String("local_address", up.LocalAddr().String()), 283 zap.String("remote_address", up.RemoteAddr().String()), 284 zap.Error(err), 285 ) 286 } 287 } 288 }(up) 289 } 290 291 downConnClosedCh := make(chan struct{}, 1) 292 293 go func() { 294 // read from downstream until connection is closed; 295 // TODO: this pumps the reader, but writing into discard is a weird way to do it; could be avoided if we used io.Pipe - see _gitignore/oldtee.go.txt 296 _, _ = io.Copy(io.Discard, downTee) 297 downConnClosedCh <- struct{}{} 298 299 // Shut down the writing side of all upstream connections, in case 300 // that the downstream connection is half closed. (issue #40) 301 // 302 // UDP connections meanwhile don't implement CloseWrite(), but in order 303 // to ensure io.Copy() in the per-upstream goroutines (above) returns, 304 // we need to close the socket. This will cause io.Copy() return an 305 // error, which in this particular case is expected, so we signal the 306 // intentional closure by setting this flag. 307 downClosed.Store(true) 308 for _, up := range upConns { 309 if conn, ok := up.(closeWriter); ok { 310 _ = conn.CloseWrite() 311 } else { 312 _ = up.Close() 313 } 314 } 315 }() 316 317 // wait for reading from all upstream connections 318 wg.Wait() 319 320 // Shut down the writing side of the downstream connection, in case that 321 // the upstream connections are all half closed. 322 if downConn, ok := down.Conn.(closeWriter); ok { 323 _ = downConn.CloseWrite() 324 } 325 326 // Wait for reading from the downstream connection, if possible. 327 <-downConnClosedCh 328 } 329 330 // countFailure is used with passive health checks. It 331 // remembers 1 failure for upstream for the configured 332 // duration. If passive health checks are disabled or 333 // failure expiry is 0, this is a no-op. 334 func (h *Handler) countFailure(p *peer) { 335 // only count failures if passive health checking is enabled 336 // and if failures are configured have a non-zero expiry 337 if h.HealthChecks == nil || h.HealthChecks.Passive == nil { 338 return 339 } 340 failDuration := time.Duration(h.HealthChecks.Passive.FailDuration) 341 if failDuration == 0 { 342 return 343 } 344 345 // count failure immediately 346 err := p.countFail(1) 347 if err != nil { 348 h.HealthChecks.Passive.logger.Error("could not count failure", 349 zap.String("peer_address", p.address.String()), 350 zap.Error(err)) 351 return 352 } 353 354 // forget it later 355 go func(failDuration time.Duration) { 356 defer func() { 357 if err := recover(); err != nil { 358 log.Printf("[PANIC] health check failure forgetter: %v\n%s", err, debug.Stack()) 359 } 360 }() 361 time.Sleep(failDuration) 362 err := p.countFail(-1) 363 if err != nil { 364 h.HealthChecks.Passive.logger.Error("could not forget failure", 365 zap.String("peer_address", p.address.String()), 366 zap.Error(err)) 367 } 368 }(failDuration) 369 } 370 371 // Cleanup cleans up the resources made by h during provisioning. 372 func (h *Handler) Cleanup() error { 373 // remove hosts from our config from the pool 374 for _, upstream := range h.Upstreams { 375 for _, dialAddr := range upstream.Dial { 376 _, _ = peers.Delete(dialAddr) 377 } 378 } 379 return nil 380 } 381 382 // UnmarshalCaddyfile sets up the Handler from Caddyfile tokens. Syntax: 383 // 384 // proxy [<upstreams...>] { 385 // # active health check options 386 // health_interval <duration> 387 // health_port <int> 388 // health_timeout <duration> 389 // 390 // # passive health check options 391 // fail_duration <duration> 392 // max_fails <int> 393 // unhealthy_connection_count <int> 394 // 395 // # load balancing options 396 // lb_policy <name> [<args...>] 397 // lb_try_duration <duration> 398 // lb_try_interval <duration> 399 // 400 // proxy_protocol <v1|v2> 401 // 402 // # multiple upstream options are supported 403 // upstream [<args...>] { 404 // ... 405 // } 406 // upstream [<args...>] 407 // } 408 func (h *Handler) UnmarshalCaddyfile(d *caddyfile.Dispenser) error { 409 _, wrapper := d.Next(), d.Val() // consume wrapper name 410 411 // Treat all same-line options as upstream addresses 412 for d.NextArg() { 413 h.Upstreams = append(h.Upstreams, &Upstream{Dial: []string{d.Val()}}) 414 } 415 416 var ( 417 hasHealthInterval, hasHealthPort, hasHealthTimeout bool // active health check options 418 hasFailDuration, hasMaxFails, hasUnhealthyConnCount bool // passive health check options 419 hasLBPolicy, hasLBTryDuration, hasLBTryInterval bool // load balancing options 420 hasProxyProtocol bool 421 ) 422 for nesting := d.Nesting(); d.NextBlock(nesting); { 423 optionName := d.Val() 424 switch optionName { 425 case "health_interval": 426 if hasHealthInterval { 427 return d.Errf("duplicate %s option '%s'", wrapper, optionName) 428 } 429 if d.CountRemainingArgs() != 1 { 430 return d.ArgErr() 431 } 432 d.NextArg() 433 dur, err := caddy.ParseDuration(d.Val()) 434 if err != nil { 435 return d.Errf("parsing %s option '%s' duration: %v", wrapper, optionName, err) 436 } 437 if h.HealthChecks == nil { 438 h.HealthChecks = &HealthChecks{Active: &ActiveHealthChecks{}} 439 } else if h.HealthChecks.Active == nil { 440 h.HealthChecks.Active = &ActiveHealthChecks{} 441 } 442 h.HealthChecks.Active.Interval, hasHealthInterval = caddy.Duration(dur), true 443 case "health_port": 444 if hasHealthPort { 445 return d.Errf("duplicate %s option '%s'", wrapper, optionName) 446 } 447 if d.CountRemainingArgs() != 1 { 448 return d.ArgErr() 449 } 450 d.NextArg() 451 val, err := strconv.ParseInt(d.Val(), 10, 32) 452 if err != nil { 453 return d.Errf("parsing %s option '%s': %v", wrapper, optionName, err) 454 } 455 if h.HealthChecks == nil { 456 h.HealthChecks = &HealthChecks{Active: &ActiveHealthChecks{}} 457 } else if h.HealthChecks.Active == nil { 458 h.HealthChecks.Active = &ActiveHealthChecks{} 459 } 460 h.HealthChecks.Active.Port, hasHealthPort = int(val), true 461 case "health_timeout": 462 if hasHealthTimeout { 463 return d.Errf("duplicate %s option '%s'", wrapper, optionName) 464 } 465 if d.CountRemainingArgs() != 1 { 466 return d.ArgErr() 467 } 468 d.NextArg() 469 dur, err := caddy.ParseDuration(d.Val()) 470 if err != nil { 471 return d.Errf("parsing %s option '%s' duration: %v", wrapper, optionName, err) 472 } 473 if h.HealthChecks == nil { 474 h.HealthChecks = &HealthChecks{Active: &ActiveHealthChecks{}} 475 } else if h.HealthChecks.Active == nil { 476 h.HealthChecks.Active = &ActiveHealthChecks{} 477 } 478 h.HealthChecks.Active.Timeout, hasHealthTimeout = caddy.Duration(dur), true 479 case "fail_duration": 480 if hasFailDuration { 481 return d.Errf("duplicate %s option '%s'", wrapper, optionName) 482 } 483 if d.CountRemainingArgs() != 1 { 484 return d.ArgErr() 485 } 486 d.NextArg() 487 dur, err := caddy.ParseDuration(d.Val()) 488 if err != nil { 489 return d.Errf("parsing %s option '%s' duration: %v", wrapper, optionName, err) 490 } 491 if h.HealthChecks == nil { 492 h.HealthChecks = &HealthChecks{Passive: &PassiveHealthChecks{}} 493 } else if h.HealthChecks.Passive == nil { 494 h.HealthChecks.Passive = &PassiveHealthChecks{} 495 } 496 h.HealthChecks.Passive.FailDuration, hasFailDuration = caddy.Duration(dur), true 497 case "max_fails": 498 if hasMaxFails { 499 return d.Errf("duplicate %s option '%s'", wrapper, optionName) 500 } 501 if d.CountRemainingArgs() != 1 { 502 return d.ArgErr() 503 } 504 d.NextArg() 505 val, err := strconv.ParseInt(d.Val(), 10, 32) 506 if err != nil { 507 return d.Errf("parsing %s option '%s': %v", wrapper, optionName, err) 508 } 509 if h.HealthChecks == nil { 510 h.HealthChecks = &HealthChecks{Passive: &PassiveHealthChecks{}} 511 } else if h.HealthChecks.Passive == nil { 512 h.HealthChecks.Passive = &PassiveHealthChecks{} 513 } 514 h.HealthChecks.Passive.MaxFails, hasMaxFails = int(val), true 515 case "unhealthy_connection_count": 516 if hasUnhealthyConnCount { 517 return d.Errf("duplicate %s option '%s'", wrapper, optionName) 518 } 519 if d.CountRemainingArgs() != 1 { 520 return d.ArgErr() 521 } 522 d.NextArg() 523 val, err := strconv.ParseInt(d.Val(), 10, 32) 524 if err != nil { 525 return d.Errf("parsing %s option '%s': %v", wrapper, optionName, err) 526 } 527 if h.HealthChecks == nil { 528 h.HealthChecks = &HealthChecks{Passive: &PassiveHealthChecks{}} 529 } else if h.HealthChecks.Passive == nil { 530 h.HealthChecks.Passive = &PassiveHealthChecks{} 531 } 532 h.HealthChecks.Passive.UnhealthyConnectionCount, hasUnhealthyConnCount = int(val), true 533 case "lb_policy": 534 if hasLBPolicy { 535 return d.Errf("duplicate proxy load_balancing option '%s'", optionName) 536 } 537 if !d.NextArg() { 538 return d.ArgErr() 539 } 540 policyName := d.Val() 541 542 unm, err := caddyfile.UnmarshalModule(d, "layer4.proxy.selection_policies."+policyName) 543 if err != nil { 544 return err 545 } 546 us, ok := unm.(Selector) 547 if !ok { 548 return d.Errf("policy module '%s' is not an upstream selector", policyName) 549 } 550 policyRaw := caddyconfig.JSON(us, nil) 551 552 policyRaw, err = layer4.SetModuleNameInline("policy", policyName, policyRaw) 553 if err != nil { 554 return d.Errf("re-encoding module '%s' configuration: %v", policyName, err) 555 } 556 if h.LoadBalancing == nil { 557 h.LoadBalancing = &LoadBalancing{} 558 } 559 h.LoadBalancing.SelectionPolicyRaw, hasLBPolicy = policyRaw, true 560 case "lb_try_duration": 561 if hasLBTryDuration { 562 return d.Errf("duplicate %s option '%s'", wrapper, optionName) 563 } 564 if d.CountRemainingArgs() != 1 { 565 return d.ArgErr() 566 } 567 d.NextArg() 568 dur, err := caddy.ParseDuration(d.Val()) 569 if err != nil { 570 return d.Errf("parsing %s option '%s' duration: %v", wrapper, optionName, err) 571 } 572 if h.LoadBalancing == nil { 573 h.LoadBalancing = &LoadBalancing{} 574 } 575 h.LoadBalancing.TryDuration, hasLBTryDuration = caddy.Duration(dur), true 576 case "lb_try_interval": 577 if hasLBTryInterval { 578 return d.Errf("duplicate %s option '%s'", wrapper, optionName) 579 } 580 if d.CountRemainingArgs() != 1 { 581 return d.ArgErr() 582 } 583 d.NextArg() 584 dur, err := caddy.ParseDuration(d.Val()) 585 if err != nil { 586 return d.Errf("parsing %s option '%s' duration: %v", wrapper, optionName, err) 587 } 588 if h.LoadBalancing == nil { 589 h.LoadBalancing = &LoadBalancing{} 590 } 591 h.LoadBalancing.TryInterval, hasLBTryInterval = caddy.Duration(dur), true 592 case "proxy_protocol": 593 if hasProxyProtocol { 594 return d.Errf("duplicate %s option '%s'", wrapper, optionName) 595 } 596 _, h.ProxyProtocol, hasProxyProtocol = d.NextArg(), d.Val(), true 597 case "upstream": 598 u := &Upstream{} 599 if err := u.UnmarshalCaddyfile(d.NewFromNextSegment()); err != nil { 600 return err 601 } 602 h.Upstreams = append(h.Upstreams, u) 603 default: 604 return d.ArgErr() 605 } 606 607 // No nested blocks are supported 608 if d.NextBlock(nesting + 1) { 609 return d.Errf("malformed %s option '%s': blocks are not supported", wrapper, optionName) 610 } 611 } 612 613 return nil 614 } 615 616 // peers is the global repository for peers that are 617 // currently in use by active configuration(s). This 618 // allows the state of remote hosts to be preserved 619 // through config reloads. 620 var peers = caddy.NewUsagePool() 621 622 // Interface guards 623 var ( 624 _ caddy.CleanerUpper = (*Handler)(nil) 625 _ caddy.Provisioner = (*Handler)(nil) 626 _ caddyfile.Unmarshaler = (*Handler)(nil) 627 _ layer4.NextHandler = (*Handler)(nil) 628 ) 629 630 // Used to properly shutdown half-closed connections (see PR #73). 631 // Implemented by net.TCPConn, net.UnixConn, tls.Conn, qtls.Conn. 632 type closeWriter interface { 633 // CloseWrite shuts down the writing side of the connection. 634 CloseWrite() error 635 } 636 637 // Ensure we notice if CloseWrite changes for these important connections 638 var ( 639 _ closeWriter = (*net.TCPConn)(nil) 640 _ closeWriter = (*net.UnixConn)(nil) 641 _ closeWriter = (*tls.Conn)(nil) 642 )