github.com/MerlinKodo/quic-go@v0.39.2/conn_id_generator_test.go (about)

     1  package quic
     2  
     3  import (
     4  	"fmt"
     5  
     6  	"github.com/MerlinKodo/quic-go/internal/protocol"
     7  	"github.com/MerlinKodo/quic-go/internal/qerr"
     8  	"github.com/MerlinKodo/quic-go/internal/wire"
     9  
    10  	. "github.com/onsi/ginkgo/v2"
    11  	. "github.com/onsi/gomega"
    12  )
    13  
    14  var _ = Describe("Connection ID Generator", func() {
    15  	var (
    16  		addedConnIDs       []protocol.ConnectionID
    17  		retiredConnIDs     []protocol.ConnectionID
    18  		removedConnIDs     []protocol.ConnectionID
    19  		replacedWithClosed []protocol.ConnectionID
    20  		queuedFrames       []wire.Frame
    21  		g                  *connIDGenerator
    22  	)
    23  	initialConnID := protocol.ParseConnectionID([]byte{1, 2, 3, 4, 5, 6, 7})
    24  	initialClientDestConnID := protocol.ParseConnectionID([]byte{0xa, 0xb, 0xc, 0xd, 0xe})
    25  
    26  	connIDToToken := func(c protocol.ConnectionID) protocol.StatelessResetToken {
    27  		b := c.Bytes()[0]
    28  		return protocol.StatelessResetToken{b, b, b, b, b, b, b, b, b, b, b, b, b, b, b, b}
    29  	}
    30  
    31  	BeforeEach(func() {
    32  		addedConnIDs = nil
    33  		retiredConnIDs = nil
    34  		removedConnIDs = nil
    35  		queuedFrames = nil
    36  		replacedWithClosed = nil
    37  		g = newConnIDGenerator(
    38  			initialConnID,
    39  			&initialClientDestConnID,
    40  			func(c protocol.ConnectionID) { addedConnIDs = append(addedConnIDs, c) },
    41  			connIDToToken,
    42  			func(c protocol.ConnectionID) { removedConnIDs = append(removedConnIDs, c) },
    43  			func(c protocol.ConnectionID) { retiredConnIDs = append(retiredConnIDs, c) },
    44  			func(cs []protocol.ConnectionID, _ protocol.Perspective, _ []byte) {
    45  				replacedWithClosed = append(replacedWithClosed, cs...)
    46  			},
    47  			func(f wire.Frame) { queuedFrames = append(queuedFrames, f) },
    48  			&protocol.DefaultConnectionIDGenerator{ConnLen: initialConnID.Len()},
    49  		)
    50  	})
    51  
    52  	It("issues new connection IDs", func() {
    53  		Expect(g.SetMaxActiveConnIDs(4)).To(Succeed())
    54  		Expect(retiredConnIDs).To(BeEmpty())
    55  		Expect(addedConnIDs).To(HaveLen(3))
    56  		for i := 0; i < len(addedConnIDs)-1; i++ {
    57  			Expect(addedConnIDs[i]).ToNot(Equal(addedConnIDs[i+1]))
    58  		}
    59  		Expect(queuedFrames).To(HaveLen(3))
    60  		for i := 0; i < 3; i++ {
    61  			f := queuedFrames[i]
    62  			Expect(f).To(BeAssignableToTypeOf(&wire.NewConnectionIDFrame{}))
    63  			nf := f.(*wire.NewConnectionIDFrame)
    64  			Expect(nf.SequenceNumber).To(BeEquivalentTo(i + 1))
    65  			Expect(nf.ConnectionID.Len()).To(Equal(7))
    66  			Expect(nf.StatelessResetToken).To(Equal(connIDToToken(nf.ConnectionID)))
    67  		}
    68  	})
    69  
    70  	It("limits the number of connection IDs that it issues", func() {
    71  		Expect(g.SetMaxActiveConnIDs(9999999)).To(Succeed())
    72  		Expect(retiredConnIDs).To(BeEmpty())
    73  		Expect(addedConnIDs).To(HaveLen(protocol.MaxIssuedConnectionIDs - 1))
    74  		Expect(queuedFrames).To(HaveLen(protocol.MaxIssuedConnectionIDs - 1))
    75  	})
    76  
    77  	// SetMaxActiveConnIDs is called twice when dialing a 0-RTT connection:
    78  	// once for the restored from the old connections, once when we receive the transport parameters
    79  	Context("dealing with 0-RTT", func() {
    80  		It("doesn't issue new connection IDs when SetMaxActiveConnIDs is called with the same value", func() {
    81  			Expect(g.SetMaxActiveConnIDs(4)).To(Succeed())
    82  			Expect(queuedFrames).To(HaveLen(3))
    83  			queuedFrames = nil
    84  			Expect(g.SetMaxActiveConnIDs(4)).To(Succeed())
    85  			Expect(queuedFrames).To(BeEmpty())
    86  		})
    87  
    88  		It("issues more connection IDs if the server allows a higher limit on the resumed connection", func() {
    89  			Expect(g.SetMaxActiveConnIDs(3)).To(Succeed())
    90  			Expect(queuedFrames).To(HaveLen(2))
    91  			queuedFrames = nil
    92  			Expect(g.SetMaxActiveConnIDs(6)).To(Succeed())
    93  			Expect(queuedFrames).To(HaveLen(3))
    94  		})
    95  
    96  		It("issues more connection IDs if the server allows a higher limit on the resumed connection, when connection IDs were retired in between", func() {
    97  			Expect(g.SetMaxActiveConnIDs(3)).To(Succeed())
    98  			Expect(queuedFrames).To(HaveLen(2))
    99  			queuedFrames = nil
   100  			g.Retire(1, protocol.ConnectionID{})
   101  			Expect(queuedFrames).To(HaveLen(1))
   102  			queuedFrames = nil
   103  			Expect(g.SetMaxActiveConnIDs(6)).To(Succeed())
   104  			Expect(queuedFrames).To(HaveLen(3))
   105  		})
   106  	})
   107  
   108  	It("errors if the peers tries to retire a connection ID that wasn't yet issued", func() {
   109  		Expect(g.Retire(1, protocol.ConnectionID{})).To(MatchError(&qerr.TransportError{
   110  			ErrorCode:    qerr.ProtocolViolation,
   111  			ErrorMessage: "retired connection ID 1 (highest issued: 0)",
   112  		}))
   113  	})
   114  
   115  	It("errors if the peers tries to retire a connection ID in a packet with that connection ID", func() {
   116  		Expect(g.SetMaxActiveConnIDs(4)).To(Succeed())
   117  		Expect(queuedFrames).ToNot(BeEmpty())
   118  		Expect(queuedFrames[0]).To(BeAssignableToTypeOf(&wire.NewConnectionIDFrame{}))
   119  		f := queuedFrames[0].(*wire.NewConnectionIDFrame)
   120  		Expect(g.Retire(f.SequenceNumber, f.ConnectionID)).To(MatchError(&qerr.TransportError{
   121  			ErrorCode:    qerr.ProtocolViolation,
   122  			ErrorMessage: fmt.Sprintf("retired connection ID %d (%s), which was used as the Destination Connection ID on this packet", f.SequenceNumber, f.ConnectionID),
   123  		}))
   124  	})
   125  
   126  	It("issues new connection IDs, when old ones are retired", func() {
   127  		Expect(g.SetMaxActiveConnIDs(5)).To(Succeed())
   128  		queuedFrames = nil
   129  		Expect(retiredConnIDs).To(BeEmpty())
   130  		Expect(g.Retire(3, protocol.ConnectionID{})).To(Succeed())
   131  		Expect(queuedFrames).To(HaveLen(1))
   132  		Expect(queuedFrames[0]).To(BeAssignableToTypeOf(&wire.NewConnectionIDFrame{}))
   133  		nf := queuedFrames[0].(*wire.NewConnectionIDFrame)
   134  		Expect(nf.SequenceNumber).To(BeEquivalentTo(5))
   135  		Expect(nf.ConnectionID.Len()).To(Equal(7))
   136  	})
   137  
   138  	It("retires the initial connection ID", func() {
   139  		Expect(g.Retire(0, protocol.ConnectionID{})).To(Succeed())
   140  		Expect(removedConnIDs).To(BeEmpty())
   141  		Expect(retiredConnIDs).To(HaveLen(1))
   142  		Expect(retiredConnIDs[0]).To(Equal(initialConnID))
   143  		Expect(addedConnIDs).To(BeEmpty())
   144  	})
   145  
   146  	It("handles duplicate retirements", func() {
   147  		Expect(g.SetMaxActiveConnIDs(11)).To(Succeed())
   148  		queuedFrames = nil
   149  		Expect(retiredConnIDs).To(BeEmpty())
   150  		Expect(g.Retire(5, protocol.ConnectionID{})).To(Succeed())
   151  		Expect(retiredConnIDs).To(HaveLen(1))
   152  		Expect(queuedFrames).To(HaveLen(1))
   153  		Expect(g.Retire(5, protocol.ConnectionID{})).To(Succeed())
   154  		Expect(retiredConnIDs).To(HaveLen(1))
   155  		Expect(queuedFrames).To(HaveLen(1))
   156  	})
   157  
   158  	It("retires the client's initial destination connection ID when the handshake completes", func() {
   159  		g.SetHandshakeComplete()
   160  		Expect(retiredConnIDs).To(HaveLen(1))
   161  		Expect(retiredConnIDs[0]).To(Equal(initialClientDestConnID))
   162  	})
   163  
   164  	It("removes all connection IDs", func() {
   165  		Expect(g.SetMaxActiveConnIDs(5)).To(Succeed())
   166  		Expect(queuedFrames).To(HaveLen(4))
   167  		g.RemoveAll()
   168  		Expect(removedConnIDs).To(HaveLen(6)) // initial conn ID, initial client dest conn id, and newly issued ones
   169  		Expect(removedConnIDs).To(ContainElement(initialConnID))
   170  		Expect(removedConnIDs).To(ContainElement(initialClientDestConnID))
   171  		for _, f := range queuedFrames {
   172  			nf := f.(*wire.NewConnectionIDFrame)
   173  			Expect(removedConnIDs).To(ContainElement(nf.ConnectionID))
   174  		}
   175  	})
   176  
   177  	It("replaces with a closed connection for all connection IDs", func() {
   178  		Expect(g.SetMaxActiveConnIDs(5)).To(Succeed())
   179  		Expect(queuedFrames).To(HaveLen(4))
   180  		g.ReplaceWithClosed(protocol.PerspectiveClient, []byte("foobar"))
   181  		Expect(replacedWithClosed).To(HaveLen(6)) // initial conn ID, initial client dest conn id, and newly issued ones
   182  		Expect(replacedWithClosed).To(ContainElement(initialClientDestConnID))
   183  		Expect(replacedWithClosed).To(ContainElement(initialConnID))
   184  		for _, f := range queuedFrames {
   185  			nf := f.(*wire.NewConnectionIDFrame)
   186  			Expect(replacedWithClosed).To(ContainElement(nf.ConnectionID))
   187  		}
   188  	})
   189  })