github.com/olivere/camlistore@v0.0.0-20140121221811-1b7ac2da0199/pkg/webserver/webserver.go (about)

     1  /*
     2  Copyright 2011 Google Inc.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8       http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  // Package webserver implements a superset wrapper of http.Server.
    18  //
    19  // Among other things, it can throttle its connections, inherit its
    20  // listening socket from a file descriptor in the environment, and
    21  // log all activity.
    22  package webserver
    23  
    24  import (
    25  	"bufio"
    26  	"crypto/rand"
    27  	"crypto/tls"
    28  	"fmt"
    29  	"io"
    30  	"log"
    31  	"net"
    32  	"net/http"
    33  	"os"
    34  	"strconv"
    35  	"strings"
    36  	"sync"
    37  	"time"
    38  
    39  	"camlistore.org/pkg/throttle"
    40  	"camlistore.org/third_party/github.com/bradfitz/runsit/listen"
    41  )
    42  
    43  type Server struct {
    44  	mux      *http.ServeMux
    45  	listener net.Listener
    46  	verbose  bool // log HTTP requests and response codes
    47  
    48  	enableTLS               bool
    49  	tlsCertFile, tlsKeyFile string
    50  
    51  	mu   sync.Mutex
    52  	reqs int64
    53  }
    54  
    55  func New() *Server {
    56  	verbose, _ := strconv.ParseBool(os.Getenv("CAMLI_HTTP_DEBUG"))
    57  	return &Server{
    58  		mux:     http.NewServeMux(),
    59  		verbose: verbose,
    60  	}
    61  }
    62  
    63  func (s *Server) SetTLS(certFile, keyFile string) {
    64  	s.enableTLS = true
    65  	s.tlsCertFile = certFile
    66  	s.tlsKeyFile = keyFile
    67  }
    68  
    69  func (s *Server) ListenURL() string {
    70  	scheme := "http"
    71  	if s.enableTLS {
    72  		scheme = "https"
    73  	}
    74  	if s.listener != nil {
    75  		if taddr, ok := s.listener.Addr().(*net.TCPAddr); ok {
    76  			if taddr.IP.IsUnspecified() {
    77  				return fmt.Sprintf("%s://localhost:%d", scheme, taddr.Port)
    78  			}
    79  			return fmt.Sprintf("%s://%s", scheme, s.listener.Addr())
    80  		}
    81  	}
    82  	return ""
    83  }
    84  
    85  func (s *Server) HandleFunc(pattern string, fn func(http.ResponseWriter, *http.Request)) {
    86  	s.mux.HandleFunc(pattern, fn)
    87  }
    88  
    89  func (s *Server) Handle(pattern string, handler http.Handler) {
    90  	s.mux.Handle(pattern, handler)
    91  }
    92  
    93  func (s *Server) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
    94  	var n int64
    95  	if s.verbose {
    96  		s.mu.Lock()
    97  		s.reqs++
    98  		n = s.reqs
    99  		s.mu.Unlock()
   100  		log.Printf("Request #%d: %s %s ...", n, req.Method, req.RequestURI)
   101  		rw = &trackResponseWriter{ResponseWriter: rw}
   102  	}
   103  	s.mux.ServeHTTP(rw, req)
   104  	if s.verbose {
   105  		tw := rw.(*trackResponseWriter)
   106  		log.Printf("Request #%d: %s %s = code %d, %d bytes", n, req.Method, req.RequestURI, tw.code, tw.resSize)
   107  	}
   108  }
   109  
   110  type trackResponseWriter struct {
   111  	http.ResponseWriter
   112  	code    int
   113  	resSize int64
   114  }
   115  
   116  func (tw *trackResponseWriter) WriteHeader(code int) {
   117  	tw.code = code
   118  	tw.ResponseWriter.WriteHeader(code)
   119  }
   120  
   121  func (tw *trackResponseWriter) Write(p []byte) (int, error) {
   122  	if tw.code == 0 {
   123  		tw.code = 200
   124  	}
   125  	tw.resSize += int64(len(p))
   126  	return tw.ResponseWriter.Write(p)
   127  }
   128  
   129  // Listen starts listening on the given host:port addr.
   130  func (s *Server) Listen(addr string) error {
   131  	if s.listener != nil {
   132  		return nil
   133  	}
   134  
   135  	doLog := os.Getenv("TESTING_PORT_WRITE_FD") == "" // Don't make noise during unit tests
   136  	if addr == "" {
   137  		return fmt.Errorf("<host>:<port> needs to be provided to start listening")
   138  	}
   139  
   140  	var err error
   141  	s.listener, err = listen.Listen(addr)
   142  	if err != nil {
   143  		return fmt.Errorf("Failed to listen on %s: %v", addr, err)
   144  	}
   145  	base := s.ListenURL()
   146  	if doLog {
   147  		log.Printf("Starting to listen on %s\n", base)
   148  	}
   149  
   150  	if s.enableTLS {
   151  		config := &tls.Config{
   152  			Rand:       rand.Reader,
   153  			Time:       time.Now,
   154  			NextProtos: []string{"http/1.1"},
   155  		}
   156  		config.Certificates = make([]tls.Certificate, 1)
   157  		config.Certificates[0], err = tls.LoadX509KeyPair(s.tlsCertFile, s.tlsKeyFile)
   158  		if err != nil {
   159  			return fmt.Errorf("Failed to load TLS cert: %v", err)
   160  		}
   161  		s.listener = tls.NewListener(s.listener, config)
   162  	}
   163  
   164  	if doLog && strings.HasSuffix(base, ":0") {
   165  		log.Printf("Now listening on %s\n", s.ListenURL())
   166  	}
   167  
   168  	return nil
   169  }
   170  
   171  func (s *Server) throttleListener() net.Listener {
   172  	kBps, _ := strconv.Atoi(os.Getenv("DEV_THROTTLE_KBPS"))
   173  	ms, _ := strconv.Atoi(os.Getenv("DEV_THROTTLE_LATENCY_MS"))
   174  	if kBps == 0 && ms == 0 {
   175  		return s.listener
   176  	}
   177  	rate := throttle.Rate{
   178  		KBps:    kBps,
   179  		Latency: time.Duration(ms) * time.Millisecond,
   180  	}
   181  	return &throttle.Listener{
   182  		Listener: s.listener,
   183  		Down:     rate,
   184  		Up:       rate, // TODO: separate rates?
   185  	}
   186  }
   187  
   188  func (s *Server) Serve() {
   189  	if err := s.Listen(""); err != nil {
   190  		log.Fatalf("Listen error: %v", err)
   191  	}
   192  	go runTestHarnessIntegration(s.listener)
   193  	err := http.Serve(s.throttleListener(), s)
   194  	if err != nil {
   195  		log.Printf("Error in http server: %v\n", err)
   196  		os.Exit(1)
   197  	}
   198  }
   199  
   200  // Signals the test harness that we've started listening.
   201  // TODO: write back the port number that we randomly selected?
   202  // For now just writes back a single byte.
   203  func runTestHarnessIntegration(listener net.Listener) {
   204  	writePipe, err := pipeFromEnvFd("TESTING_PORT_WRITE_FD")
   205  	if err != nil {
   206  		return
   207  	}
   208  	readPipe, _ := pipeFromEnvFd("TESTING_CONTROL_READ_FD")
   209  
   210  	if writePipe != nil {
   211  		writePipe.Write([]byte(listener.Addr().String() + "\n"))
   212  	}
   213  
   214  	if readPipe != nil {
   215  		bufr := bufio.NewReader(readPipe)
   216  		for {
   217  			line, err := bufr.ReadString('\n')
   218  			if err == io.EOF || line == "EXIT\n" {
   219  				os.Exit(0)
   220  			}
   221  			return
   222  		}
   223  	}
   224  }