github.com/avenga/couper@v1.12.2/handler/transport/transport.go (about)

     1  package transport
     2  
     3  import (
     4  	"context"
     5  	"crypto/tls"
     6  	"fmt"
     7  	"net"
     8  	"net/http"
     9  	"net/url"
    10  	"os"
    11  	"time"
    12  
    13  	"github.com/sirupsen/logrus"
    14  	"go.opentelemetry.io/otel/attribute"
    15  	"go.opentelemetry.io/otel/trace"
    16  
    17  	"github.com/avenga/couper/config"
    18  	"github.com/avenga/couper/config/request"
    19  	"github.com/avenga/couper/handler/ratelimit"
    20  	coupertls "github.com/avenga/couper/internal/tls"
    21  	"github.com/avenga/couper/telemetry"
    22  	"golang.org/x/net/http/httpproxy"
    23  )
    24  
    25  // Config represents the transport <Config> object.
    26  type Config struct {
    27  	BackendName            string
    28  	DisableCertValidation  bool
    29  	DisableConnectionReuse bool
    30  	HTTP2                  bool
    31  	MaxConnections         int
    32  	NoProxyFromEnv         bool
    33  	Proxy                  string
    34  	RateLimits             ratelimit.RateLimits
    35  
    36  	ConnectTimeout time.Duration
    37  	TTFBTimeout    time.Duration
    38  	Timeout        time.Duration
    39  
    40  	// TLS settings
    41  	// Certificate is passed to all backends from the related cli option.
    42  	Certificate []byte
    43  	// CACertificate contains a per backend configured one.
    44  	CACertificate tls.Certificate
    45  	// ClientCertificate holds the one the backend will send during tls handshake if required.
    46  	ClientCertificate tls.Certificate
    47  
    48  	// Dynamic values
    49  	Context  context.Context
    50  	Hostname string
    51  	Origin   string
    52  	Scheme   string
    53  }
    54  
    55  // NewTransport creates a new <*http.Transport> object by the given <*Config>.
    56  func NewTransport(conf *Config, log *logrus.Entry) *http.Transport {
    57  	tlsConf := coupertls.DefaultTLSConfig()
    58  	if len(conf.Certificate) > 0 {
    59  		tlsConf.RootCAs.AppendCertsFromPEM(conf.Certificate)
    60  	}
    61  	if conf.CACertificate.Leaf == nil {
    62  		tlsConf.InsecureSkipVerify = conf.DisableCertValidation
    63  	} else {
    64  		tlsConf.RootCAs.AddCert(conf.CACertificate.Leaf)
    65  	}
    66  
    67  	if conf.ClientCertificate.Leaf != nil {
    68  		clientCert := &conf.ClientCertificate
    69  		tlsConf.GetClientCertificate = func(info *tls.CertificateRequestInfo) (*tls.Certificate, error) {
    70  			return clientCert, nil
    71  		}
    72  	}
    73  
    74  	if conf.Origin != conf.Hostname {
    75  		tlsConf.ServerName = conf.Hostname
    76  	}
    77  
    78  	d := &net.Dialer{
    79  		KeepAlive: 60 * time.Second,
    80  	}
    81  
    82  	var proxyFunc func(req *http.Request) (*url.URL, error)
    83  	if conf.Proxy != "" {
    84  		proxyFunc = func(req *http.Request) (*url.URL, error) {
    85  			proxyConf := &httpproxy.Config{
    86  				HTTPProxy:  conf.Proxy,
    87  				HTTPSProxy: conf.Proxy,
    88  			}
    89  
    90  			return proxyConf.ProxyFunc()(req.URL)
    91  		}
    92  	} else if !conf.NoProxyFromEnv {
    93  		proxyFunc = http.ProxyFromEnvironment
    94  	}
    95  
    96  	// This is the documented way to disable http2. However, if a custom tls.Config or
    97  	// DialContext is used h2 will also be disabled. To enable h2 the transport must be
    98  	// explicitly configured, this can be done with the 'ForceAttemptHTTP2' below.
    99  	var nextProto map[string]func(authority string, c *tls.Conn) http.RoundTripper
   100  	if !conf.HTTP2 {
   101  		nextProto = make(map[string]func(authority string, c *tls.Conn) http.RoundTripper)
   102  		tlsConf.NextProtos = nil
   103  	}
   104  
   105  	logEntry := log.WithField("type", "couper_connection")
   106  
   107  	transport := &http.Transport{
   108  		DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
   109  			address := addr
   110  			if proxyFunc == nil {
   111  				address = conf.Origin
   112  			} // Otherwise, proxy connect will use this dial method and addr could be a proxy one.
   113  
   114  			stx, span := telemetry.NewSpanFromContext(ctx, "connect", trace.WithAttributes(attribute.String("couper.address", addr)))
   115  			defer span.End()
   116  
   117  			connectTimeout, _ := ctx.Value(request.ConnectTimeout).(time.Duration)
   118  			if connectTimeout > 0 {
   119  				dtx, cancel := context.WithDeadline(stx, time.Now().Add(connectTimeout))
   120  				stx = dtx
   121  				defer cancel()
   122  			}
   123  
   124  			conn, cerr := d.DialContext(stx, network, address)
   125  			if cerr != nil {
   126  				host, port, _ := net.SplitHostPort(conf.Origin)
   127  				if port != "80" && port != "443" {
   128  					host = conf.Origin
   129  				}
   130  				if os.IsTimeout(cerr) || cerr == context.DeadlineExceeded {
   131  					return nil, fmt.Errorf("connecting to %s '%s' failed: i/o timeout", conf.BackendName, host)
   132  				}
   133  				return nil, fmt.Errorf("connecting to %s '%s' failed: %w", conf.BackendName, conf.Origin, cerr)
   134  			}
   135  			return NewOriginConn(stx, conn, conf, logEntry), nil
   136  		},
   137  		DisableCompression: true,
   138  		DisableKeepAlives:  conf.DisableConnectionReuse,
   139  		ForceAttemptHTTP2:  conf.HTTP2,
   140  		MaxConnsPerHost:    conf.MaxConnections,
   141  		Proxy:              proxyFunc,
   142  		TLSClientConfig:    tlsConf,
   143  		TLSNextProto:       nextProto,
   144  	}
   145  
   146  	return transport
   147  }
   148  
   149  func (c *Config) WithTarget(scheme, origin, hostname, proxyURL string) *Config {
   150  	const defaultScheme = "http"
   151  	conf := *c
   152  	if scheme != "" {
   153  		conf.Scheme = scheme
   154  	} else {
   155  		conf.Scheme = defaultScheme
   156  		if conf.HTTP2 {
   157  			conf.Scheme += "s"
   158  		}
   159  	}
   160  
   161  	conf.Origin = origin
   162  	conf.Hostname = hostname
   163  
   164  	// Port required by transport.DialContext
   165  	_, p, _ := net.SplitHostPort(origin)
   166  	if p == "" {
   167  		const port, tlsPort = "80", "443"
   168  		if conf.Scheme == defaultScheme {
   169  			conf.Origin += ":" + port
   170  		} else {
   171  			conf.Origin += ":" + tlsPort
   172  		}
   173  	}
   174  
   175  	if proxyURL != "" {
   176  		conf.Proxy = proxyURL
   177  	}
   178  
   179  	return &conf
   180  }
   181  
   182  func (c *Config) WithTimings(connect, ttfb, timeout string, logger *logrus.Entry) *Config {
   183  	conf := *c
   184  	parseDuration(connect, &conf.ConnectTimeout, "connect_timeout", logger)
   185  	parseDuration(ttfb, &conf.TTFBTimeout, "ttfb_timeout", logger)
   186  	parseDuration(timeout, &conf.Timeout, "timeout", logger)
   187  	return &conf
   188  }
   189  
   190  // parseDuration sets the target value if the given duration string is not empty.
   191  func parseDuration(src string, target *time.Duration, attributeName string, logger *logrus.Entry) {
   192  	d, err := config.ParseDuration(attributeName, src, *target)
   193  	if err != nil {
   194  		logger.WithError(err).Warning("using default timing of ", target, " because an error occurred")
   195  	}
   196  	if src != "" && err != nil {
   197  		return
   198  	}
   199  	*target = d
   200  }