github.com/MerlinKodo/quic-go@v0.39.2/integrationtests/self/resumption_test.go (about)

     1  package self_test
     2  
     3  import (
     4  	"context"
     5  	"crypto/tls"
     6  	"fmt"
     7  	"net"
     8  	"time"
     9  
    10  	"github.com/MerlinKodo/quic-go"
    11  
    12  	. "github.com/onsi/ginkgo/v2"
    13  	. "github.com/onsi/gomega"
    14  )
    15  
    16  type clientSessionCache struct {
    17  	cache tls.ClientSessionCache
    18  
    19  	gets chan<- string
    20  	puts chan<- string
    21  }
    22  
    23  func newClientSessionCache(cache tls.ClientSessionCache, gets, puts chan<- string) *clientSessionCache {
    24  	return &clientSessionCache{
    25  		cache: cache,
    26  		gets:  gets,
    27  		puts:  puts,
    28  	}
    29  }
    30  
    31  var _ tls.ClientSessionCache = &clientSessionCache{}
    32  
    33  func (c *clientSessionCache) Get(sessionKey string) (*tls.ClientSessionState, bool) {
    34  	session, ok := c.cache.Get(sessionKey)
    35  	c.gets <- sessionKey
    36  	return session, ok
    37  }
    38  
    39  func (c *clientSessionCache) Put(sessionKey string, cs *tls.ClientSessionState) {
    40  	c.cache.Put(sessionKey, cs)
    41  	c.puts <- sessionKey
    42  }
    43  
    44  var _ = Describe("TLS session resumption", func() {
    45  	It("uses session resumption", func() {
    46  		server, err := quic.ListenAddr("localhost:0", getTLSConfig(), getQuicConfig(nil))
    47  		Expect(err).ToNot(HaveOccurred())
    48  		defer server.Close()
    49  
    50  		gets := make(chan string, 100)
    51  		puts := make(chan string, 100)
    52  		cache := newClientSessionCache(tls.NewLRUClientSessionCache(10), gets, puts)
    53  		tlsConf := getTLSClientConfig()
    54  		tlsConf.ClientSessionCache = cache
    55  		conn, err := quic.DialAddr(
    56  			context.Background(),
    57  			fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
    58  			tlsConf,
    59  			getQuicConfig(nil),
    60  		)
    61  		Expect(err).ToNot(HaveOccurred())
    62  		var sessionKey string
    63  		Eventually(puts).Should(Receive(&sessionKey))
    64  		Expect(conn.ConnectionState().TLS.DidResume).To(BeFalse())
    65  
    66  		serverConn, err := server.Accept(context.Background())
    67  		Expect(err).ToNot(HaveOccurred())
    68  		Expect(serverConn.ConnectionState().TLS.DidResume).To(BeFalse())
    69  
    70  		conn, err = quic.DialAddr(
    71  			context.Background(),
    72  			fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
    73  			tlsConf,
    74  			getQuicConfig(nil),
    75  		)
    76  		Expect(err).ToNot(HaveOccurred())
    77  		Expect(gets).To(Receive(Equal(sessionKey)))
    78  		Expect(conn.ConnectionState().TLS.DidResume).To(BeTrue())
    79  
    80  		serverConn, err = server.Accept(context.Background())
    81  		Expect(err).ToNot(HaveOccurred())
    82  		Expect(serverConn.ConnectionState().TLS.DidResume).To(BeTrue())
    83  	})
    84  
    85  	It("doesn't use session resumption, if the config disables it", func() {
    86  		sConf := getTLSConfig()
    87  		sConf.SessionTicketsDisabled = true
    88  		server, err := quic.ListenAddr("localhost:0", sConf, getQuicConfig(nil))
    89  		Expect(err).ToNot(HaveOccurred())
    90  		defer server.Close()
    91  
    92  		gets := make(chan string, 100)
    93  		puts := make(chan string, 100)
    94  		cache := newClientSessionCache(tls.NewLRUClientSessionCache(10), gets, puts)
    95  		tlsConf := getTLSClientConfig()
    96  		tlsConf.ClientSessionCache = cache
    97  		conn, err := quic.DialAddr(
    98  			context.Background(),
    99  			fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
   100  			tlsConf,
   101  			getQuicConfig(nil),
   102  		)
   103  		Expect(err).ToNot(HaveOccurred())
   104  		Consistently(puts).ShouldNot(Receive())
   105  		Expect(conn.ConnectionState().TLS.DidResume).To(BeFalse())
   106  
   107  		ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
   108  		defer cancel()
   109  		serverConn, err := server.Accept(ctx)
   110  		Expect(err).ToNot(HaveOccurred())
   111  		Expect(serverConn.ConnectionState().TLS.DidResume).To(BeFalse())
   112  
   113  		conn, err = quic.DialAddr(
   114  			context.Background(),
   115  			fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
   116  			tlsConf,
   117  			getQuicConfig(nil),
   118  		)
   119  		Expect(err).ToNot(HaveOccurred())
   120  		Expect(conn.ConnectionState().TLS.DidResume).To(BeFalse())
   121  
   122  		serverConn, err = server.Accept(context.Background())
   123  		Expect(err).ToNot(HaveOccurred())
   124  		Expect(serverConn.ConnectionState().TLS.DidResume).To(BeFalse())
   125  	})
   126  
   127  	It("doesn't use session resumption, if the config returned by GetConfigForClient disables it", func() {
   128  		sConf := &tls.Config{
   129  			GetConfigForClient: func(*tls.ClientHelloInfo) (*tls.Config, error) {
   130  				conf := getTLSConfig()
   131  				conf.SessionTicketsDisabled = true
   132  				return conf, nil
   133  			},
   134  		}
   135  
   136  		server, err := quic.ListenAddr("localhost:0", sConf, getQuicConfig(nil))
   137  		Expect(err).ToNot(HaveOccurred())
   138  		defer server.Close()
   139  
   140  		gets := make(chan string, 100)
   141  		puts := make(chan string, 100)
   142  		cache := newClientSessionCache(tls.NewLRUClientSessionCache(10), gets, puts)
   143  		tlsConf := getTLSClientConfig()
   144  		tlsConf.ClientSessionCache = cache
   145  		conn, err := quic.DialAddr(
   146  			context.Background(),
   147  			fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
   148  			tlsConf,
   149  			getQuicConfig(nil),
   150  		)
   151  		Expect(err).ToNot(HaveOccurred())
   152  		Consistently(puts).ShouldNot(Receive())
   153  		Expect(conn.ConnectionState().TLS.DidResume).To(BeFalse())
   154  
   155  		ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
   156  		defer cancel()
   157  		serverConn, err := server.Accept(ctx)
   158  		Expect(err).ToNot(HaveOccurred())
   159  		Expect(serverConn.ConnectionState().TLS.DidResume).To(BeFalse())
   160  
   161  		conn, err = quic.DialAddr(
   162  			context.Background(),
   163  			fmt.Sprintf("localhost:%d", server.Addr().(*net.UDPAddr).Port),
   164  			tlsConf,
   165  			getQuicConfig(nil),
   166  		)
   167  		Expect(err).ToNot(HaveOccurred())
   168  		Expect(conn.ConnectionState().TLS.DidResume).To(BeFalse())
   169  
   170  		serverConn, err = server.Accept(context.Background())
   171  		Expect(err).ToNot(HaveOccurred())
   172  		Expect(serverConn.ConnectionState().TLS.DidResume).To(BeFalse())
   173  	})
   174  })