github.com/MerlinKodo/quic-go@v0.39.2/integrationtests/self/hotswap_test.go (about)

     1  package self_test
     2  
     3  import (
     4  	"context"
     5  	"io"
     6  	"net"
     7  	"net/http"
     8  	"strconv"
     9  	"sync/atomic"
    10  	"time"
    11  
    12  	"github.com/MerlinKodo/quic-go"
    13  	"github.com/MerlinKodo/quic-go/http3"
    14  
    15  	. "github.com/onsi/ginkgo/v2"
    16  	. "github.com/onsi/gomega"
    17  	"github.com/onsi/gomega/gbytes"
    18  )
    19  
    20  type listenerWrapper struct {
    21  	http3.QUICEarlyListener
    22  	listenerClosed bool
    23  	count          int32
    24  }
    25  
    26  func (ln *listenerWrapper) Close() error {
    27  	ln.listenerClosed = true
    28  	return ln.QUICEarlyListener.Close()
    29  }
    30  
    31  func (ln *listenerWrapper) Faker() *fakeClosingListener {
    32  	atomic.AddInt32(&ln.count, 1)
    33  	ctx, cancel := context.WithCancel(context.Background())
    34  	return &fakeClosingListener{ln, 0, ctx, cancel}
    35  }
    36  
    37  type fakeClosingListener struct {
    38  	*listenerWrapper
    39  	closed int32
    40  	ctx    context.Context
    41  	cancel context.CancelFunc
    42  }
    43  
    44  func (ln *fakeClosingListener) Accept(ctx context.Context) (quic.EarlyConnection, error) {
    45  	Expect(ctx).To(Equal(context.Background()))
    46  	return ln.listenerWrapper.Accept(ln.ctx)
    47  }
    48  
    49  func (ln *fakeClosingListener) Close() error {
    50  	if atomic.CompareAndSwapInt32(&ln.closed, 0, 1) {
    51  		ln.cancel()
    52  		if atomic.AddInt32(&ln.listenerWrapper.count, -1) == 0 {
    53  			ln.listenerWrapper.Close()
    54  		}
    55  	}
    56  	return nil
    57  }
    58  
    59  var _ = Describe("HTTP3 Server hotswap test", func() {
    60  	var (
    61  		mux1    *http.ServeMux
    62  		mux2    *http.ServeMux
    63  		client  *http.Client
    64  		rt      *http3.RoundTripper
    65  		server1 *http3.Server
    66  		server2 *http3.Server
    67  		ln      *listenerWrapper
    68  		port    string
    69  	)
    70  
    71  	BeforeEach(func() {
    72  		mux1 = http.NewServeMux()
    73  		mux1.HandleFunc("/hello1", func(w http.ResponseWriter, r *http.Request) {
    74  			defer GinkgoRecover()
    75  			io.WriteString(w, "Hello, World 1!\n") // don't check the error here. Stream may be reset.
    76  		})
    77  
    78  		mux2 = http.NewServeMux()
    79  		mux2.HandleFunc("/hello2", func(w http.ResponseWriter, r *http.Request) {
    80  			defer GinkgoRecover()
    81  			io.WriteString(w, "Hello, World 2!\n") // don't check the error here. Stream may be reset.
    82  		})
    83  
    84  		server1 = &http3.Server{
    85  			Handler:    mux1,
    86  			QuicConfig: getQuicConfig(nil),
    87  		}
    88  		server2 = &http3.Server{
    89  			Handler:    mux2,
    90  			QuicConfig: getQuicConfig(nil),
    91  		}
    92  
    93  		tlsConf := http3.ConfigureTLSConfig(getTLSConfig())
    94  		quicln, err := quic.ListenAddrEarly("0.0.0.0:0", tlsConf, getQuicConfig(nil))
    95  		ln = &listenerWrapper{QUICEarlyListener: quicln}
    96  		Expect(err).NotTo(HaveOccurred())
    97  		port = strconv.Itoa(ln.Addr().(*net.UDPAddr).Port)
    98  	})
    99  
   100  	AfterEach(func() {
   101  		Expect(rt.Close()).NotTo(HaveOccurred())
   102  		Expect(ln.Close()).NotTo(HaveOccurred())
   103  	})
   104  
   105  	BeforeEach(func() {
   106  		rt = &http3.RoundTripper{
   107  			TLSClientConfig:    getTLSClientConfig(),
   108  			DisableCompression: true,
   109  			QuicConfig:         getQuicConfig(&quic.Config{MaxIdleTimeout: 10 * time.Second}),
   110  		}
   111  		client = &http.Client{Transport: rt}
   112  	})
   113  
   114  	It("hotswap works", func() {
   115  		// open first server and make single request to it
   116  		fake1 := ln.Faker()
   117  		stoppedServing1 := make(chan struct{})
   118  		go func() {
   119  			defer GinkgoRecover()
   120  			server1.ServeListener(fake1)
   121  			close(stoppedServing1)
   122  		}()
   123  
   124  		resp, err := client.Get("https://localhost:" + port + "/hello1")
   125  		Expect(err).ToNot(HaveOccurred())
   126  		Expect(resp.StatusCode).To(Equal(200))
   127  		body, err := io.ReadAll(gbytes.TimeoutReader(resp.Body, 3*time.Second))
   128  		Expect(err).ToNot(HaveOccurred())
   129  		Expect(string(body)).To(Equal("Hello, World 1!\n"))
   130  
   131  		// open second server with same underlying listener,
   132  		// make sure it opened and both servers are currently running
   133  		fake2 := ln.Faker()
   134  		stoppedServing2 := make(chan struct{})
   135  		go func() {
   136  			defer GinkgoRecover()
   137  			server2.ServeListener(fake2)
   138  			close(stoppedServing2)
   139  		}()
   140  
   141  		Consistently(stoppedServing1).ShouldNot(BeClosed())
   142  		Consistently(stoppedServing2).ShouldNot(BeClosed())
   143  
   144  		// now close first server, no errors should occur here
   145  		// and only the fake listener should be closed
   146  		Expect(server1.Close()).NotTo(HaveOccurred())
   147  		Eventually(stoppedServing1).Should(BeClosed())
   148  		Expect(fake1.closed).To(Equal(int32(1)))
   149  		Expect(fake2.closed).To(Equal(int32(0)))
   150  		Expect(ln.listenerClosed).ToNot(BeTrue())
   151  		Expect(client.Transport.(*http3.RoundTripper).Close()).NotTo(HaveOccurred())
   152  
   153  		// verify that new connections are being initiated from the second server now
   154  		resp, err = client.Get("https://localhost:" + port + "/hello2")
   155  		Expect(err).ToNot(HaveOccurred())
   156  		Expect(resp.StatusCode).To(Equal(200))
   157  		body, err = io.ReadAll(gbytes.TimeoutReader(resp.Body, 3*time.Second))
   158  		Expect(err).ToNot(HaveOccurred())
   159  		Expect(string(body)).To(Equal("Hello, World 2!\n"))
   160  
   161  		// close the other server - both the fake and the actual listeners must close now
   162  		Expect(server2.Close()).NotTo(HaveOccurred())
   163  		Eventually(stoppedServing2).Should(BeClosed())
   164  		Expect(fake2.closed).To(Equal(int32(1)))
   165  		Expect(ln.listenerClosed).To(BeTrue())
   166  	})
   167  })