github.com/tumi8/quic-go@v0.37.4-tum/config_test.go (about) 1 package quic 2 3 import ( 4 "context" 5 "errors" 6 "fmt" 7 "net" 8 "reflect" 9 "time" 10 11 "github.com/tumi8/quic-go/logging" 12 "github.com/tumi8/quic-go/noninternal/protocol" 13 "github.com/tumi8/quic-go/quicvarint" 14 15 . "github.com/onsi/ginkgo/v2" 16 . "github.com/onsi/gomega" 17 ) 18 19 var _ = Describe("Config", func() { 20 Context("validating", func() { 21 It("validates a nil config", func() { 22 Expect(validateConfig(nil)).To(Succeed()) 23 }) 24 25 It("validates a config with normal values", func() { 26 conf := populateServerConfig(&Config{ 27 MaxIncomingStreams: 5, 28 MaxStreamReceiveWindow: 10, 29 }) 30 Expect(validateConfig(conf)).To(Succeed()) 31 Expect(conf.MaxIncomingStreams).To(BeEquivalentTo(5)) 32 Expect(conf.MaxStreamReceiveWindow).To(BeEquivalentTo(10)) 33 }) 34 35 It("clips too large values for the stream limits", func() { 36 conf := &Config{ 37 MaxIncomingStreams: 1<<60 + 1, 38 MaxIncomingUniStreams: 1<<60 + 2, 39 } 40 Expect(validateConfig(conf)).To(Succeed()) 41 Expect(conf.MaxIncomingStreams).To(BeEquivalentTo(int64(1 << 60))) 42 Expect(conf.MaxIncomingUniStreams).To(BeEquivalentTo(int64(1 << 60))) 43 }) 44 45 It("clips too large values for the flow control windows", func() { 46 conf := &Config{ 47 MaxStreamReceiveWindow: quicvarint.Max + 1, 48 MaxConnectionReceiveWindow: quicvarint.Max + 2, 49 } 50 Expect(validateConfig(conf)).To(Succeed()) 51 Expect(conf.MaxStreamReceiveWindow).To(BeEquivalentTo(uint64(quicvarint.Max))) 52 Expect(conf.MaxConnectionReceiveWindow).To(BeEquivalentTo(uint64(quicvarint.Max))) 53 }) 54 }) 55 56 configWithNonZeroNonFunctionFields := func() *Config { 57 c := &Config{} 58 v := reflect.ValueOf(c).Elem() 59 60 typ := v.Type() 61 for i := 0; i < typ.NumField(); i++ { 62 f := v.Field(i) 63 if !f.CanSet() { 64 // unexported field; not cloned. 65 continue 66 } 67 68 switch fn := typ.Field(i).Name; fn { 69 case "GetConfigForClient", "RequireAddressValidation", "GetLogWriter", "AllowConnectionWindowIncrease", "Tracer": 70 // Can't compare functions. 71 case "Versions": 72 f.Set(reflect.ValueOf([]VersionNumber{1, 2, 3})) 73 case "ConnectionIDLength": 74 f.Set(reflect.ValueOf(8)) 75 case "ConnectionIDGenerator": 76 f.Set(reflect.ValueOf(&protocol.DefaultConnectionIDGenerator{ConnLen: protocol.DefaultConnectionIDLength})) 77 case "HandshakeIdleTimeout": 78 f.Set(reflect.ValueOf(time.Second)) 79 case "MaxIdleTimeout": 80 f.Set(reflect.ValueOf(time.Hour)) 81 case "MaxTokenAge": 82 f.Set(reflect.ValueOf(2 * time.Hour)) 83 case "MaxRetryTokenAge": 84 f.Set(reflect.ValueOf(2 * time.Minute)) 85 case "TokenStore": 86 f.Set(reflect.ValueOf(NewLRUTokenStore(2, 3))) 87 case "InitialStreamReceiveWindow": 88 f.Set(reflect.ValueOf(uint64(1234))) 89 case "MaxStreamReceiveWindow": 90 f.Set(reflect.ValueOf(uint64(9))) 91 case "InitialConnectionReceiveWindow": 92 f.Set(reflect.ValueOf(uint64(4321))) 93 case "MaxConnectionReceiveWindow": 94 f.Set(reflect.ValueOf(uint64(10))) 95 case "MaxIncomingStreams": 96 f.Set(reflect.ValueOf(int64(11))) 97 case "MaxIncomingUniStreams": 98 f.Set(reflect.ValueOf(int64(12))) 99 case "StatelessResetKey": 100 f.Set(reflect.ValueOf(&StatelessResetKey{1, 2, 3, 4})) 101 case "KeepAlivePeriod": 102 f.Set(reflect.ValueOf(time.Second)) 103 case "EnableDatagrams": 104 f.Set(reflect.ValueOf(true)) 105 case "DisableVersionNegotiationPackets": 106 f.Set(reflect.ValueOf(true)) 107 case "DisablePathMTUDiscovery": 108 f.Set(reflect.ValueOf(true)) 109 case "SCID": 110 case "DCID": 111 case "Allow0RTT": 112 f.Set(reflect.ValueOf(true)) 113 default: 114 Fail(fmt.Sprintf("all fields must be accounted for, but saw unknown field %q", fn)) 115 } 116 } 117 return c 118 } 119 120 It("uses 10s handshake timeout for short handshake idle timeouts", func() { 121 c := &Config{HandshakeIdleTimeout: time.Second} 122 Expect(c.handshakeTimeout()).To(Equal(protocol.DefaultHandshakeTimeout)) 123 }) 124 125 It("uses twice the handshake idle timeouts for the handshake timeout, for long handshake idle timeouts", func() { 126 c := &Config{HandshakeIdleTimeout: time.Second * 11 / 2} 127 Expect(c.handshakeTimeout()).To(Equal(11 * time.Second)) 128 }) 129 130 Context("cloning", func() { 131 It("clones function fields", func() { 132 var calledAddrValidation, calledAllowConnectionWindowIncrease, calledTracer bool 133 c1 := &Config{ 134 GetConfigForClient: func(info *ClientHelloInfo) (*Config, error) { return nil, errors.New("nope") }, 135 AllowConnectionWindowIncrease: func(Connection, uint64) bool { calledAllowConnectionWindowIncrease = true; return true }, 136 RequireAddressValidation: func(net.Addr) bool { calledAddrValidation = true; return true }, 137 Tracer: func(context.Context, logging.Perspective, ConnectionID) logging.ConnectionTracer { 138 calledTracer = true 139 return nil 140 }, 141 } 142 c2 := c1.Clone() 143 c2.RequireAddressValidation(&net.UDPAddr{}) 144 Expect(calledAddrValidation).To(BeTrue()) 145 c2.AllowConnectionWindowIncrease(nil, 1234) 146 Expect(calledAllowConnectionWindowIncrease).To(BeTrue()) 147 _, err := c2.GetConfigForClient(&ClientHelloInfo{}) 148 Expect(err).To(MatchError("nope")) 149 c2.Tracer(context.Background(), logging.PerspectiveClient, protocol.ConnectionID{}) 150 Expect(calledTracer).To(BeTrue()) 151 }) 152 153 It("clones non-function fields", func() { 154 c := configWithNonZeroNonFunctionFields() 155 Expect(c.Clone()).To(Equal(c)) 156 }) 157 158 It("returns a copy", func() { 159 c1 := &Config{ 160 MaxIncomingStreams: 100, 161 RequireAddressValidation: func(net.Addr) bool { return true }, 162 } 163 c2 := c1.Clone() 164 c2.MaxIncomingStreams = 200 165 c2.RequireAddressValidation = func(net.Addr) bool { return false } 166 167 Expect(c1.MaxIncomingStreams).To(BeEquivalentTo(100)) 168 Expect(c1.RequireAddressValidation(&net.UDPAddr{})).To(BeTrue()) 169 }) 170 }) 171 172 Context("populating", func() { 173 It("populates function fields", func() { 174 var calledAddrValidation bool 175 c1 := &Config{} 176 c1.RequireAddressValidation = func(net.Addr) bool { calledAddrValidation = true; return true } 177 c2 := populateConfig(c1) 178 c2.RequireAddressValidation(&net.UDPAddr{}) 179 Expect(calledAddrValidation).To(BeTrue()) 180 }) 181 182 It("copies non-function fields", func() { 183 c := configWithNonZeroNonFunctionFields() 184 Expect(populateConfig(c)).To(Equal(c)) 185 }) 186 187 It("populates empty fields with default values", func() { 188 c := populateConfig(&Config{}) 189 Expect(c.Versions).To(Equal(protocol.SupportedVersions)) 190 Expect(c.HandshakeIdleTimeout).To(Equal(protocol.DefaultHandshakeIdleTimeout)) 191 Expect(c.InitialStreamReceiveWindow).To(BeEquivalentTo(protocol.DefaultInitialMaxStreamData)) 192 Expect(c.MaxStreamReceiveWindow).To(BeEquivalentTo(protocol.DefaultMaxReceiveStreamFlowControlWindow)) 193 Expect(c.InitialConnectionReceiveWindow).To(BeEquivalentTo(protocol.DefaultInitialMaxData)) 194 Expect(c.MaxConnectionReceiveWindow).To(BeEquivalentTo(protocol.DefaultMaxReceiveConnectionFlowControlWindow)) 195 Expect(c.MaxIncomingStreams).To(BeEquivalentTo(protocol.DefaultMaxIncomingStreams)) 196 Expect(c.MaxIncomingUniStreams).To(BeEquivalentTo(protocol.DefaultMaxIncomingUniStreams)) 197 Expect(c.DisableVersionNegotiationPackets).To(BeFalse()) 198 Expect(c.DisablePathMTUDiscovery).To(BeFalse()) 199 Expect(c.GetConfigForClient).To(BeNil()) 200 }) 201 202 It("populates empty fields with default values, for the server", func() { 203 c := populateServerConfig(&Config{}) 204 Expect(c.RequireAddressValidation).ToNot(BeNil()) 205 }) 206 }) 207 })