github.com/daeuniverse/quic-go@v0.0.0-20240413031024-943f218e0810/integrationtests/self/resumption_test.go (about)

     1  package self_test
     2  
     3  import (
     4  	"context"
     5  	"crypto/tls"
     6  	"fmt"
     7  	"net"
     8  	"time"
     9  
    10  	"github.com/daeuniverse/quic-go"
    11  
    12  	. "github.com/onsi/ginkgo/v2"
    13  	. "github.com/onsi/gomega"
    14  )
    15  
    16  type clientSessionCache struct {
    17  	cache tls.ClientSessionCache
    18  
    19  	gets chan<- string
    20  	puts chan<- string
    21  }
    22  
    23  func newClientSessionCache(cache tls.ClientSessionCache, gets, puts chan<- string) *clientSessionCache {
    24  	return &clientSessionCache{
    25  		cache: cache,
    26  		gets:  gets,
    27  		puts:  puts,
    28  	}
    29  }
    30  
    31  var _ tls.ClientSessionCache = &clientSessionCache{}
    32  
    33  func (c *clientSessionCache) Get(sessionKey string) (*tls.ClientSessionState, bool) {
    34  	session, ok := c.cache.Get(sessionKey)
    35  	c.gets <- sessionKey
    36  	return session, ok
    37  }
    38  
    39  func (c *clientSessionCache) Put(sessionKey string, cs *tls.ClientSessionState) {
    40  	c.cache.Put(sessionKey, cs)
    41  	c.puts <- sessionKey
    42  }
    43  
    44  var _ = Describe("TLS session resumption", func() {
    45  	It("uses session resumption", func() {
    46  		server, err := quic.ListenAddr("localhost:0", getTLSConfig(), getQuicConfig(nil))
    47  		Expect(err).ToNot(HaveOccurred())
    48  		defer server.Close()
    49  
    50  		gets := make(chan string, 100)
    51  		puts := make(chan string, 100)
    52  		cache := newClientSessionCache(tls.NewLRUClientSessionCache(10), gets, puts)
    53  		tlsConf := getTLSClientConfig()
    54  		tlsConf.ClientSessionCache = cache
    55  		conn1, err := quic.DialAddr(
    56  			context.Background(),
    57  			fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
    58  			tlsConf,
    59  			getQuicConfig(nil),
    60  		)
    61  		Expect(err).ToNot(HaveOccurred())
    62  		defer conn1.CloseWithError(0, "")
    63  		var sessionKey string
    64  		Eventually(puts).Should(Receive(&sessionKey))
    65  		Expect(conn1.ConnectionState().TLS.DidResume).To(BeFalse())
    66  
    67  		serverConn, err := server.Accept(context.Background())
    68  		Expect(err).ToNot(HaveOccurred())
    69  		Expect(serverConn.ConnectionState().TLS.DidResume).To(BeFalse())
    70  
    71  		conn2, err := quic.DialAddr(
    72  			context.Background(),
    73  			fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
    74  			tlsConf,
    75  			getQuicConfig(nil),
    76  		)
    77  		Expect(err).ToNot(HaveOccurred())
    78  		Expect(gets).To(Receive(Equal(sessionKey)))
    79  		Expect(conn2.ConnectionState().TLS.DidResume).To(BeTrue())
    80  
    81  		serverConn, err = server.Accept(context.Background())
    82  		Expect(err).ToNot(HaveOccurred())
    83  		Expect(serverConn.ConnectionState().TLS.DidResume).To(BeTrue())
    84  		conn2.CloseWithError(0, "")
    85  	})
    86  
    87  	It("doesn't use session resumption, if the config disables it", func() {
    88  		sConf := getTLSConfig()
    89  		sConf.SessionTicketsDisabled = true
    90  		server, err := quic.ListenAddr("localhost:0", sConf, getQuicConfig(nil))
    91  		Expect(err).ToNot(HaveOccurred())
    92  		defer server.Close()
    93  
    94  		gets := make(chan string, 100)
    95  		puts := make(chan string, 100)
    96  		cache := newClientSessionCache(tls.NewLRUClientSessionCache(10), gets, puts)
    97  		tlsConf := getTLSClientConfig()
    98  		tlsConf.ClientSessionCache = cache
    99  		conn1, err := quic.DialAddr(
   100  			context.Background(),
   101  			fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
   102  			tlsConf,
   103  			getQuicConfig(nil),
   104  		)
   105  		Expect(err).ToNot(HaveOccurred())
   106  		defer conn1.CloseWithError(0, "")
   107  		Consistently(puts).ShouldNot(Receive())
   108  		Expect(conn1.ConnectionState().TLS.DidResume).To(BeFalse())
   109  
   110  		ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
   111  		defer cancel()
   112  		serverConn, err := server.Accept(ctx)
   113  		Expect(err).ToNot(HaveOccurred())
   114  		Expect(serverConn.ConnectionState().TLS.DidResume).To(BeFalse())
   115  
   116  		conn2, err := quic.DialAddr(
   117  			context.Background(),
   118  			fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
   119  			tlsConf,
   120  			getQuicConfig(nil),
   121  		)
   122  		Expect(err).ToNot(HaveOccurred())
   123  		Expect(conn2.ConnectionState().TLS.DidResume).To(BeFalse())
   124  		defer conn2.CloseWithError(0, "")
   125  
   126  		serverConn, err = server.Accept(context.Background())
   127  		Expect(err).ToNot(HaveOccurred())
   128  		Expect(serverConn.ConnectionState().TLS.DidResume).To(BeFalse())
   129  	})
   130  
   131  	It("doesn't use session resumption, if the config returned by GetConfigForClient disables it", func() {
   132  		sConf := &tls.Config{
   133  			GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) {
   134  				conf := getTLSConfig()
   135  				conf.SessionTicketsDisabled = true
   136  				return conf, nil
   137  			},
   138  		}
   139  
   140  		server, err := quic.ListenAddr("localhost:0", sConf, getQuicConfig(nil))
   141  		Expect(err).ToNot(HaveOccurred())
   142  		defer server.Close()
   143  
   144  		gets := make(chan string, 100)
   145  		puts := make(chan string, 100)
   146  		cache := newClientSessionCache(tls.NewLRUClientSessionCache(10), gets, puts)
   147  		tlsConf := getTLSClientConfig()
   148  		tlsConf.ClientSessionCache = cache
   149  		conn1, err := quic.DialAddr(
   150  			context.Background(),
   151  			fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
   152  			tlsConf,
   153  			getQuicConfig(nil),
   154  		)
   155  		Expect(err).ToNot(HaveOccurred())
   156  		Consistently(puts).ShouldNot(Receive())
   157  		Expect(conn1.ConnectionState().TLS.DidResume).To(BeFalse())
   158  		defer conn1.CloseWithError(0, "")
   159  
   160  		ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
   161  		defer cancel()
   162  		serverConn, err := server.Accept(ctx)
   163  		Expect(err).ToNot(HaveOccurred())
   164  		Expect(serverConn.ConnectionState().TLS.DidResume).To(BeFalse())
   165  
   166  		conn2, err := quic.DialAddr(
   167  			context.Background(),
   168  			fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
   169  			tlsConf,
   170  			getQuicConfig(nil),
   171  		)
   172  		Expect(err).ToNot(HaveOccurred())
   173  		Expect(conn2.ConnectionState().TLS.DidResume).To(BeFalse())
   174  		defer conn2.CloseWithError(0, "")
   175  
   176  		serverConn, err = server.Accept(context.Background())
   177  		Expect(err).ToNot(HaveOccurred())
   178  		Expect(serverConn.ConnectionState().TLS.DidResume).To(BeFalse())
   179  	})
   180  })