github.com/juju/juju@v0.0.0-20240430160146-1752b71fcf00/apiserver/apiserverhttp/mux_test.go (about) 1 // Copyright 2018 Canonical Ltd. 2 // Licensed under the AGPLv3, see LICENCE file for details. 3 4 package apiserverhttp_test 5 6 import ( 7 "net/http" 8 "net/http/httptest" 9 "sync" 10 "time" 11 12 "github.com/juju/testing" 13 jc "github.com/juju/testing/checkers" 14 gc "gopkg.in/check.v1" 15 16 "github.com/juju/juju/apiserver/apiserverhttp" 17 coretesting "github.com/juju/juju/testing" 18 ) 19 20 type MuxSuite struct { 21 testing.IsolationSuite 22 mux *apiserverhttp.Mux 23 server *httptest.Server 24 client *http.Client 25 } 26 27 var _ = gc.Suite(&MuxSuite{}) 28 29 func (s *MuxSuite) SetUpTest(c *gc.C) { 30 s.IsolationSuite.SetUpTest(c) 31 s.mux = apiserverhttp.NewMux() 32 s.server = httptest.NewServer(s.mux) 33 s.client = s.server.Client() 34 s.AddCleanup(func(c *gc.C) { 35 s.server.Close() 36 }) 37 } 38 39 func (s *MuxSuite) TestNotFound(c *gc.C) { 40 resp, err := s.client.Get(s.server.URL + "/") 41 c.Assert(err, jc.ErrorIsNil) 42 defer resp.Body.Close() 43 44 c.Assert(resp.StatusCode, gc.Equals, http.StatusNotFound) 45 } 46 47 func (s *MuxSuite) TestAddHandler(c *gc.C) { 48 err := s.mux.AddHandler("GET", "/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) 49 c.Assert(err, jc.ErrorIsNil) 50 51 resp, err := s.client.Get(s.server.URL + "/") 52 c.Assert(err, jc.ErrorIsNil) 53 defer resp.Body.Close() 54 55 c.Assert(resp.StatusCode, gc.Equals, http.StatusOK) 56 } 57 58 func (s *MuxSuite) TestAddRemoveNotFound(c *gc.C) { 59 s.mux.AddHandler("GET", "/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) 60 s.mux.RemoveHandler("GET", "/") 61 62 resp, err := s.client.Get(s.server.URL + "/") 63 c.Assert(err, jc.ErrorIsNil) 64 defer resp.Body.Close() 65 66 c.Assert(resp.StatusCode, gc.Equals, http.StatusNotFound) 67 } 68 69 func (s *MuxSuite) TestAddHandlerExists(c *gc.C) { 70 s.mux.AddHandler("GET", "/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) 71 err := s.mux.AddHandler("GET", "/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) 72 c.Assert(err, gc.ErrorMatches, `handler for GET "/" already exists`) 73 } 74 75 func (s *MuxSuite) TestRemoveHandlerMissing(c *gc.C) { 76 s.mux.RemoveHandler("GET", "/") // no-op 77 } 78 79 func (s *MuxSuite) TestMethodNotSupported(c *gc.C) { 80 s.mux.AddHandler("POST", "/", http.NotFoundHandler()) 81 resp, err := s.client.Get(s.server.URL + "/") 82 c.Assert(err, jc.ErrorIsNil) 83 defer resp.Body.Close() 84 85 c.Assert(resp.StatusCode, gc.Equals, http.StatusMethodNotAllowed) 86 } 87 88 func (s *MuxSuite) TestConcurrentAddHandler(c *gc.C) { 89 err := s.mux.AddHandler("GET", "/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) 90 c.Assert(err, jc.ErrorIsNil) 91 92 // Concurrently add and remove another handler to show that 93 // adding and removing handlers will not race with request 94 // handling. 95 const N = 1000 96 var wg sync.WaitGroup 97 wg.Add(1) 98 go func() { 99 defer wg.Done() 100 for i := 0; i < N; i++ { 101 s.mux.AddHandler("POST", "/", http.NotFoundHandler()) 102 s.mux.RemoveHandler("POST", "/") 103 } 104 }() 105 defer wg.Wait() 106 107 for i := 0; i < N; i++ { 108 resp, err := s.client.Get(s.server.URL + "/") 109 c.Assert(err, jc.ErrorIsNil) 110 resp.Body.Close() 111 c.Assert(resp.StatusCode, gc.Equals, http.StatusOK) 112 } 113 } 114 115 func (s *MuxSuite) TestConcurrentRemoveHandler(c *gc.C) { 116 h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) 117 118 // Concurrently add and remove another handler to show that 119 // adding and removing handlers will not race with request 120 // handling. 121 const N = 500 122 var wg sync.WaitGroup 123 wg.Add(1) 124 done := make(chan struct{}) 125 go func() { 126 defer wg.Done() 127 defer close(done) 128 for i := 0; i < N; i++ { 129 s.mux.AddHandler("GET", "/", h) 130 // Sleep to give the client a 131 // chance to hit the endpoint. 132 time.Sleep(time.Millisecond) 133 s.mux.RemoveHandler("GET", "/") 134 time.Sleep(time.Millisecond) 135 } 136 }() 137 defer wg.Wait() 138 139 var ok, notfound int 140 out: 141 for { 142 select { 143 case _, _ = <-done: 144 break out 145 default: 146 } 147 resp, err := s.client.Get(s.server.URL + "/") 148 c.Assert(err, jc.ErrorIsNil) 149 resp.Body.Close() 150 switch resp.StatusCode { 151 case http.StatusOK: 152 ok++ 153 case http.StatusNotFound: 154 notfound++ 155 default: 156 c.Fatalf( 157 "got status %d, expected %d or %d", 158 resp.StatusCode, 159 http.StatusOK, 160 http.StatusNotFound, 161 ) 162 } 163 } 164 c.Assert(ok, gc.Not(gc.Equals), 0) 165 c.Assert(notfound, gc.Not(gc.Equals), 0) 166 } 167 168 func (s *MuxSuite) TestWait(c *gc.C) { 169 // Check that mux.Wait() blocks until clients are all finished 170 // with it. 171 s.mux.AddClient() 172 s.mux.AddClient() 173 finished := make(chan struct{}) 174 go func() { 175 defer close(finished) 176 s.mux.Wait() 177 }() 178 179 select { 180 case <-finished: 181 c.Fatalf("should wait when there are clients") 182 case <-time.After(coretesting.ShortWait): 183 } 184 185 s.mux.ClientDone() 186 select { 187 case <-finished: 188 c.Fatalf("should wait when there is still a client") 189 case <-time.After(coretesting.ShortWait): 190 } 191 192 s.mux.ClientDone() 193 select { 194 case <-finished: 195 case <-time.After(coretesting.LongWait): 196 c.Fatalf("should finish once clients are done") 197 } 198 }