github.com/avenga/couper@v1.12.2/server/http.go (about)

     1  package server
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"log"
     7  	"net"
     8  	"net/http"
     9  	"strings"
    10  	"time"
    11  
    12  	"github.com/sirupsen/logrus"
    13  	"go.opentelemetry.io/otel/metric/instrument"
    14  	"go.opentelemetry.io/otel/metric/unit"
    15  
    16  	"github.com/avenga/couper/config"
    17  	"github.com/avenga/couper/config/env"
    18  	"github.com/avenga/couper/config/request"
    19  	"github.com/avenga/couper/config/runtime"
    20  	"github.com/avenga/couper/errors"
    21  	"github.com/avenga/couper/eval"
    22  	"github.com/avenga/couper/handler"
    23  	"github.com/avenga/couper/handler/middleware"
    24  	"github.com/avenga/couper/logging"
    25  	"github.com/avenga/couper/telemetry/instrumentation"
    26  	"github.com/avenga/couper/telemetry/provider"
    27  )
    28  
    29  type muxers map[string]*Mux
    30  
    31  // HTTPServer represents a configured HTTP server.
    32  type HTTPServer struct {
    33  	commandCtx context.Context
    34  	evalCtx    *eval.Context
    35  	listener   net.Listener
    36  	log        logrus.FieldLogger
    37  	muxers     muxers
    38  	port       string
    39  	settings   *config.Settings
    40  	shutdownCh chan struct{}
    41  	srv        *http.Server
    42  	timings    *runtime.HTTPTimings
    43  }
    44  
    45  // NewServers returns a list of the created and configured HTTP(s) servers.
    46  func NewServers(cmdCtx, evalCtx context.Context, log logrus.FieldLogger, settings *config.Settings,
    47  	timings *runtime.HTTPTimings, srvConf runtime.ServerConfiguration) ([]*HTTPServer, func(), error) {
    48  
    49  	var list []*HTTPServer
    50  
    51  	for port, hosts := range srvConf {
    52  		srv, err := New(cmdCtx, evalCtx, log, settings, timings, port, hosts)
    53  		if err != nil {
    54  			return nil, nil, err
    55  		}
    56  		list = append(list, srv)
    57  	}
    58  
    59  	handleShutdownFn := func() {
    60  		<-cmdCtx.Done()
    61  		time.Sleep(timings.ShutdownDelay + timings.ShutdownTimeout) // wait for max amount, TODO: feedback per server
    62  	}
    63  
    64  	return list, handleShutdownFn, nil
    65  }
    66  
    67  // New creates an HTTP(S) server with configured router and middlewares.
    68  func New(cmdCtx, evalCtx context.Context, log logrus.FieldLogger, settings *config.Settings,
    69  	timings *runtime.HTTPTimings, p runtime.Port, hosts runtime.Hosts) (*HTTPServer, error) {
    70  
    71  	logConf := *logging.DefaultConfig
    72  	logConf.TypeFieldKey = "couper_access"
    73  	env.DecodeWithPrefix(&logConf, "ACCESS_")
    74  
    75  	shutdownCh := make(chan struct{})
    76  
    77  	muxersList := make(muxers)
    78  	var serverTLS *config.ServerTLS
    79  	for host, muxOpts := range hosts {
    80  		mux := NewMux(muxOpts)
    81  		registerHandler(mux.endpointRoot, []string{http.MethodGet}, settings.HealthPath, handler.NewHealthCheck(settings.HealthPath, shutdownCh))
    82  		mux.RegisterConfigured()
    83  		muxersList[host] = mux
    84  
    85  		// TODO: refactor (hosts,muxOpts, etc) format type and usage
    86  		// serverOpts are all the same, pick first
    87  		if serverTLS == nil && muxOpts.ServerOptions != nil && muxOpts.ServerOptions.TLS != nil {
    88  			serverTLS = muxOpts.ServerOptions.TLS
    89  		}
    90  	}
    91  
    92  	httpSrv := &HTTPServer{
    93  		evalCtx:    evalCtx.Value(request.ContextType).(*eval.Context),
    94  		commandCtx: cmdCtx,
    95  		log:        log,
    96  		muxers:     muxersList,
    97  		port:       p.String(),
    98  		settings:   settings,
    99  		shutdownCh: shutdownCh,
   100  		timings:    timings,
   101  	}
   102  
   103  	accessLog := logging.NewAccessLog(&logConf, log)
   104  
   105  	// order matters
   106  	telemetryHandler := middleware.NewHandler(httpSrv, nil) // fallback to plain wrapper without telemetry options
   107  	if settings.TelemetryMetrics {
   108  		telemetryHandler = middleware.NewMetricsHandler()(httpSrv)
   109  	}
   110  	if settings.TelemetryTraces {
   111  		telemetryHandler = middleware.NewTraceHandler()(telemetryHandler)
   112  	}
   113  
   114  	uidHandler := middleware.NewUIDHandler(settings, httpsDevProxyIDField)(telemetryHandler)
   115  	logHandler := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
   116  		uidHandler.ServeHTTP(rw, req)
   117  		accessLog.Do(rw, req)
   118  	})
   119  	recordHandler := middleware.NewRecordHandler(settings.SecureCookies)(logHandler)
   120  	startTimeHandler := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
   121  		recordHandler.ServeHTTP(rw, r.WithContext(
   122  			context.WithValue(r.Context(), request.StartTime, time.Now())))
   123  	})
   124  
   125  	srv := &http.Server{
   126  		Addr:              ":" + p.String(),
   127  		ErrorLog:          newErrorLogWrapper(log),
   128  		Handler:           startTimeHandler,
   129  		IdleTimeout:       timings.IdleTimeout,
   130  		ReadHeaderTimeout: timings.ReadHeaderTimeout,
   131  	}
   132  
   133  	if settings.TelemetryMetrics {
   134  		srv.ConnState = httpSrv.onConnState
   135  	}
   136  
   137  	if serverTLS != nil {
   138  		tlsConfig, err := newTLSConfig(serverTLS, log)
   139  		if err != nil {
   140  			return nil, err
   141  		}
   142  		srv.TLSConfig = tlsConfig
   143  	}
   144  
   145  	httpSrv.srv = srv
   146  
   147  	return httpSrv, nil
   148  }
   149  
   150  // Addr returns the listener address.
   151  func (s *HTTPServer) Addr() string {
   152  	if s.listener != nil {
   153  		return s.listener.Addr().String()
   154  	}
   155  	return ""
   156  }
   157  
   158  // Listen initiates the configured http handler and start listing on given port.
   159  func (s *HTTPServer) Listen() error {
   160  	if s.srv.Addr == "" {
   161  		s.srv.Addr = ":http"
   162  		if s.srv.TLSConfig != nil {
   163  			s.srv.Addr += "s"
   164  		}
   165  	}
   166  
   167  	ln, err := net.Listen("tcp4", s.srv.Addr)
   168  	if err != nil {
   169  		return err
   170  	}
   171  
   172  	s.listener = ln
   173  	s.log.Infof("couper is serving: %s", ln.Addr().String())
   174  
   175  	go s.listenForCtx()
   176  
   177  	go func() {
   178  		var serveErr error
   179  		if s.srv.TLSConfig != nil {
   180  			serveErr = s.srv.ServeTLS(s.listener, "", "")
   181  		} else {
   182  			serveErr = s.srv.Serve(ln)
   183  		}
   184  
   185  		if serveErr != nil {
   186  			if serveErr == http.ErrServerClosed {
   187  				s.log.Infof("%v: %s", serveErr, ln.Addr().String())
   188  			} else {
   189  				s.log.Errorf("%s: %v", ln.Addr().String(), serveErr)
   190  			}
   191  		}
   192  	}()
   193  	return nil
   194  }
   195  
   196  // Close closes the listener
   197  func (s *HTTPServer) Close() error {
   198  	return s.listener.Close()
   199  }
   200  
   201  func (s *HTTPServer) listenForCtx() {
   202  	<-s.commandCtx.Done()
   203  
   204  	logFields := logrus.Fields{
   205  		"delay":    s.timings.ShutdownDelay.String(),
   206  		"deadline": s.timings.ShutdownTimeout.String(),
   207  	}
   208  
   209  	s.log.WithFields(logFields).Warn("shutting down")
   210  	close(s.shutdownCh)
   211  
   212  	time.Sleep(s.timings.ShutdownDelay)
   213  	ctx := context.Background()
   214  	if s.timings.ShutdownTimeout > 0 {
   215  		c, cancel := context.WithTimeout(ctx, s.timings.ShutdownTimeout)
   216  		defer cancel()
   217  		ctx = c
   218  	}
   219  
   220  	if err := s.srv.Shutdown(ctx); err != nil {
   221  		s.log.WithFields(logFields).Error(err)
   222  	}
   223  }
   224  
   225  func (s *HTTPServer) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
   226  	var h http.Handler
   227  
   228  	req.Host = s.getHost(req)
   229  	host, _, err := runtime.GetHostPort(req.Host)
   230  	if err != nil {
   231  		h = errors.DefaultHTML.WithError(errors.ClientRequest)
   232  	}
   233  
   234  	mux, ok := s.muxers[host]
   235  	if !ok {
   236  		mux, ok = s.muxers["*"]
   237  		if !ok && h == nil {
   238  			h = errors.DefaultHTML.WithError(errors.Configuration)
   239  		}
   240  	}
   241  
   242  	if h == nil {
   243  		// mux.FindHandler() exchanges the req: *req = *req.WithContext(ctx)
   244  		h = mux.FindHandler(req)
   245  	}
   246  
   247  	ctx := context.WithValue(req.Context(), request.LogEntry, s.log)
   248  	ctx = context.WithValue(ctx, request.XFF, req.Header.Get("X-Forwarded-For"))
   249  
   250  	// set innermost handler name for logging purposes
   251  	if hs, stringer := getChildHandler(h).(fmt.Stringer); stringer {
   252  		ctx = context.WithValue(ctx, request.Handler, hs.String())
   253  	}
   254  
   255  	if err = s.setGetBody(h, req); err != nil {
   256  		h = mux.opts.ServerOptions.ServerErrTpl.WithError(err)
   257  	}
   258  
   259  	req.URL.Host = req.Host
   260  	req.URL.Scheme = "http"
   261  	if req.TLS != nil && req.TLS.HandshakeComplete {
   262  		req.URL.Scheme += "s"
   263  	}
   264  
   265  	if s.settings.AcceptsForwardedProtocol() {
   266  		if xfpr := req.Header.Get("X-Forwarded-Proto"); xfpr != "" {
   267  			req.URL.Scheme = xfpr
   268  			req.URL.Host = req.URL.Hostname()
   269  		}
   270  	}
   271  	if s.settings.AcceptsForwardedHost() {
   272  		if xfh := req.Header.Get("X-Forwarded-Host"); xfh != "" {
   273  			portToAppend := req.URL.Port()
   274  			req.URL.Host = xfh
   275  			if portToAppend != "" && req.URL.Port() == "" {
   276  				req.URL.Host += ":" + portToAppend
   277  			}
   278  		}
   279  	}
   280  	if s.settings.AcceptsForwardedPort() {
   281  		if xfpo := req.Header.Get("X-Forwarded-Port"); xfpo != "" {
   282  			req.URL.Host = req.URL.Hostname() + ":" + xfpo
   283  		}
   284  	}
   285  
   286  	// due to the middleware callee stack we have to update the 'req' value.
   287  	*req = *req.WithContext(s.evalCtx.WithClientRequest(req.WithContext(ctx)))
   288  
   289  	h.ServeHTTP(rw, req)
   290  }
   291  
   292  func (s *HTTPServer) setGetBody(h http.Handler, req *http.Request) error {
   293  	inner := getChildHandler(h)
   294  
   295  	var err error
   296  	if limitHandler, ok := inner.(handler.BodyLimit); ok {
   297  		err = eval.SetGetBody(req, limitHandler.BufferOptions(), limitHandler.RequestLimit())
   298  	}
   299  	return err
   300  }
   301  
   302  // getHost configures the host from the incoming request host based on
   303  // the xfh setting and listener port to be prepared for the http multiplexer.
   304  func (s *HTTPServer) getHost(req *http.Request) string {
   305  	host := req.Host
   306  	if s.settings.XForwardedHost {
   307  		if xfh := req.Header.Get("X-Forwarded-Host"); xfh != "" {
   308  			host = xfh
   309  		}
   310  	}
   311  
   312  	host = strings.ToLower(host)
   313  
   314  	if !strings.Contains(host, ":") {
   315  		return s.cleanHostAppendPort(host)
   316  	}
   317  
   318  	h, _, err := net.SplitHostPort(host)
   319  	if err != nil {
   320  		return s.cleanHostAppendPort(host)
   321  	}
   322  
   323  	return s.cleanHostAppendPort(h)
   324  }
   325  
   326  func (s *HTTPServer) cleanHostAppendPort(host string) string {
   327  	return strings.TrimSuffix(host, ".") + ":" + s.port
   328  }
   329  
   330  func (s *HTTPServer) onConnState(_ net.Conn, state http.ConnState) {
   331  	meter := provider.Meter("couper/server")
   332  	counter, _ := meter.SyncInt64().
   333  		Counter(instrumentation.ClientConnectionsTotal, instrument.WithDescription(string(unit.Dimensionless)))
   334  	gauge, _ := meter.SyncFloat64().UpDownCounter(
   335  		instrumentation.ClientConnections,
   336  		instrument.WithDescription(string(unit.Dimensionless)),
   337  	)
   338  
   339  	if state == http.StateNew {
   340  		counter.Add(context.Background(), 1)
   341  		gauge.Add(context.Background(), 1)
   342  		// we have no callback for closing a hijacked one, so count them down too.
   343  		// TODO: if required we COULD override given conn ptr value with own obj.
   344  	} else if state == http.StateClosed || state == http.StateHijacked {
   345  		gauge.Add(context.Background(), -1)
   346  	}
   347  }
   348  
   349  // getChildHandler returns the innermost handler which supports the Child interface.
   350  func getChildHandler(handler http.Handler) http.Handler {
   351  	outer := handler
   352  	for {
   353  		if inner, ok := outer.(interface{ Child() http.Handler }); ok {
   354  			outer = inner.Child()
   355  			continue
   356  		}
   357  		break
   358  	}
   359  	return outer
   360  }
   361  
   362  // ErrorWrapper logs incoming Write bytes with the context filled logrus.FieldLogger.
   363  type ErrorWrapper struct{ l logrus.FieldLogger }
   364  
   365  func (e *ErrorWrapper) Write(p []byte) (n int, err error) {
   366  	msg := string(p)
   367  	if strings.HasSuffix(msg, " tls: unknown certificate") {
   368  		return len(p), nil // triggered on first browser connect for self signed certs; skip
   369  	}
   370  
   371  	e.l.Error(strings.TrimSpace(msg))
   372  	return len(p), nil
   373  }
   374  func newErrorLogWrapper(logger logrus.FieldLogger) *log.Logger {
   375  	return log.New(&ErrorWrapper{logger}, "", log.Lmsgprefix)
   376  }