github.com/metacubex/quic-go@v0.44.1-0.20240520163451-20b689a59136/integrationtests/self/resumption_test.go (about)

     1  package self_test
     2  
     3  import (
     4  	"context"
     5  	"crypto/tls"
     6  	"fmt"
     7  	"net"
     8  	"time"
     9  
    10  	"github.com/metacubex/quic-go"
    11  
    12  	. "github.com/onsi/ginkgo/v2"
    13  	. "github.com/onsi/gomega"
    14  )
    15  
    16  type clientSessionCache struct {
    17  	cache tls.ClientSessionCache
    18  
    19  	gets chan<- string
    20  	puts chan<- string
    21  }
    22  
    23  func newClientSessionCache(cache tls.ClientSessionCache, gets, puts chan<- string) *clientSessionCache {
    24  	return &clientSessionCache{
    25  		cache: cache,
    26  		gets:  gets,
    27  		puts:  puts,
    28  	}
    29  }
    30  
    31  var _ tls.ClientSessionCache = &clientSessionCache{}
    32  
    33  func (c *clientSessionCache) Get(sessionKey string) (*tls.ClientSessionState, bool) {
    34  	session, ok := c.cache.Get(sessionKey)
    35  	if c.gets != nil {
    36  		c.gets <- sessionKey
    37  	}
    38  	return session, ok
    39  }
    40  
    41  func (c *clientSessionCache) Put(sessionKey string, cs *tls.ClientSessionState) {
    42  	c.cache.Put(sessionKey, cs)
    43  	if c.puts != nil {
    44  		c.puts <- sessionKey
    45  	}
    46  }
    47  
    48  var _ = Describe("TLS session resumption", func() {
    49  	It("uses session resumption", func() {
    50  		server, err := quic.ListenAddr("localhost:0", getTLSConfig(), getQuicConfig(nil))
    51  		Expect(err).ToNot(HaveOccurred())
    52  		defer server.Close()
    53  
    54  		gets := make(chan string, 100)
    55  		puts := make(chan string, 100)
    56  		cache := newClientSessionCache(tls.NewLRUClientSessionCache(10), gets, puts)
    57  		tlsConf := getTLSClientConfig()
    58  		tlsConf.ClientSessionCache = cache
    59  		conn1, err := quic.DialAddr(
    60  			context.Background(),
    61  			fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
    62  			tlsConf,
    63  			getQuicConfig(nil),
    64  		)
    65  		Expect(err).ToNot(HaveOccurred())
    66  		defer conn1.CloseWithError(0, "")
    67  		var sessionKey string
    68  		Eventually(puts).Should(Receive(&sessionKey))
    69  		Expect(conn1.ConnectionState().TLS.DidResume).To(BeFalse())
    70  
    71  		serverConn, err := server.Accept(context.Background())
    72  		Expect(err).ToNot(HaveOccurred())
    73  		Expect(serverConn.ConnectionState().TLS.DidResume).To(BeFalse())
    74  
    75  		conn2, err := quic.DialAddr(
    76  			context.Background(),
    77  			fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
    78  			tlsConf,
    79  			getQuicConfig(nil),
    80  		)
    81  		Expect(err).ToNot(HaveOccurred())
    82  		Expect(gets).To(Receive(Equal(sessionKey)))
    83  		Expect(conn2.ConnectionState().TLS.DidResume).To(BeTrue())
    84  
    85  		serverConn, err = server.Accept(context.Background())
    86  		Expect(err).ToNot(HaveOccurred())
    87  		Expect(serverConn.ConnectionState().TLS.DidResume).To(BeTrue())
    88  		conn2.CloseWithError(0, "")
    89  	})
    90  
    91  	It("doesn't use session resumption, if the config disables it", func() {
    92  		sConf := getTLSConfig()
    93  		sConf.SessionTicketsDisabled = true
    94  		server, err := quic.ListenAddr("localhost:0", sConf, getQuicConfig(nil))
    95  		Expect(err).ToNot(HaveOccurred())
    96  		defer server.Close()
    97  
    98  		gets := make(chan string, 100)
    99  		puts := make(chan string, 100)
   100  		cache := newClientSessionCache(tls.NewLRUClientSessionCache(10), gets, puts)
   101  		tlsConf := getTLSClientConfig()
   102  		tlsConf.ClientSessionCache = cache
   103  		conn1, err := quic.DialAddr(
   104  			context.Background(),
   105  			fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
   106  			tlsConf,
   107  			getQuicConfig(nil),
   108  		)
   109  		Expect(err).ToNot(HaveOccurred())
   110  		defer conn1.CloseWithError(0, "")
   111  		Consistently(puts).ShouldNot(Receive())
   112  		Expect(conn1.ConnectionState().TLS.DidResume).To(BeFalse())
   113  
   114  		ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
   115  		defer cancel()
   116  		serverConn, err := server.Accept(ctx)
   117  		Expect(err).ToNot(HaveOccurred())
   118  		Expect(serverConn.ConnectionState().TLS.DidResume).To(BeFalse())
   119  
   120  		conn2, err := quic.DialAddr(
   121  			context.Background(),
   122  			fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
   123  			tlsConf,
   124  			getQuicConfig(nil),
   125  		)
   126  		Expect(err).ToNot(HaveOccurred())
   127  		Expect(conn2.ConnectionState().TLS.DidResume).To(BeFalse())
   128  		defer conn2.CloseWithError(0, "")
   129  
   130  		serverConn, err = server.Accept(context.Background())
   131  		Expect(err).ToNot(HaveOccurred())
   132  		Expect(serverConn.ConnectionState().TLS.DidResume).To(BeFalse())
   133  	})
   134  
   135  	It("doesn't use session resumption, if the config returned by GetConfigForClient disables it", func() {
   136  		sConf := &tls.Config{
   137  			GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) {
   138  				conf := getTLSConfig()
   139  				conf.SessionTicketsDisabled = true
   140  				return conf, nil
   141  			},
   142  		}
   143  
   144  		server, err := quic.ListenAddr("localhost:0", sConf, getQuicConfig(nil))
   145  		Expect(err).ToNot(HaveOccurred())
   146  		defer server.Close()
   147  
   148  		gets := make(chan string, 100)
   149  		puts := make(chan string, 100)
   150  		cache := newClientSessionCache(tls.NewLRUClientSessionCache(10), gets, puts)
   151  		tlsConf := getTLSClientConfig()
   152  		tlsConf.ClientSessionCache = cache
   153  		conn1, err := quic.DialAddr(
   154  			context.Background(),
   155  			fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
   156  			tlsConf,
   157  			getQuicConfig(nil),
   158  		)
   159  		Expect(err).ToNot(HaveOccurred())
   160  		Consistently(puts).ShouldNot(Receive())
   161  		Expect(conn1.ConnectionState().TLS.DidResume).To(BeFalse())
   162  		defer conn1.CloseWithError(0, "")
   163  
   164  		ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
   165  		defer cancel()
   166  		serverConn, err := server.Accept(ctx)
   167  		Expect(err).ToNot(HaveOccurred())
   168  		Expect(serverConn.ConnectionState().TLS.DidResume).To(BeFalse())
   169  
   170  		conn2, err := quic.DialAddr(
   171  			context.Background(),
   172  			fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
   173  			tlsConf,
   174  			getQuicConfig(nil),
   175  		)
   176  		Expect(err).ToNot(HaveOccurred())
   177  		Expect(conn2.ConnectionState().TLS.DidResume).To(BeFalse())
   178  		defer conn2.CloseWithError(0, "")
   179  
   180  		serverConn, err = server.Accept(context.Background())
   181  		Expect(err).ToNot(HaveOccurred())
   182  		Expect(serverConn.ConnectionState().TLS.DidResume).To(BeFalse())
   183  	})
   184  })