github.com/mikelsr/quic-go@v0.36.1-0.20230701132136-1d9415b66898/config_test.go (about) 1 package quic 2 3 import ( 4 "context" 5 "errors" 6 "fmt" 7 "net" 8 "reflect" 9 "time" 10 11 "github.com/mikelsr/quic-go/internal/protocol" 12 "github.com/mikelsr/quic-go/logging" 13 "github.com/mikelsr/quic-go/quicvarint" 14 15 . "github.com/onsi/ginkgo/v2" 16 . "github.com/onsi/gomega" 17 ) 18 19 var _ = Describe("Config", func() { 20 Context("validating", func() { 21 It("validates a nil config", func() { 22 Expect(validateConfig(nil)).To(Succeed()) 23 }) 24 25 It("validates a config with normal values", func() { 26 conf := populateServerConfig(&Config{ 27 MaxIncomingStreams: 5, 28 MaxStreamReceiveWindow: 10, 29 }) 30 Expect(validateConfig(conf)).To(Succeed()) 31 Expect(conf.MaxIncomingStreams).To(BeEquivalentTo(5)) 32 Expect(conf.MaxStreamReceiveWindow).To(BeEquivalentTo(10)) 33 }) 34 35 It("clips too large values for the stream limits", func() { 36 conf := &Config{ 37 MaxIncomingStreams: 1<<60 + 1, 38 MaxIncomingUniStreams: 1<<60 + 2, 39 } 40 Expect(validateConfig(conf)).To(Succeed()) 41 Expect(conf.MaxIncomingStreams).To(BeEquivalentTo(int64(1 << 60))) 42 Expect(conf.MaxIncomingUniStreams).To(BeEquivalentTo(int64(1 << 60))) 43 }) 44 45 It("clips too large values for the flow control windows", func() { 46 conf := &Config{ 47 MaxStreamReceiveWindow: quicvarint.Max + 1, 48 MaxConnectionReceiveWindow: quicvarint.Max + 2, 49 } 50 Expect(validateConfig(conf)).To(Succeed()) 51 Expect(conf.MaxStreamReceiveWindow).To(BeEquivalentTo(uint64(quicvarint.Max))) 52 Expect(conf.MaxConnectionReceiveWindow).To(BeEquivalentTo(uint64(quicvarint.Max))) 53 }) 54 }) 55 56 configWithNonZeroNonFunctionFields := func() *Config { 57 c := &Config{} 58 v := reflect.ValueOf(c).Elem() 59 60 typ := v.Type() 61 for i := 0; i < typ.NumField(); i++ { 62 f := v.Field(i) 63 if !f.CanSet() { 64 // unexported field; not cloned. 65 continue 66 } 67 68 switch fn := typ.Field(i).Name; fn { 69 case "GetConfigForClient", "RequireAddressValidation", "GetLogWriter", "AllowConnectionWindowIncrease", "Tracer": 70 // Can't compare functions. 71 case "Versions": 72 f.Set(reflect.ValueOf([]VersionNumber{1, 2, 3})) 73 case "ConnectionIDLength": 74 f.Set(reflect.ValueOf(8)) 75 case "ConnectionIDGenerator": 76 f.Set(reflect.ValueOf(&protocol.DefaultConnectionIDGenerator{ConnLen: protocol.DefaultConnectionIDLength})) 77 case "HandshakeIdleTimeout": 78 f.Set(reflect.ValueOf(time.Second)) 79 case "MaxIdleTimeout": 80 f.Set(reflect.ValueOf(time.Hour)) 81 case "MaxTokenAge": 82 f.Set(reflect.ValueOf(2 * time.Hour)) 83 case "MaxRetryTokenAge": 84 f.Set(reflect.ValueOf(2 * time.Minute)) 85 case "TokenStore": 86 f.Set(reflect.ValueOf(NewLRUTokenStore(2, 3))) 87 case "InitialStreamReceiveWindow": 88 f.Set(reflect.ValueOf(uint64(1234))) 89 case "MaxStreamReceiveWindow": 90 f.Set(reflect.ValueOf(uint64(9))) 91 case "InitialConnectionReceiveWindow": 92 f.Set(reflect.ValueOf(uint64(4321))) 93 case "MaxConnectionReceiveWindow": 94 f.Set(reflect.ValueOf(uint64(10))) 95 case "MaxIncomingStreams": 96 f.Set(reflect.ValueOf(int64(11))) 97 case "MaxIncomingUniStreams": 98 f.Set(reflect.ValueOf(int64(12))) 99 case "StatelessResetKey": 100 f.Set(reflect.ValueOf(&StatelessResetKey{1, 2, 3, 4})) 101 case "KeepAlivePeriod": 102 f.Set(reflect.ValueOf(time.Second)) 103 case "EnableDatagrams": 104 f.Set(reflect.ValueOf(true)) 105 case "DisableVersionNegotiationPackets": 106 f.Set(reflect.ValueOf(true)) 107 case "DisablePathMTUDiscovery": 108 f.Set(reflect.ValueOf(true)) 109 case "Allow0RTT": 110 f.Set(reflect.ValueOf(true)) 111 default: 112 Fail(fmt.Sprintf("all fields must be accounted for, but saw unknown field %q", fn)) 113 } 114 } 115 return c 116 } 117 118 It("uses 10s handshake timeout for short handshake idle timeouts", func() { 119 c := &Config{HandshakeIdleTimeout: time.Second} 120 Expect(c.handshakeTimeout()).To(Equal(protocol.DefaultHandshakeTimeout)) 121 }) 122 123 It("uses twice the handshake idle timeouts for the handshake timeout, for long handshake idle timeouts", func() { 124 c := &Config{HandshakeIdleTimeout: time.Second * 11 / 2} 125 Expect(c.handshakeTimeout()).To(Equal(11 * time.Second)) 126 }) 127 128 Context("cloning", func() { 129 It("clones function fields", func() { 130 var calledAddrValidation, calledAllowConnectionWindowIncrease, calledTracer bool 131 c1 := &Config{ 132 GetConfigForClient: func(info *ClientHelloInfo) (*Config, error) { return nil, errors.New("nope") }, 133 AllowConnectionWindowIncrease: func(Connection, uint64) bool { calledAllowConnectionWindowIncrease = true; return true }, 134 RequireAddressValidation: func(net.Addr) bool { calledAddrValidation = true; return true }, 135 Tracer: func(context.Context, logging.Perspective, ConnectionID) logging.ConnectionTracer { 136 calledTracer = true 137 return nil 138 }, 139 } 140 c2 := c1.Clone() 141 c2.RequireAddressValidation(&net.UDPAddr{}) 142 Expect(calledAddrValidation).To(BeTrue()) 143 c2.AllowConnectionWindowIncrease(nil, 1234) 144 Expect(calledAllowConnectionWindowIncrease).To(BeTrue()) 145 _, err := c2.GetConfigForClient(&ClientHelloInfo{}) 146 Expect(err).To(MatchError("nope")) 147 c2.Tracer(context.Background(), logging.PerspectiveClient, protocol.ConnectionID{}) 148 Expect(calledTracer).To(BeTrue()) 149 }) 150 151 It("clones non-function fields", func() { 152 c := configWithNonZeroNonFunctionFields() 153 Expect(c.Clone()).To(Equal(c)) 154 }) 155 156 It("returns a copy", func() { 157 c1 := &Config{ 158 MaxIncomingStreams: 100, 159 RequireAddressValidation: func(net.Addr) bool { return true }, 160 } 161 c2 := c1.Clone() 162 c2.MaxIncomingStreams = 200 163 c2.RequireAddressValidation = func(net.Addr) bool { return false } 164 165 Expect(c1.MaxIncomingStreams).To(BeEquivalentTo(100)) 166 Expect(c1.RequireAddressValidation(&net.UDPAddr{})).To(BeTrue()) 167 }) 168 }) 169 170 Context("populating", func() { 171 It("populates function fields", func() { 172 var calledAddrValidation bool 173 c1 := &Config{} 174 c1.RequireAddressValidation = func(net.Addr) bool { calledAddrValidation = true; return true } 175 c2 := populateConfig(c1) 176 c2.RequireAddressValidation(&net.UDPAddr{}) 177 Expect(calledAddrValidation).To(BeTrue()) 178 }) 179 180 It("copies non-function fields", func() { 181 c := configWithNonZeroNonFunctionFields() 182 Expect(populateConfig(c)).To(Equal(c)) 183 }) 184 185 It("populates empty fields with default values", func() { 186 c := populateConfig(&Config{}) 187 Expect(c.Versions).To(Equal(protocol.SupportedVersions)) 188 Expect(c.HandshakeIdleTimeout).To(Equal(protocol.DefaultHandshakeIdleTimeout)) 189 Expect(c.InitialStreamReceiveWindow).To(BeEquivalentTo(protocol.DefaultInitialMaxStreamData)) 190 Expect(c.MaxStreamReceiveWindow).To(BeEquivalentTo(protocol.DefaultMaxReceiveStreamFlowControlWindow)) 191 Expect(c.InitialConnectionReceiveWindow).To(BeEquivalentTo(protocol.DefaultInitialMaxData)) 192 Expect(c.MaxConnectionReceiveWindow).To(BeEquivalentTo(protocol.DefaultMaxReceiveConnectionFlowControlWindow)) 193 Expect(c.MaxIncomingStreams).To(BeEquivalentTo(protocol.DefaultMaxIncomingStreams)) 194 Expect(c.MaxIncomingUniStreams).To(BeEquivalentTo(protocol.DefaultMaxIncomingUniStreams)) 195 Expect(c.DisableVersionNegotiationPackets).To(BeFalse()) 196 Expect(c.DisablePathMTUDiscovery).To(BeFalse()) 197 Expect(c.GetConfigForClient).To(BeNil()) 198 }) 199 200 It("populates empty fields with default values, for the server", func() { 201 c := populateServerConfig(&Config{}) 202 Expect(c.RequireAddressValidation).ToNot(BeNil()) 203 }) 204 }) 205 })