github.com/google/go-safeweb@v0.0.0-20231219055052-64d8cfc90fbb/safehttp/server.go (about)

     1  // Copyright 2020 Google LLC
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //	https://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package safehttp
    16  
    17  import (
    18  	"context"
    19  	"crypto/tls"
    20  	"errors"
    21  	"net"
    22  	"net/http"
    23  	"time"
    24  )
    25  
    26  // Server is a safe wrapper for a standard HTTP server.
    27  // The zero value is safe and ready to use and will apply safe defaults on serving.
    28  // Changing any of the fields after the server has been started is a no-op.
    29  type Server struct {
    30  	// Addr optionally specifies the TCP address for the server to listen on,
    31  	// in the form "host:port". If empty, ":http" (port 80) is used.
    32  	// The service names are defined in RFC 6335 and assigned by IANA.
    33  	// See net.Dial for details of the address format.
    34  	Addr string
    35  
    36  	// Mux is the ServeMux to use for the current server. A nil Mux is invalid.
    37  	Mux *ServeMux
    38  
    39  	// TODO(empijei): potentially consider exposing ReadHeaderTimeout for
    40  	// fine-grained handling (e.g. websocket endpoints).
    41  
    42  	// ReadTimeout is the maximum duration for reading the entire
    43  	// request, including the body.
    44  	ReadTimeout time.Duration
    45  
    46  	// WriteTimeout is the maximum duration before timing out
    47  	// writes of the response. It is reset whenever a new
    48  	// request's header is read.
    49  	WriteTimeout time.Duration
    50  
    51  	// IdleTimeout is the maximum amount of time to wait for the
    52  	// next request when keep-alives are enabled.
    53  	IdleTimeout time.Duration
    54  
    55  	// MaxHeaderBytes controls the maximum number of bytes the
    56  	// server will read parsing the request header's keys and
    57  	// values, including the request line. It does not limit the
    58  	// size of the request body.
    59  	MaxHeaderBytes int
    60  
    61  	// TLSConfig optionally provides a TLS configuration for use
    62  	// by ServeTLS and ListenAndServeTLS. Note that this value is
    63  	// cloned on serving, so it's not possible to modify the
    64  	// configuration with methods like tls.Config.SetSessionTicketKeys.
    65  	//
    66  	// When the server is started the cloned configuration will be changed
    67  	// to set the minimum TLS version to 1.2 and to prefer Server Ciphers.
    68  	TLSConfig *tls.Config
    69  
    70  	// OnShutdown is a slice of functions to call on Shutdown.
    71  	// This can be used to gracefully shutdown connections that have undergone
    72  	// ALPN protocol upgrade or that have been hijacked.
    73  	// These functions should start protocol-specific graceful shutdown, but
    74  	// should not wait for shutdown to complete.
    75  	OnShudown []func()
    76  
    77  	// DisableKeepAlives controls whether HTTP keep-alives should be disabled.
    78  	DisableKeepAlives bool
    79  
    80  	srv     *http.Server
    81  	started bool
    82  }
    83  
    84  func (s *Server) buildStd() error {
    85  	if s.started {
    86  		return errors.New("server already started")
    87  	}
    88  	if s.srv != nil {
    89  		// Server was already built
    90  		return nil
    91  	}
    92  	if s.Mux == nil {
    93  		return errors.New("building server without a mux")
    94  	}
    95  
    96  	srv := &http.Server{
    97  		Addr:           s.Addr,
    98  		Handler:        s.Mux,
    99  		ReadTimeout:    5 * time.Second,
   100  		WriteTimeout:   5 * time.Second,
   101  		IdleTimeout:    120 * time.Second,
   102  		MaxHeaderBytes: 10 * 1024,
   103  	}
   104  	if s.ReadTimeout != 0 {
   105  		srv.ReadTimeout = s.ReadTimeout
   106  	}
   107  	if s.WriteTimeout != 0 {
   108  		srv.WriteTimeout = s.WriteTimeout
   109  	}
   110  	if s.IdleTimeout != 0 {
   111  		srv.IdleTimeout = s.IdleTimeout
   112  	}
   113  	if s.MaxHeaderBytes != 0 {
   114  		srv.MaxHeaderBytes = s.MaxHeaderBytes
   115  	}
   116  	if s.TLSConfig != nil {
   117  		cfg := s.TLSConfig.Clone()
   118  		cfg.MinVersion = tls.VersionTLS12
   119  		cfg.PreferServerCipherSuites = true
   120  		srv.TLSConfig = cfg
   121  	}
   122  	for _, f := range s.OnShudown {
   123  		srv.RegisterOnShutdown(f)
   124  	}
   125  	if s.DisableKeepAlives {
   126  		srv.SetKeepAlivesEnabled(false)
   127  	}
   128  	s.srv = srv
   129  	return nil
   130  }
   131  
   132  // Clone returns an unstarted deep copy of Server that can be re-configured and re-started.
   133  func (s *Server) Clone() *Server {
   134  	cln := *s
   135  	cln.started = false
   136  	cln.TLSConfig = s.TLSConfig.Clone()
   137  	cln.srv = nil
   138  	return &cln
   139  }
   140  
   141  // ListenAndServe is a wrapper for https://golang.org/pkg/net/http/#Server.ListenAndServe
   142  func (s *Server) ListenAndServe() error {
   143  	if err := s.buildStd(); err != nil {
   144  		return err
   145  	}
   146  	s.started = true
   147  	return s.srv.ListenAndServe()
   148  }
   149  
   150  // ListenAndServeTLS is a wrapper for https://golang.org/pkg/net/http/#Server.ListenAndServeTLS
   151  func (s *Server) ListenAndServeTLS(certFile, keyFile string) error {
   152  	if err := s.buildStd(); err != nil {
   153  		return err
   154  	}
   155  	s.started = true
   156  	return s.srv.ListenAndServeTLS(certFile, keyFile)
   157  }
   158  
   159  // Serve is a wrapper for https://golang.org/pkg/net/http/#Server.Serve
   160  func (s *Server) Serve(l net.Listener) error {
   161  	if err := s.buildStd(); err != nil {
   162  		return err
   163  	}
   164  	s.started = true
   165  	return s.srv.Serve(l)
   166  }
   167  
   168  // ServeTLS is a wrapper for https://golang.org/pkg/net/http/#Server.ServeTLS
   169  func (s *Server) ServeTLS(l net.Listener, certFile, keyFile string) error {
   170  	if err := s.buildStd(); err != nil {
   171  		return err
   172  	}
   173  	s.started = true
   174  	return s.srv.ServeTLS(l, certFile, keyFile)
   175  }
   176  
   177  // Shutdown is a wrapper for https://golang.org/pkg/net/http/#Server.Shutdown
   178  func (s *Server) Shutdown(ctx context.Context) error {
   179  	if !s.started {
   180  		return errors.New("shutting down unstarted server")
   181  	}
   182  	s.srv.SetKeepAlivesEnabled(false)
   183  	return s.srv.Shutdown(ctx)
   184  }
   185  
   186  // Close is a wrapper for https://golang.org/pkg/net/http/#Server.Close
   187  func (s *Server) Close() error {
   188  	if !s.started {
   189  		return errors.New("closing unstarted server")
   190  	}
   191  	return s.srv.Close()
   192  }