github.com/tumi8/quic-go@v0.37.4-tum/integrationtests/self/hotswap_test.go (about) 1 package self_test 2 3 import ( 4 "context" 5 "github.com/tumi8/quic-go" 6 "io" 7 "net" 8 "net/http" 9 "strconv" 10 "sync/atomic" 11 "time" 12 13 "github.com/tumi8/quic-go/http3" 14 15 . "github.com/onsi/ginkgo/v2" 16 . "github.com/onsi/gomega" 17 "github.com/onsi/gomega/gbytes" 18 ) 19 20 type listenerWrapper struct { 21 http3.QUICEarlyListener 22 listenerClosed bool 23 count int32 24 } 25 26 func (ln *listenerWrapper) Close() error { 27 ln.listenerClosed = true 28 return ln.QUICEarlyListener.Close() 29 } 30 31 func (ln *listenerWrapper) Faker() *fakeClosingListener { 32 atomic.AddInt32(&ln.count, 1) 33 ctx, cancel := context.WithCancel(context.Background()) 34 return &fakeClosingListener{ln, 0, ctx, cancel} 35 } 36 37 type fakeClosingListener struct { 38 *listenerWrapper 39 closed int32 40 ctx context.Context 41 cancel context.CancelFunc 42 } 43 44 func (ln *fakeClosingListener) Accept(ctx context.Context) (quic.EarlyConnection, error) { 45 Expect(ctx).To(Equal(context.Background())) 46 return ln.listenerWrapper.Accept(ln.ctx) 47 } 48 49 func (ln *fakeClosingListener) Close() error { 50 if atomic.CompareAndSwapInt32(&ln.closed, 0, 1) { 51 ln.cancel() 52 if atomic.AddInt32(&ln.listenerWrapper.count, -1) == 0 { 53 ln.listenerWrapper.Close() 54 } 55 } 56 return nil 57 } 58 59 var _ = Describe("HTTP3 Server hotswap test", func() { 60 var ( 61 mux1 *http.ServeMux 62 mux2 *http.ServeMux 63 client *http.Client 64 rt *http3.RoundTripper 65 server1 *http3.Server 66 server2 *http3.Server 67 ln *listenerWrapper 68 port string 69 ) 70 71 BeforeEach(func() { 72 mux1 = http.NewServeMux() 73 mux1.HandleFunc("/hello1", func(w http.ResponseWriter, r *http.Request) { 74 defer GinkgoRecover() 75 io.WriteString(w, "Hello, World 1!\n") // don't check the error here. Stream may be reset. 76 }) 77 78 mux2 = http.NewServeMux() 79 mux2.HandleFunc("/hello2", func(w http.ResponseWriter, r *http.Request) { 80 defer GinkgoRecover() 81 io.WriteString(w, "Hello, World 2!\n") // don't check the error here. Stream may be reset. 82 }) 83 84 server1 = &http3.Server{ 85 Handler: mux1, 86 QuicConfig: getQuicConfig(nil), 87 } 88 server2 = &http3.Server{ 89 Handler: mux2, 90 QuicConfig: getQuicConfig(nil), 91 } 92 93 tlsConf := http3.ConfigureTLSConfig(getTLSConfig()) 94 quicln, err := quic.ListenAddrEarly("0.0.0.0:0", tlsConf, getQuicConfig(nil)) 95 ln = &listenerWrapper{QUICEarlyListener: quicln} 96 Expect(err).NotTo(HaveOccurred()) 97 port = strconv.Itoa(ln.Addr().(*net.UDPAddr).Port) 98 }) 99 100 AfterEach(func() { 101 Expect(rt.Close()).NotTo(HaveOccurred()) 102 Expect(ln.Close()).NotTo(HaveOccurred()) 103 }) 104 105 BeforeEach(func() { 106 rt = &http3.RoundTripper{ 107 TLSClientConfig: getTLSClientConfig(), 108 DisableCompression: true, 109 QuicConfig: getQuicConfig(&quic.Config{MaxIdleTimeout: 10 * time.Second}), 110 } 111 client = &http.Client{Transport: rt} 112 }) 113 114 It("hotswap works", func() { 115 // open first server and make single request to it 116 fake1 := ln.Faker() 117 stoppedServing1 := make(chan struct{}) 118 go func() { 119 defer GinkgoRecover() 120 server1.ServeListener(fake1) 121 close(stoppedServing1) 122 }() 123 124 resp, err := client.Get("https://localhost:" + port + "/hello1") 125 Expect(err).ToNot(HaveOccurred()) 126 Expect(resp.StatusCode).To(Equal(200)) 127 body, err := io.ReadAll(gbytes.TimeoutReader(resp.Body, 3*time.Second)) 128 Expect(err).ToNot(HaveOccurred()) 129 Expect(string(body)).To(Equal("Hello, World 1!\n")) 130 131 // open second server with same underlying listener, 132 // make sure it opened and both servers are currently running 133 fake2 := ln.Faker() 134 stoppedServing2 := make(chan struct{}) 135 go func() { 136 defer GinkgoRecover() 137 server2.ServeListener(fake2) 138 close(stoppedServing2) 139 }() 140 141 Consistently(stoppedServing1).ShouldNot(BeClosed()) 142 Consistently(stoppedServing2).ShouldNot(BeClosed()) 143 144 // now close first server, no errors should occur here 145 // and only the fake listener should be closed 146 Expect(server1.Close()).NotTo(HaveOccurred()) 147 Eventually(stoppedServing1).Should(BeClosed()) 148 Expect(fake1.closed).To(Equal(int32(1))) 149 Expect(fake2.closed).To(Equal(int32(0))) 150 Expect(ln.listenerClosed).ToNot(BeTrue()) 151 Expect(client.Transport.(*http3.RoundTripper).Close()).NotTo(HaveOccurred()) 152 153 // verify that new connections are being initiated from the second server now 154 resp, err = client.Get("https://localhost:" + port + "/hello2") 155 Expect(err).ToNot(HaveOccurred()) 156 Expect(resp.StatusCode).To(Equal(200)) 157 body, err = io.ReadAll(gbytes.TimeoutReader(resp.Body, 3*time.Second)) 158 Expect(err).ToNot(HaveOccurred()) 159 Expect(string(body)).To(Equal("Hello, World 2!\n")) 160 161 // close the other server - both the fake and the actual listeners must close now 162 Expect(server2.Close()).NotTo(HaveOccurred()) 163 Eventually(stoppedServing2).Should(BeClosed()) 164 Expect(fake2.closed).To(Equal(int32(1))) 165 Expect(ln.listenerClosed).To(BeTrue()) 166 }) 167 })