github.com/mattyw/juju@v0.0.0-20140610034352-732aecd63861/testing/http.go (about)

     1  // Copyright 2012, 2013 Canonical Ltd.
     2  // Licensed under the AGPLv3, see LICENCE file for details.
     3  
     4  package testing
     5  
     6  import (
     7  	"bytes"
     8  	"fmt"
     9  	"io/ioutil"
    10  	"net"
    11  	"net/http"
    12  	"os"
    13  	"time"
    14  
    15  	gc "launchpad.net/gocheck"
    16  )
    17  
    18  type HTTPSuite struct{}
    19  
    20  var Server = NewHTTPServer(5 * time.Second)
    21  
    22  func (s *HTTPSuite) SetUpSuite(c *gc.C) {
    23  	Server.Start()
    24  }
    25  func (s *HTTPSuite) TearDownSuite(c *gc.C) {}
    26  
    27  func (s *HTTPSuite) SetUpTest(c *gc.C) {}
    28  
    29  func (s *HTTPSuite) TearDownTest(c *gc.C) {
    30  	Server.Flush()
    31  }
    32  
    33  func (s *HTTPSuite) URL(path string) string {
    34  	return Server.URL + path
    35  }
    36  
    37  type HTTPServer struct {
    38  	URL      string
    39  	Timeout  time.Duration
    40  	started  bool
    41  	request  chan *http.Request
    42  	response chan ResponseFunc
    43  }
    44  
    45  func NewHTTPServer(timeout time.Duration) *HTTPServer {
    46  	return &HTTPServer{Timeout: timeout}
    47  }
    48  
    49  type Response struct {
    50  	Status  int
    51  	Headers map[string]string
    52  	Body    []byte
    53  }
    54  
    55  type ResponseFunc func(path string) Response
    56  
    57  func (s *HTTPServer) Start() {
    58  	if s.started {
    59  		return
    60  	}
    61  	s.started = true
    62  	s.request = make(chan *http.Request, 64)
    63  	s.response = make(chan ResponseFunc, 64)
    64  
    65  	l, err := net.Listen("tcp", "localhost:0")
    66  	if err != nil {
    67  		panic(err)
    68  	}
    69  	port := l.Addr().(*net.TCPAddr).Port
    70  	s.URL = fmt.Sprintf("http://localhost:%d", port)
    71  	go http.Serve(l, s)
    72  
    73  	s.Response(203, nil, nil)
    74  	for {
    75  		// Wait for it to be up.
    76  		resp, err := http.Get(s.URL)
    77  		if err == nil && resp.StatusCode == 203 {
    78  			break
    79  		}
    80  		time.Sleep(1e8)
    81  	}
    82  	s.WaitRequest() // Consume dummy request.
    83  }
    84  
    85  // Flush discards all pending requests and responses.
    86  func (s *HTTPServer) Flush() {
    87  	for {
    88  		select {
    89  		case <-s.request:
    90  		case <-s.response:
    91  		default:
    92  			return
    93  		}
    94  	}
    95  }
    96  
    97  func body(req *http.Request) string {
    98  	data, err := ioutil.ReadAll(req.Body)
    99  	if err != nil {
   100  		panic(err)
   101  	}
   102  	return string(data)
   103  }
   104  
   105  func (s *HTTPServer) ServeHTTP(w http.ResponseWriter, req *http.Request) {
   106  	req.ParseMultipartForm(1e6)
   107  	data, err := ioutil.ReadAll(req.Body)
   108  	if err != nil {
   109  		panic(err)
   110  	}
   111  	req.Body = ioutil.NopCloser(bytes.NewBuffer(data))
   112  	s.request <- req
   113  	var resp Response
   114  	select {
   115  	case respFunc := <-s.response:
   116  		resp = respFunc(req.URL.Path)
   117  	case <-time.After(s.Timeout):
   118  		const msg = "ERROR: Timeout waiting for test to prepare a response\n"
   119  		fmt.Fprintf(os.Stderr, msg)
   120  		resp = Response{500, nil, []byte(msg)}
   121  	}
   122  	if resp.Headers != nil {
   123  		h := w.Header()
   124  		for k, v := range resp.Headers {
   125  			h.Set(k, v)
   126  		}
   127  	}
   128  	if resp.Status != 0 {
   129  		w.WriteHeader(resp.Status)
   130  	}
   131  	w.Write(resp.Body)
   132  }
   133  
   134  // WaitRequests returns the next n requests made to the http server from
   135  // the queue. If not enough requests were previously made, it waits until
   136  // the timeout value for them to be made.
   137  func (s *HTTPServer) WaitRequests(n int) []*http.Request {
   138  	reqs := make([]*http.Request, 0, n)
   139  	for i := 0; i < n; i++ {
   140  		select {
   141  		case req := <-s.request:
   142  			reqs = append(reqs, req)
   143  		case <-time.After(s.Timeout):
   144  			panic("Timeout waiting for request")
   145  		}
   146  	}
   147  	return reqs
   148  }
   149  
   150  // WaitRequest returns the next request made to the http server from
   151  // the queue. If no requests were previously made, it waits until the
   152  // timeout value for one to be made.
   153  func (s *HTTPServer) WaitRequest() *http.Request {
   154  	return s.WaitRequests(1)[0]
   155  }
   156  
   157  // ResponseFunc prepares the test server to respond the following n
   158  // requests using f to build each response.
   159  func (s *HTTPServer) ResponseFunc(n int, f ResponseFunc) {
   160  	for i := 0; i < n; i++ {
   161  		s.response <- f
   162  	}
   163  }
   164  
   165  // ResponseMap maps request paths to responses.
   166  type ResponseMap map[string]Response
   167  
   168  // ResponseMap prepares the test server to respond the following n
   169  // requests using the m to obtain the responses.
   170  func (s *HTTPServer) ResponseMap(n int, m ResponseMap) {
   171  	f := func(path string) Response {
   172  		for rpath, resp := range m {
   173  			if rpath == path {
   174  				return resp
   175  			}
   176  		}
   177  		body := []byte("Path not found in response map: " + path)
   178  		return Response{Status: 500, Body: body}
   179  	}
   180  	s.ResponseFunc(n, f)
   181  }
   182  
   183  // Responses prepares the test server to respond the following n requests
   184  // using the provided response parameters.
   185  func (s *HTTPServer) Responses(n int, status int, headers map[string]string, body []byte) {
   186  	f := func(path string) Response {
   187  		return Response{status, headers, body}
   188  	}
   189  	s.ResponseFunc(n, f)
   190  }
   191  
   192  // Response prepares the test server to respond the following request
   193  // using the provided response parameters.
   194  func (s *HTTPServer) Response(status int, headers map[string]string, body []byte) {
   195  	s.Responses(1, status, headers, body)
   196  }