github.com/richardwilkes/toolbox@v1.121.0/xio/network/xhttp/server.go (about)

     1  // Copyright (c) 2016-2024 by Richard A. Wilkes. All rights reserved.
     2  //
     3  // This Source Code Form is subject to the terms of the Mozilla Public
     4  // License, version 2.0. If a copy of the MPL was not distributed with
     5  // this file, You can obtain one at http://mozilla.org/MPL/2.0/.
     6  //
     7  // This Source Code Form is "Incompatible With Secondary Licenses", as
     8  // defined by the Mozilla Public License, version 2.0.
     9  
    10  package xhttp
    11  
    12  import (
    13  	"context"
    14  	"errors"
    15  	"fmt"
    16  	"log/slog"
    17  	"net"
    18  	"net/http"
    19  	"strconv"
    20  	"strings"
    21  	"time"
    22  
    23  	"github.com/richardwilkes/toolbox/atexit"
    24  	"github.com/richardwilkes/toolbox/errs"
    25  	"github.com/richardwilkes/toolbox/xio/network"
    26  )
    27  
    28  // Constants for protocols the server can provide.
    29  const (
    30  	ProtocolHTTP  = "http"
    31  	ProtocolHTTPS = "https"
    32  )
    33  
    34  type ctxKey int
    35  
    36  const metadataKey ctxKey = 1
    37  
    38  // Metadata holds auxiliary information for a request.
    39  type Metadata struct {
    40  	Logger *slog.Logger
    41  	User   string
    42  }
    43  
    44  // Server holds the data necessary for the server.
    45  type Server struct {
    46  	WebServer           *http.Server
    47  	Logger              *slog.Logger
    48  	clientHandler       http.Handler
    49  	StartedChan         chan struct{} // If not nil, will be closed once the server is ready to accept connections
    50  	ShutdownCallback    func(*slog.Logger)
    51  	CertFile            string
    52  	KeyFile             string
    53  	addresses           []string
    54  	ShutdownGracePeriod time.Duration
    55  	port                int
    56  	shutdownID          int
    57  }
    58  
    59  // Protocol returns the protocol this server is handling.
    60  func (s *Server) Protocol() string {
    61  	if s.CertFile != "" && s.KeyFile != "" {
    62  		return ProtocolHTTPS
    63  	}
    64  	return ProtocolHTTP
    65  }
    66  
    67  // Addresses returns the host addresses being listened to.
    68  func (s *Server) Addresses() []string {
    69  	return s.addresses
    70  }
    71  
    72  // Port returns the port being listened to.
    73  func (s *Server) Port() int {
    74  	return s.port
    75  }
    76  
    77  // LocalBaseURL returns the local base URL that will reach the server.
    78  func (s *Server) LocalBaseURL() string {
    79  	return fmt.Sprintf("%s://%s:%d", s.Protocol(), network.IPv4LoopbackAddress, s.port)
    80  }
    81  
    82  func (s *Server) String() string {
    83  	var buffer strings.Builder
    84  	buffer.WriteString(s.Protocol())
    85  	buffer.WriteString(" on ")
    86  	for i, addr := range s.addresses {
    87  		if i != 0 {
    88  			buffer.WriteString(", ")
    89  		}
    90  		fmt.Fprintf(&buffer, "%s:%d", addr, s.port)
    91  	}
    92  	return buffer.String()
    93  }
    94  
    95  // Run the server. Does not return until the server is shutdown.
    96  func (s *Server) Run() error {
    97  	s.shutdownID = atexit.Register(s.Shutdown)
    98  	if s.Logger == nil {
    99  		s.Logger = slog.Default()
   100  	}
   101  	s.clientHandler = s.WebServer.Handler
   102  	s.WebServer.Handler = s
   103  	var listener net.Listener
   104  	_, _, err := net.SplitHostPort(s.WebServer.Addr)
   105  	if err == nil {
   106  		listener, err = net.Listen("tcp", s.WebServer.Addr)
   107  	} else {
   108  		listener, err = net.Listen("tcp", net.JoinHostPort(s.WebServer.Addr, "0"))
   109  	}
   110  	if err != nil {
   111  		return errs.Wrap(err)
   112  	}
   113  	var host, portStr string
   114  	if host, portStr, err = net.SplitHostPort(listener.Addr().String()); err != nil {
   115  		return errs.Wrap(err)
   116  	}
   117  	if s.port, err = strconv.Atoi(portStr); err != nil {
   118  		return errs.Wrap(err)
   119  	}
   120  	s.addresses = network.AddressesForHost(host)
   121  	s.Logger.Info("listening", "protocol", s.Protocol(), "addresses", s.addresses, "port", s.port)
   122  	if s.StartedChan != nil {
   123  		go func() { close(s.StartedChan) }()
   124  	}
   125  	if s.Protocol() == ProtocolHTTPS {
   126  		err = s.WebServer.ServeTLS(listener, s.CertFile, s.KeyFile)
   127  	} else {
   128  		err = s.WebServer.Serve(listener)
   129  	}
   130  	if err != nil && !errors.Is(err, http.ErrServerClosed) {
   131  		return errs.Wrap(err)
   132  	}
   133  	return nil
   134  }
   135  
   136  // ServeHTTP implements the http.Handler interface.
   137  func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
   138  	started := time.Now()
   139  	sw := &StatusResponseWriter{
   140  		Original: w,
   141  		Head:     r.Method == http.MethodHead,
   142  	}
   143  	md := &Metadata{Logger: s.Logger.With("method", r.Method, "url", r.URL)}
   144  	r = r.WithContext(context.WithValue(r.Context(), metadataKey, md))
   145  	defer func() {
   146  		if recovered := recover(); recovered != nil {
   147  			err, ok := recovered.(error)
   148  			if !ok {
   149  				err = errs.Newf("%+v", recovered)
   150  			}
   151  			errs.LogTo(md.Logger, errs.NewWithCause("recovered from panic in handler", err))
   152  			ErrorStatus(sw, http.StatusInternalServerError)
   153  		}
   154  		since := time.Since(started)
   155  		millis := int64(since / time.Millisecond)
   156  		micros := int64(since/time.Microsecond) - millis*1000
   157  		written := sw.BytesWritten()
   158  		md.Logger.Info("web", "status", sw.Status(), "bytes", written, "elapsed",
   159  			fmt.Sprintf("%d.%03dms", millis, micros))
   160  	}()
   161  	s.clientHandler.ServeHTTP(sw, r)
   162  }
   163  
   164  // Shutdown the server gracefully.
   165  func (s *Server) Shutdown() {
   166  	atexit.Unregister(s.shutdownID)
   167  	startedAt := time.Now()
   168  	logger := s.Logger.With("protocol", s.Protocol(), "addresses", s.addresses, "port", s.port)
   169  	logger.Info("starting shutdown")
   170  	defer func() { logger.Info("finished shutdown", "elapsed", time.Since(startedAt)) }()
   171  	gracePeriod := s.ShutdownGracePeriod
   172  	if gracePeriod <= 0 {
   173  		gracePeriod = time.Minute
   174  	}
   175  	ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(gracePeriod))
   176  	defer cancel()
   177  	if err := s.WebServer.Shutdown(ctx); err != nil {
   178  		errs.LogTo(logger, errs.NewWithCause("unable to shutdown gracefully", err))
   179  	}
   180  	if s.ShutdownCallback != nil {
   181  		s.ShutdownCallback(logger)
   182  	}
   183  }
   184  
   185  // MetadataFromRequest returns the Metadata from the request.
   186  func MetadataFromRequest(req *http.Request) *Metadata {
   187  	if md, ok := req.Context().Value(metadataKey).(*Metadata); ok {
   188  		return md
   189  	}
   190  	return nil
   191  }