github.com/MerlinKodo/quic-go@v0.39.2/config_test.go (about) 1 package quic 2 3 import ( 4 "context" 5 "errors" 6 "fmt" 7 "net" 8 "reflect" 9 "time" 10 11 "github.com/MerlinKodo/quic-go/internal/protocol" 12 "github.com/MerlinKodo/quic-go/logging" 13 "github.com/MerlinKodo/quic-go/quicvarint" 14 15 . "github.com/onsi/ginkgo/v2" 16 . "github.com/onsi/gomega" 17 ) 18 19 var _ = Describe("Config", func() { 20 Context("validating", func() { 21 It("validates a nil config", func() { 22 Expect(validateConfig(nil)).To(Succeed()) 23 }) 24 25 It("validates a config with normal values", func() { 26 conf := populateServerConfig(&Config{ 27 MaxIncomingStreams: 5, 28 MaxStreamReceiveWindow: 10, 29 }) 30 Expect(validateConfig(conf)).To(Succeed()) 31 Expect(conf.MaxIncomingStreams).To(BeEquivalentTo(5)) 32 Expect(conf.MaxStreamReceiveWindow).To(BeEquivalentTo(10)) 33 }) 34 35 It("clips too large values for the stream limits", func() { 36 conf := &Config{ 37 MaxIncomingStreams: 1<<60 + 1, 38 MaxIncomingUniStreams: 1<<60 + 2, 39 } 40 Expect(validateConfig(conf)).To(Succeed()) 41 Expect(conf.MaxIncomingStreams).To(BeEquivalentTo(int64(1 << 60))) 42 Expect(conf.MaxIncomingUniStreams).To(BeEquivalentTo(int64(1 << 60))) 43 }) 44 45 It("clips too large values for the flow control windows", func() { 46 conf := &Config{ 47 MaxStreamReceiveWindow: quicvarint.Max + 1, 48 MaxConnectionReceiveWindow: quicvarint.Max + 2, 49 } 50 Expect(validateConfig(conf)).To(Succeed()) 51 Expect(conf.MaxStreamReceiveWindow).To(BeEquivalentTo(uint64(quicvarint.Max))) 52 Expect(conf.MaxConnectionReceiveWindow).To(BeEquivalentTo(uint64(quicvarint.Max))) 53 }) 54 }) 55 56 configWithNonZeroNonFunctionFields := func() *Config { 57 c := &Config{} 58 v := reflect.ValueOf(c).Elem() 59 60 typ := v.Type() 61 for i := 0; i < typ.NumField(); i++ { 62 f := v.Field(i) 63 if !f.CanSet() { 64 // unexported field; not cloned. 65 continue 66 } 67 68 switch fn := typ.Field(i).Name; fn { 69 case "GetConfigForClient", "RequireAddressValidation", "GetLogWriter", "AllowConnectionWindowIncrease", "Tracer": 70 // Can't compare functions. 71 case "Versions": 72 f.Set(reflect.ValueOf([]VersionNumber{1, 2, 3})) 73 case "ConnectionIDLength": 74 f.Set(reflect.ValueOf(8)) 75 case "ConnectionIDGenerator": 76 f.Set(reflect.ValueOf(&protocol.DefaultConnectionIDGenerator{ConnLen: protocol.DefaultConnectionIDLength})) 77 case "HandshakeIdleTimeout": 78 f.Set(reflect.ValueOf(time.Second)) 79 case "MaxIdleTimeout": 80 f.Set(reflect.ValueOf(time.Hour)) 81 case "TokenStore": 82 f.Set(reflect.ValueOf(NewLRUTokenStore(2, 3))) 83 case "InitialStreamReceiveWindow": 84 f.Set(reflect.ValueOf(uint64(1234))) 85 case "MaxStreamReceiveWindow": 86 f.Set(reflect.ValueOf(uint64(9))) 87 case "InitialConnectionReceiveWindow": 88 f.Set(reflect.ValueOf(uint64(4321))) 89 case "MaxConnectionReceiveWindow": 90 f.Set(reflect.ValueOf(uint64(10))) 91 case "MaxIncomingStreams": 92 f.Set(reflect.ValueOf(int64(11))) 93 case "MaxIncomingUniStreams": 94 f.Set(reflect.ValueOf(int64(12))) 95 case "StatelessResetKey": 96 f.Set(reflect.ValueOf(&StatelessResetKey{1, 2, 3, 4})) 97 case "KeepAlivePeriod": 98 f.Set(reflect.ValueOf(time.Second)) 99 case "EnableDatagrams": 100 f.Set(reflect.ValueOf(true)) 101 case "DisableVersionNegotiationPackets": 102 f.Set(reflect.ValueOf(true)) 103 case "DisablePathMTUDiscovery": 104 f.Set(reflect.ValueOf(true)) 105 case "Allow0RTT": 106 f.Set(reflect.ValueOf(true)) 107 default: 108 Fail(fmt.Sprintf("all fields must be accounted for, but saw unknown field %q", fn)) 109 } 110 } 111 return c 112 } 113 114 It("uses twice the handshake idle timeouts for the handshake timeout", func() { 115 c := &Config{HandshakeIdleTimeout: time.Second * 11 / 2} 116 Expect(c.handshakeTimeout()).To(Equal(11 * time.Second)) 117 }) 118 119 Context("cloning", func() { 120 It("clones function fields", func() { 121 var calledAddrValidation, calledAllowConnectionWindowIncrease, calledTracer bool 122 c1 := &Config{ 123 GetConfigForClient: func(info *ClientHelloInfo) (*Config, error) { return nil, errors.New("nope") }, 124 AllowConnectionWindowIncrease: func(Connection, uint64) bool { calledAllowConnectionWindowIncrease = true; return true }, 125 RequireAddressValidation: func(net.Addr) bool { calledAddrValidation = true; return true }, 126 Tracer: func(context.Context, logging.Perspective, ConnectionID) *logging.ConnectionTracer { 127 calledTracer = true 128 return nil 129 }, 130 } 131 c2 := c1.Clone() 132 c2.RequireAddressValidation(&net.UDPAddr{}) 133 Expect(calledAddrValidation).To(BeTrue()) 134 c2.AllowConnectionWindowIncrease(nil, 1234) 135 Expect(calledAllowConnectionWindowIncrease).To(BeTrue()) 136 _, err := c2.GetConfigForClient(&ClientHelloInfo{}) 137 Expect(err).To(MatchError("nope")) 138 c2.Tracer(context.Background(), logging.PerspectiveClient, protocol.ConnectionID{}) 139 Expect(calledTracer).To(BeTrue()) 140 }) 141 142 It("clones non-function fields", func() { 143 c := configWithNonZeroNonFunctionFields() 144 Expect(c.Clone()).To(Equal(c)) 145 }) 146 147 It("returns a copy", func() { 148 c1 := &Config{ 149 MaxIncomingStreams: 100, 150 RequireAddressValidation: func(net.Addr) bool { return true }, 151 } 152 c2 := c1.Clone() 153 c2.MaxIncomingStreams = 200 154 c2.RequireAddressValidation = func(net.Addr) bool { return false } 155 156 Expect(c1.MaxIncomingStreams).To(BeEquivalentTo(100)) 157 Expect(c1.RequireAddressValidation(&net.UDPAddr{})).To(BeTrue()) 158 }) 159 }) 160 161 Context("populating", func() { 162 It("populates function fields", func() { 163 var calledAddrValidation bool 164 c1 := &Config{} 165 c1.RequireAddressValidation = func(net.Addr) bool { calledAddrValidation = true; return true } 166 c2 := populateConfig(c1) 167 c2.RequireAddressValidation(&net.UDPAddr{}) 168 Expect(calledAddrValidation).To(BeTrue()) 169 }) 170 171 It("copies non-function fields", func() { 172 c := configWithNonZeroNonFunctionFields() 173 Expect(populateConfig(c)).To(Equal(c)) 174 }) 175 176 It("populates empty fields with default values", func() { 177 c := populateConfig(&Config{}) 178 Expect(c.Versions).To(Equal(protocol.SupportedVersions)) 179 Expect(c.HandshakeIdleTimeout).To(Equal(protocol.DefaultHandshakeIdleTimeout)) 180 Expect(c.InitialStreamReceiveWindow).To(BeEquivalentTo(protocol.DefaultInitialMaxStreamData)) 181 Expect(c.MaxStreamReceiveWindow).To(BeEquivalentTo(protocol.DefaultMaxReceiveStreamFlowControlWindow)) 182 Expect(c.InitialConnectionReceiveWindow).To(BeEquivalentTo(protocol.DefaultInitialMaxData)) 183 Expect(c.MaxConnectionReceiveWindow).To(BeEquivalentTo(protocol.DefaultMaxReceiveConnectionFlowControlWindow)) 184 Expect(c.MaxIncomingStreams).To(BeEquivalentTo(protocol.DefaultMaxIncomingStreams)) 185 Expect(c.MaxIncomingUniStreams).To(BeEquivalentTo(protocol.DefaultMaxIncomingUniStreams)) 186 Expect(c.DisablePathMTUDiscovery).To(BeFalse()) 187 Expect(c.GetConfigForClient).To(BeNil()) 188 }) 189 190 It("populates empty fields with default values, for the server", func() { 191 c := populateServerConfig(&Config{}) 192 Expect(c.RequireAddressValidation).ToNot(BeNil()) 193 }) 194 }) 195 })