github.com/danielpfeifer02/quic-go-prio-packs@v0.41.0-28/integrationtests/self/hotswap_test.go (about)

     1  package self_test
     2  
     3  import (
     4  	"context"
     5  	"io"
     6  	"net"
     7  	"net/http"
     8  	"strconv"
     9  	"sync/atomic"
    10  	"time"
    11  
    12  	"github.com/danielpfeifer02/quic-go-prio-packs"
    13  	"github.com/danielpfeifer02/quic-go-prio-packs/http3"
    14  
    15  	. "github.com/onsi/ginkgo/v2"
    16  	. "github.com/onsi/gomega"
    17  	"github.com/onsi/gomega/gbytes"
    18  )
    19  
    20  type listenerWrapper struct {
    21  	http3.QUICEarlyListener
    22  	listenerClosed bool
    23  	count          atomic.Int32
    24  }
    25  
    26  func (ln *listenerWrapper) Close() error {
    27  	ln.listenerClosed = true
    28  	return ln.QUICEarlyListener.Close()
    29  }
    30  
    31  func (ln *listenerWrapper) Faker() *fakeClosingListener {
    32  	ln.count.Add(1)
    33  	ctx, cancel := context.WithCancel(context.Background())
    34  	return &fakeClosingListener{
    35  		listenerWrapper: ln,
    36  		ctx:             ctx,
    37  		cancel:          cancel,
    38  	}
    39  }
    40  
    41  type fakeClosingListener struct {
    42  	*listenerWrapper
    43  	closed atomic.Bool
    44  	ctx    context.Context
    45  	cancel context.CancelFunc
    46  }
    47  
    48  func (ln *fakeClosingListener) Accept(ctx context.Context) (quic.EarlyConnection, error) {
    49  	Expect(ctx).To(Equal(context.Background()))
    50  	return ln.listenerWrapper.Accept(ln.ctx)
    51  }
    52  
    53  func (ln *fakeClosingListener) Close() error {
    54  	if ln.closed.CompareAndSwap(false, true) {
    55  		ln.cancel()
    56  		if ln.listenerWrapper.count.Add(-1) == 0 {
    57  			ln.listenerWrapper.Close()
    58  		}
    59  	}
    60  	return nil
    61  }
    62  
    63  var _ = Describe("HTTP3 Server hotswap test", func() {
    64  	var (
    65  		mux1    *http.ServeMux
    66  		mux2    *http.ServeMux
    67  		client  *http.Client
    68  		rt      *http3.RoundTripper
    69  		server1 *http3.Server
    70  		server2 *http3.Server
    71  		ln      *listenerWrapper
    72  		port    string
    73  	)
    74  
    75  	BeforeEach(func() {
    76  		mux1 = http.NewServeMux()
    77  		mux1.HandleFunc("/hello1", func(w http.ResponseWriter, r *http.Request) {
    78  			defer GinkgoRecover()
    79  			io.WriteString(w, "Hello, World 1!\n") // don't check the error here. Stream may be reset.
    80  		})
    81  
    82  		mux2 = http.NewServeMux()
    83  		mux2.HandleFunc("/hello2", func(w http.ResponseWriter, r *http.Request) {
    84  			defer GinkgoRecover()
    85  			io.WriteString(w, "Hello, World 2!\n") // don't check the error here. Stream may be reset.
    86  		})
    87  
    88  		server1 = &http3.Server{
    89  			Handler:    mux1,
    90  			QuicConfig: getQuicConfig(nil),
    91  		}
    92  		server2 = &http3.Server{
    93  			Handler:    mux2,
    94  			QuicConfig: getQuicConfig(nil),
    95  		}
    96  
    97  		tlsConf := http3.ConfigureTLSConfig(getTLSConfig())
    98  		quicln, err := quic.ListenAddrEarly("0.0.0.0:0", tlsConf, getQuicConfig(nil))
    99  		ln = &listenerWrapper{QUICEarlyListener: quicln}
   100  		Expect(err).NotTo(HaveOccurred())
   101  		port = strconv.Itoa(ln.Addr().(*net.UDPAddr).Port)
   102  	})
   103  
   104  	AfterEach(func() {
   105  		Expect(rt.Close()).NotTo(HaveOccurred())
   106  		Expect(ln.Close()).NotTo(HaveOccurred())
   107  	})
   108  
   109  	BeforeEach(func() {
   110  		rt = &http3.RoundTripper{
   111  			TLSClientConfig:    getTLSClientConfig(),
   112  			DisableCompression: true,
   113  			QuicConfig:         getQuicConfig(&quic.Config{MaxIdleTimeout: 10 * time.Second}),
   114  		}
   115  		client = &http.Client{Transport: rt}
   116  	})
   117  
   118  	It("hotswap works", func() {
   119  		// open first server and make single request to it
   120  		fake1 := ln.Faker()
   121  		stoppedServing1 := make(chan struct{})
   122  		go func() {
   123  			defer GinkgoRecover()
   124  			server1.ServeListener(fake1)
   125  			close(stoppedServing1)
   126  		}()
   127  
   128  		resp, err := client.Get("https://localhost:" + port + "/hello1")
   129  		Expect(err).ToNot(HaveOccurred())
   130  		Expect(resp.StatusCode).To(Equal(200))
   131  		body, err := io.ReadAll(gbytes.TimeoutReader(resp.Body, 3*time.Second))
   132  		Expect(err).ToNot(HaveOccurred())
   133  		Expect(string(body)).To(Equal("Hello, World 1!\n"))
   134  
   135  		// open second server with same underlying listener,
   136  		// make sure it opened and both servers are currently running
   137  		fake2 := ln.Faker()
   138  		stoppedServing2 := make(chan struct{})
   139  		go func() {
   140  			defer GinkgoRecover()
   141  			server2.ServeListener(fake2)
   142  			close(stoppedServing2)
   143  		}()
   144  
   145  		Consistently(stoppedServing1).ShouldNot(BeClosed())
   146  		Consistently(stoppedServing2).ShouldNot(BeClosed())
   147  
   148  		// now close first server, no errors should occur here
   149  		// and only the fake listener should be closed
   150  		Expect(server1.Close()).NotTo(HaveOccurred())
   151  		Eventually(stoppedServing1).Should(BeClosed())
   152  		Expect(fake1.closed.Load()).To(BeTrue())
   153  		Expect(fake2.closed.Load()).To(BeFalse())
   154  		Expect(ln.listenerClosed).ToNot(BeTrue())
   155  		Expect(client.Transport.(*http3.RoundTripper).Close()).NotTo(HaveOccurred())
   156  
   157  		// verify that new connections are being initiated from the second server now
   158  		resp, err = client.Get("https://localhost:" + port + "/hello2")
   159  		Expect(err).ToNot(HaveOccurred())
   160  		Expect(resp.StatusCode).To(Equal(200))
   161  		body, err = io.ReadAll(gbytes.TimeoutReader(resp.Body, 3*time.Second))
   162  		Expect(err).ToNot(HaveOccurred())
   163  		Expect(string(body)).To(Equal("Hello, World 2!\n"))
   164  
   165  		// close the other server - both the fake and the actual listeners must close now
   166  		Expect(server2.Close()).NotTo(HaveOccurred())
   167  		Eventually(stoppedServing2).Should(BeClosed())
   168  		Expect(fake2.closed.Load()).To(BeTrue())
   169  		Expect(ln.listenerClosed).To(BeTrue())
   170  	})
   171  })