github.com/apernet/quic-go@v0.43.1-0.20240515053213-5e9e635fd9f0/conn_id_generator_test.go (about)

     1  package quic
     2  
     3  import (
     4  	"fmt"
     5  
     6  	"github.com/apernet/quic-go/internal/protocol"
     7  	"github.com/apernet/quic-go/internal/qerr"
     8  	"github.com/apernet/quic-go/internal/wire"
     9  
    10  	. "github.com/onsi/ginkgo/v2"
    11  	. "github.com/onsi/gomega"
    12  )
    13  
    14  var _ = Describe("Connection ID Generator", func() {
    15  	var (
    16  		addedConnIDs       []protocol.ConnectionID
    17  		retiredConnIDs     []protocol.ConnectionID
    18  		removedConnIDs     []protocol.ConnectionID
    19  		replacedWithClosed []protocol.ConnectionID
    20  		queuedFrames       []wire.Frame
    21  		g                  *connIDGenerator
    22  	)
    23  	initialConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7})
    24  	initialClientDestConnID := protocol.ParseConnectionID([]byte{0xa, 0xb, 0xc, 0xd, 0xe})
    25  
    26  	connIDToToken := func(c protocol.ConnectionID) protocol.StatelessResetToken {
    27  		b := c.Bytes()[0]
    28  		return protocol.StatelessResetToken{b, b, b, b, b, b, b, b, b, b, b, b, b, b, b, b}
    29  	}
    30  
    31  	BeforeEach(func() {
    32  		addedConnIDs = nil
    33  		retiredConnIDs = nil
    34  		removedConnIDs = nil
    35  		queuedFrames = nil
    36  		replacedWithClosed = nil
    37  		g = newConnIDGenerator(
    38  			initialConnID,
    39  			&initialClientDestConnID,
    40  			func(c protocol.ConnectionID) { addedConnIDs = append(addedConnIDs, c) },
    41  			connIDToToken,
    42  			func(c protocol.ConnectionID) { removedConnIDs = append(removedConnIDs, c) },
    43  			func(c protocol.ConnectionID) { retiredConnIDs = append(retiredConnIDs, c) },
    44  			func(cs []protocol.ConnectionID, _ []byte) { replacedWithClosed = append(replacedWithClosed, cs...) },
    45  			func(f wire.Frame) { queuedFrames = append(queuedFrames, f) },
    46  			&protocol.DefaultConnectionIDGenerator{ConnLen: initialConnID.Len()},
    47  		)
    48  	})
    49  
    50  	It("issues new connection IDs", func() {
    51  		Expect(g.SetMaxActiveConnIDs(4)).To(Succeed())
    52  		Expect(retiredConnIDs).To(BeEmpty())
    53  		Expect(addedConnIDs).To(HaveLen(3))
    54  		for i := 0; i < len(addedConnIDs)-1; i++ {
    55  			Expect(addedConnIDs[i]).ToNot(Equal(addedConnIDs[i+1]))
    56  		}
    57  		Expect(queuedFrames).To(HaveLen(3))
    58  		for i := 0; i < 3; i++ {
    59  			f := queuedFrames[i]
    60  			Expect(f).To(BeAssignableToTypeOf(&wire.NewConnectionIDFrame{}))
    61  			nf := f.(*wire.NewConnectionIDFrame)
    62  			Expect(nf.SequenceNumber).To(BeEquivalentTo(i + 1))
    63  			Expect(nf.ConnectionID.Len()).To(Equal(7))
    64  			Expect(nf.StatelessResetToken).To(Equal(connIDToToken(nf.ConnectionID)))
    65  		}
    66  	})
    67  
    68  	It("limits the number of connection IDs that it issues", func() {
    69  		Expect(g.SetMaxActiveConnIDs(9999999)).To(Succeed())
    70  		Expect(retiredConnIDs).To(BeEmpty())
    71  		Expect(addedConnIDs).To(HaveLen(protocol.MaxIssuedConnectionIDs - 1))
    72  		Expect(queuedFrames).To(HaveLen(protocol.MaxIssuedConnectionIDs - 1))
    73  	})
    74  
    75  	// SetMaxActiveConnIDs is called twice when dialing a 0-RTT connection:
    76  	// once for the restored from the old connections, once when we receive the transport parameters
    77  	Context("dealing with 0-RTT", func() {
    78  		It("doesn't issue new connection IDs when SetMaxActiveConnIDs is called with the same value", func() {
    79  			Expect(g.SetMaxActiveConnIDs(4)).To(Succeed())
    80  			Expect(queuedFrames).To(HaveLen(3))
    81  			queuedFrames = nil
    82  			Expect(g.SetMaxActiveConnIDs(4)).To(Succeed())
    83  			Expect(queuedFrames).To(BeEmpty())
    84  		})
    85  
    86  		It("issues more connection IDs if the server allows a higher limit on the resumed connection", func() {
    87  			Expect(g.SetMaxActiveConnIDs(3)).To(Succeed())
    88  			Expect(queuedFrames).To(HaveLen(2))
    89  			queuedFrames = nil
    90  			Expect(g.SetMaxActiveConnIDs(6)).To(Succeed())
    91  			Expect(queuedFrames).To(HaveLen(3))
    92  		})
    93  
    94  		It("issues more connection IDs if the server allows a higher limit on the resumed connection, when connection IDs were retired in between", func() {
    95  			Expect(g.SetMaxActiveConnIDs(3)).To(Succeed())
    96  			Expect(queuedFrames).To(HaveLen(2))
    97  			queuedFrames = nil
    98  			g.Retire(1, protocol.ConnectionID{})
    99  			Expect(queuedFrames).To(HaveLen(1))
   100  			queuedFrames = nil
   101  			Expect(g.SetMaxActiveConnIDs(6)).To(Succeed())
   102  			Expect(queuedFrames).To(HaveLen(3))
   103  		})
   104  	})
   105  
   106  	It("errors if the peers tries to retire a connection ID that wasn't yet issued", func() {
   107  		Expect(g.Retire(1, protocol.ConnectionID{})).To(MatchError(&qerr.TransportError{
   108  			ErrorCode:    qerr.ProtocolViolation,
   109  			ErrorMessage: "retired connection ID 1 (highest issued: 0)",
   110  		}))
   111  	})
   112  
   113  	It("errors if the peers tries to retire a connection ID in a packet with that connection ID", func() {
   114  		Expect(g.SetMaxActiveConnIDs(4)).To(Succeed())
   115  		Expect(queuedFrames).ToNot(BeEmpty())
   116  		Expect(queuedFrames[0]).To(BeAssignableToTypeOf(&wire.NewConnectionIDFrame{}))
   117  		f := queuedFrames[0].(*wire.NewConnectionIDFrame)
   118  		Expect(g.Retire(f.SequenceNumber, f.ConnectionID)).To(MatchError(&qerr.TransportError{
   119  			ErrorCode:    qerr.ProtocolViolation,
   120  			ErrorMessage: fmt.Sprintf("retired connection ID %d (%s), which was used as the Destination Connection ID on this packet", f.SequenceNumber, f.ConnectionID),
   121  		}))
   122  	})
   123  
   124  	It("issues new connection IDs, when old ones are retired", func() {
   125  		Expect(g.SetMaxActiveConnIDs(5)).To(Succeed())
   126  		queuedFrames = nil
   127  		Expect(retiredConnIDs).To(BeEmpty())
   128  		Expect(g.Retire(3, protocol.ConnectionID{})).To(Succeed())
   129  		Expect(queuedFrames).To(HaveLen(1))
   130  		Expect(queuedFrames[0]).To(BeAssignableToTypeOf(&wire.NewConnectionIDFrame{}))
   131  		nf := queuedFrames[0].(*wire.NewConnectionIDFrame)
   132  		Expect(nf.SequenceNumber).To(BeEquivalentTo(5))
   133  		Expect(nf.ConnectionID.Len()).To(Equal(7))
   134  	})
   135  
   136  	It("retires the initial connection ID", func() {
   137  		Expect(g.Retire(0, protocol.ConnectionID{})).To(Succeed())
   138  		Expect(removedConnIDs).To(BeEmpty())
   139  		Expect(retiredConnIDs).To(HaveLen(1))
   140  		Expect(retiredConnIDs[0]).To(Equal(initialConnID))
   141  		Expect(addedConnIDs).To(BeEmpty())
   142  	})
   143  
   144  	It("handles duplicate retirements", func() {
   145  		Expect(g.SetMaxActiveConnIDs(11)).To(Succeed())
   146  		queuedFrames = nil
   147  		Expect(retiredConnIDs).To(BeEmpty())
   148  		Expect(g.Retire(5, protocol.ConnectionID{})).To(Succeed())
   149  		Expect(retiredConnIDs).To(HaveLen(1))
   150  		Expect(queuedFrames).To(HaveLen(1))
   151  		Expect(g.Retire(5, protocol.ConnectionID{})).To(Succeed())
   152  		Expect(retiredConnIDs).To(HaveLen(1))
   153  		Expect(queuedFrames).To(HaveLen(1))
   154  	})
   155  
   156  	It("retires the client's initial destination connection ID when the handshake completes", func() {
   157  		g.SetHandshakeComplete()
   158  		Expect(retiredConnIDs).To(HaveLen(1))
   159  		Expect(retiredConnIDs[0]).To(Equal(initialClientDestConnID))
   160  	})
   161  
   162  	It("removes all connection IDs", func() {
   163  		Expect(g.SetMaxActiveConnIDs(5)).To(Succeed())
   164  		Expect(queuedFrames).To(HaveLen(4))
   165  		g.RemoveAll()
   166  		Expect(removedConnIDs).To(HaveLen(6)) // initial conn ID, initial client dest conn id, and newly issued ones
   167  		Expect(removedConnIDs).To(ContainElement(initialConnID))
   168  		Expect(removedConnIDs).To(ContainElement(initialClientDestConnID))
   169  		for _, f := range queuedFrames {
   170  			nf := f.(*wire.NewConnectionIDFrame)
   171  			Expect(removedConnIDs).To(ContainElement(nf.ConnectionID))
   172  		}
   173  	})
   174  
   175  	It("replaces with a closed connection for all connection IDs", func() {
   176  		Expect(g.SetMaxActiveConnIDs(5)).To(Succeed())
   177  		Expect(queuedFrames).To(HaveLen(4))
   178  		g.ReplaceWithClosed([]byte("foobar"))
   179  		Expect(replacedWithClosed).To(HaveLen(6)) // initial conn ID, initial client dest conn id, and newly issued ones
   180  		Expect(replacedWithClosed).To(ContainElement(initialClientDestConnID))
   181  		Expect(replacedWithClosed).To(ContainElement(initialConnID))
   182  		for _, f := range queuedFrames {
   183  			nf := f.(*wire.NewConnectionIDFrame)
   184  			Expect(replacedWithClosed).To(ContainElement(nf.ConnectionID))
   185  		}
   186  	})
   187  })