github.com/quickfeed/quickfeed@v0.0.0-20240507093252-ed8ca812a09c/web/server.go (about)

     1  package web
     2  
     3  import (
     4  	"context"
     5  	"crypto/tls"
     6  	"errors"
     7  	"fmt"
     8  	"log"
     9  	"net/http"
    10  	"time"
    11  
    12  	"github.com/quickfeed/quickfeed/internal/cert"
    13  	"github.com/quickfeed/quickfeed/internal/env"
    14  	"github.com/quickfeed/quickfeed/internal/multierr"
    15  	"github.com/quickfeed/quickfeed/metrics"
    16  	"golang.org/x/crypto/acme/autocert"
    17  )
    18  
    19  // hardcoded metrics server address
    20  const metricsServerAddr = "127.0.0.1:9097"
    21  
    22  type Server struct {
    23  	httpServer     *http.Server
    24  	redirectServer *http.Server
    25  	metricsServer  *http.Server
    26  	keyFile        string
    27  	certFile       string
    28  }
    29  
    30  type ServerType func(addr string, handler http.Handler) (*Server, error)
    31  
    32  func NewProductionServer(addr string, handler http.Handler) (*Server, error) {
    33  	whitelist, err := env.Whitelist()
    34  	if err != nil {
    35  		return nil, fmt.Errorf("failed to get whitelist: %w", err)
    36  	}
    37  	certManager := autocert.Manager{
    38  		Prompt: autocert.AcceptTOS,
    39  		Cache:  autocert.DirCache(env.CertPath()),
    40  		HostPolicy: autocert.HostWhitelist(
    41  			whitelist...,
    42  		),
    43  	}
    44  
    45  	httpServer := &http.Server{
    46  		Handler:           handler,
    47  		Addr:              addr,
    48  		ReadHeaderTimeout: 3 * time.Second, // to prevent Slowloris (CWE-400)
    49  		WriteTimeout:      2 * time.Minute,
    50  		ReadTimeout:       2 * time.Minute,
    51  		TLSConfig:         certManager.TLSConfig(),
    52  	}
    53  
    54  	redirectServer := &http.Server{
    55  		Handler:           certManager.HTTPHandler(nil),
    56  		Addr:              ":http",
    57  		ReadHeaderTimeout: 3 * time.Second, // to prevent Slowloris (CWE-400)
    58  	}
    59  
    60  	return &Server{
    61  		httpServer:     httpServer,
    62  		redirectServer: redirectServer,
    63  		metricsServer:  metricsServer(),
    64  	}, nil
    65  }
    66  
    67  func NewDevelopmentServer(addr string, handler http.Handler) (*Server, error) {
    68  	certificate, err := tls.LoadX509KeyPair(env.CertFile(), env.KeyFile())
    69  	if err != nil {
    70  		// Couldn't load credentials; generate self-signed certificates.
    71  		log.Println("Generating self-signed certificates.")
    72  		if err := cert.GenerateSelfSignedCert(cert.Options{
    73  			KeyFile:  env.KeyFile(),
    74  			CertFile: env.CertFile(),
    75  			Hosts:    env.Domain(),
    76  		}); err != nil {
    77  			return nil, fmt.Errorf("failed to generate self-signed certificates: %w", err)
    78  		}
    79  		log.Printf("Certificates successfully generated at: %s", env.CertPath())
    80  		log.Print("Adding certificate to local keychain (requires sudo access)")
    81  		if err := cert.AddTrustedCert(env.CertFile()); err != nil {
    82  			return nil, fmt.Errorf("failed to install self-signed certificate: %w", err)
    83  		}
    84  	} else {
    85  		log.Println("Existing credentials successfully loaded.")
    86  	}
    87  
    88  	httpServer := &http.Server{
    89  		Handler:           handler,
    90  		Addr:              addr,
    91  		ReadHeaderTimeout: 3 * time.Second, // to prevent Slowloris (CWE-400)
    92  		WriteTimeout:      2 * time.Minute,
    93  		ReadTimeout:       2 * time.Minute,
    94  		TLSConfig: &tls.Config{
    95  			Certificates: []tls.Certificate{certificate},
    96  			MinVersion:   tls.VersionTLS13,
    97  			MaxVersion:   tls.VersionTLS13,
    98  		},
    99  	}
   100  
   101  	return &Server{
   102  		httpServer:    httpServer,
   103  		metricsServer: metricsServer(),
   104  		keyFile:       env.KeyFile(),
   105  		certFile:      env.CertFile(),
   106  	}, nil
   107  }
   108  
   109  func metricsServer() *http.Server {
   110  	return &http.Server{
   111  		Handler:           metrics.Handler(),
   112  		Addr:              metricsServerAddr,
   113  		ReadHeaderTimeout: 3 * time.Second, // to prevent Slowloris (CWE-400)
   114  	}
   115  }
   116  
   117  // Serve starts the underlying http server and redirect server, if any.
   118  // This is a blocking call and must be called last.
   119  func (srv *Server) Serve() error {
   120  	if srv.redirectServer != nil {
   121  		// Redirect all HTTP traffic to HTTPS.
   122  		go func() {
   123  			if err := srv.redirectServer.ListenAndServe(); err != nil {
   124  				if !errors.Is(err, http.ErrServerClosed) {
   125  					log.Printf("Redirect server exited with unexpected error: %v", err)
   126  				}
   127  			}
   128  		}()
   129  	}
   130  	if srv.metricsServer != nil {
   131  		// Start HTTP server for Prometheus metrics collection.
   132  		go func() {
   133  			if err := srv.metricsServer.ListenAndServe(); err != nil {
   134  				if !errors.Is(err, http.ErrServerClosed) {
   135  					log.Printf("Metrics server exited with unexpected error: %v", err)
   136  				}
   137  			}
   138  		}()
   139  	}
   140  	// Start the HTTPS server.
   141  	// For production, the certFile and keyFile are empty and managed by autocert.
   142  	if err := srv.httpServer.ListenAndServeTLS(srv.certFile, srv.keyFile); err != nil {
   143  		if !errors.Is(err, http.ErrServerClosed) {
   144  			return fmt.Errorf("server exited with unexpected error: %w", err)
   145  		}
   146  	}
   147  	// Exit with nil means graceful shutdown
   148  	return nil
   149  }
   150  
   151  // Shutdown gracefully shuts down the server.
   152  func (srv *Server) Shutdown(ctx context.Context) error {
   153  	var redirectShutdownErr, metricsShutdownErr error
   154  	if srv.redirectServer != nil {
   155  		redirectShutdownErr = srv.redirectServer.Shutdown(ctx)
   156  	}
   157  	if srv.metricsServer != nil {
   158  		metricsShutdownErr = srv.metricsServer.Shutdown(ctx)
   159  	}
   160  	srvShutdownErr := srv.httpServer.Shutdown(ctx)
   161  	return multierr.Join(redirectShutdownErr, metricsShutdownErr, srvShutdownErr)
   162  }