github.com/tumi8/quic-go@v0.37.4-tum/config_test.go (about)

     1  package quic
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"net"
     8  	"reflect"
     9  	"time"
    10  
    11  	"github.com/tumi8/quic-go/logging"
    12  	"github.com/tumi8/quic-go/noninternal/protocol"
    13  	"github.com/tumi8/quic-go/quicvarint"
    14  
    15  	. "github.com/onsi/ginkgo/v2"
    16  	. "github.com/onsi/gomega"
    17  )
    18  
    19  var _ = Describe("Config", func() {
    20  	Context("validating", func() {
    21  		It("validates a nil config", func() {
    22  			Expect(validateConfig(nil)).To(Succeed())
    23  		})
    24  
    25  		It("validates a config with normal values", func() {
    26  			conf := populateServerConfig(&Config{
    27  				MaxIncomingStreams:     5,
    28  				MaxStreamReceiveWindow: 10,
    29  			})
    30  			Expect(validateConfig(conf)).To(Succeed())
    31  			Expect(conf.MaxIncomingStreams).To(BeEquivalentTo(5))
    32  			Expect(conf.MaxStreamReceiveWindow).To(BeEquivalentTo(10))
    33  		})
    34  
    35  		It("clips too large values for the stream limits", func() {
    36  			conf := &Config{
    37  				MaxIncomingStreams:    1<<60 + 1,
    38  				MaxIncomingUniStreams: 1<<60 + 2,
    39  			}
    40  			Expect(validateConfig(conf)).To(Succeed())
    41  			Expect(conf.MaxIncomingStreams).To(BeEquivalentTo(int64(1 << 60)))
    42  			Expect(conf.MaxIncomingUniStreams).To(BeEquivalentTo(int64(1 << 60)))
    43  		})
    44  
    45  		It("clips too large values for the flow control windows", func() {
    46  			conf := &Config{
    47  				MaxStreamReceiveWindow:     quicvarint.Max + 1,
    48  				MaxConnectionReceiveWindow: quicvarint.Max + 2,
    49  			}
    50  			Expect(validateConfig(conf)).To(Succeed())
    51  			Expect(conf.MaxStreamReceiveWindow).To(BeEquivalentTo(uint64(quicvarint.Max)))
    52  			Expect(conf.MaxConnectionReceiveWindow).To(BeEquivalentTo(uint64(quicvarint.Max)))
    53  		})
    54  	})
    55  
    56  	configWithNonZeroNonFunctionFields := func() *Config {
    57  		c := &Config{}
    58  		v := reflect.ValueOf(c).Elem()
    59  
    60  		typ := v.Type()
    61  		for i := 0; i < typ.NumField(); i++ {
    62  			f := v.Field(i)
    63  			if !f.CanSet() {
    64  				// unexported field; not cloned.
    65  				continue
    66  			}
    67  
    68  			switch fn := typ.Field(i).Name; fn {
    69  			case "GetConfigForClient", "RequireAddressValidation", "GetLogWriter", "AllowConnectionWindowIncrease", "Tracer":
    70  				// Can't compare functions.
    71  			case "Versions":
    72  				f.Set(reflect.ValueOf([]VersionNumber{1, 2, 3}))
    73  			case "ConnectionIDLength":
    74  				f.Set(reflect.ValueOf(8))
    75  			case "ConnectionIDGenerator":
    76  				f.Set(reflect.ValueOf(&protocol.DefaultConnectionIDGenerator{ConnLen: protocol.DefaultConnectionIDLength}))
    77  			case "HandshakeIdleTimeout":
    78  				f.Set(reflect.ValueOf(time.Second))
    79  			case "MaxIdleTimeout":
    80  				f.Set(reflect.ValueOf(time.Hour))
    81  			case "MaxTokenAge":
    82  				f.Set(reflect.ValueOf(2 * time.Hour))
    83  			case "MaxRetryTokenAge":
    84  				f.Set(reflect.ValueOf(2 * time.Minute))
    85  			case "TokenStore":
    86  				f.Set(reflect.ValueOf(NewLRUTokenStore(2, 3)))
    87  			case "InitialStreamReceiveWindow":
    88  				f.Set(reflect.ValueOf(uint64(1234)))
    89  			case "MaxStreamReceiveWindow":
    90  				f.Set(reflect.ValueOf(uint64(9)))
    91  			case "InitialConnectionReceiveWindow":
    92  				f.Set(reflect.ValueOf(uint64(4321)))
    93  			case "MaxConnectionReceiveWindow":
    94  				f.Set(reflect.ValueOf(uint64(10)))
    95  			case "MaxIncomingStreams":
    96  				f.Set(reflect.ValueOf(int64(11)))
    97  			case "MaxIncomingUniStreams":
    98  				f.Set(reflect.ValueOf(int64(12)))
    99  			case "StatelessResetKey":
   100  				f.Set(reflect.ValueOf(&StatelessResetKey{1, 2, 3, 4}))
   101  			case "KeepAlivePeriod":
   102  				f.Set(reflect.ValueOf(time.Second))
   103  			case "EnableDatagrams":
   104  				f.Set(reflect.ValueOf(true))
   105  			case "DisableVersionNegotiationPackets":
   106  				f.Set(reflect.ValueOf(true))
   107  			case "DisablePathMTUDiscovery":
   108  				f.Set(reflect.ValueOf(true))
   109  			case "SCID":
   110  			case "DCID":
   111  			case "Allow0RTT":
   112  				f.Set(reflect.ValueOf(true))
   113  			default:
   114  				Fail(fmt.Sprintf("all fields must be accounted for, but saw unknown field %q", fn))
   115  			}
   116  		}
   117  		return c
   118  	}
   119  
   120  	It("uses 10s handshake timeout for short handshake idle timeouts", func() {
   121  		c := &Config{HandshakeIdleTimeout: time.Second}
   122  		Expect(c.handshakeTimeout()).To(Equal(protocol.DefaultHandshakeTimeout))
   123  	})
   124  
   125  	It("uses twice the handshake idle timeouts for the handshake timeout, for long handshake idle timeouts", func() {
   126  		c := &Config{HandshakeIdleTimeout: time.Second * 11 / 2}
   127  		Expect(c.handshakeTimeout()).To(Equal(11 * time.Second))
   128  	})
   129  
   130  	Context("cloning", func() {
   131  		It("clones function fields", func() {
   132  			var calledAddrValidation, calledAllowConnectionWindowIncrease, calledTracer bool
   133  			c1 := &Config{
   134  				GetConfigForClient:            func(info *ClientHelloInfo) (*Config, error) { return nil, errors.New("nope") },
   135  				AllowConnectionWindowIncrease: func(Connection, uint64) bool { calledAllowConnectionWindowIncrease = true; return true },
   136  				RequireAddressValidation:      func(net.Addr) bool { calledAddrValidation = true; return true },
   137  				Tracer: func(context.Context, logging.Perspective, ConnectionID) logging.ConnectionTracer {
   138  					calledTracer = true
   139  					return nil
   140  				},
   141  			}
   142  			c2 := c1.Clone()
   143  			c2.RequireAddressValidation(&net.UDPAddr{})
   144  			Expect(calledAddrValidation).To(BeTrue())
   145  			c2.AllowConnectionWindowIncrease(nil, 1234)
   146  			Expect(calledAllowConnectionWindowIncrease).To(BeTrue())
   147  			_, err := c2.GetConfigForClient(&ClientHelloInfo{})
   148  			Expect(err).To(MatchError("nope"))
   149  			c2.Tracer(context.Background(), logging.PerspectiveClient, protocol.ConnectionID{})
   150  			Expect(calledTracer).To(BeTrue())
   151  		})
   152  
   153  		It("clones non-function fields", func() {
   154  			c := configWithNonZeroNonFunctionFields()
   155  			Expect(c.Clone()).To(Equal(c))
   156  		})
   157  
   158  		It("returns a copy", func() {
   159  			c1 := &Config{
   160  				MaxIncomingStreams:       100,
   161  				RequireAddressValidation: func(net.Addr) bool { return true },
   162  			}
   163  			c2 := c1.Clone()
   164  			c2.MaxIncomingStreams = 200
   165  			c2.RequireAddressValidation = func(net.Addr) bool { return false }
   166  
   167  			Expect(c1.MaxIncomingStreams).To(BeEquivalentTo(100))
   168  			Expect(c1.RequireAddressValidation(&net.UDPAddr{})).To(BeTrue())
   169  		})
   170  	})
   171  
   172  	Context("populating", func() {
   173  		It("populates function fields", func() {
   174  			var calledAddrValidation bool
   175  			c1 := &Config{}
   176  			c1.RequireAddressValidation = func(net.Addr) bool { calledAddrValidation = true; return true }
   177  			c2 := populateConfig(c1)
   178  			c2.RequireAddressValidation(&net.UDPAddr{})
   179  			Expect(calledAddrValidation).To(BeTrue())
   180  		})
   181  
   182  		It("copies non-function fields", func() {
   183  			c := configWithNonZeroNonFunctionFields()
   184  			Expect(populateConfig(c)).To(Equal(c))
   185  		})
   186  
   187  		It("populates empty fields with default values", func() {
   188  			c := populateConfig(&Config{})
   189  			Expect(c.Versions).To(Equal(protocol.SupportedVersions))
   190  			Expect(c.HandshakeIdleTimeout).To(Equal(protocol.DefaultHandshakeIdleTimeout))
   191  			Expect(c.InitialStreamReceiveWindow).To(BeEquivalentTo(protocol.DefaultInitialMaxStreamData))
   192  			Expect(c.MaxStreamReceiveWindow).To(BeEquivalentTo(protocol.DefaultMaxReceiveStreamFlowControlWindow))
   193  			Expect(c.InitialConnectionReceiveWindow).To(BeEquivalentTo(protocol.DefaultInitialMaxData))
   194  			Expect(c.MaxConnectionReceiveWindow).To(BeEquivalentTo(protocol.DefaultMaxReceiveConnectionFlowControlWindow))
   195  			Expect(c.MaxIncomingStreams).To(BeEquivalentTo(protocol.DefaultMaxIncomingStreams))
   196  			Expect(c.MaxIncomingUniStreams).To(BeEquivalentTo(protocol.DefaultMaxIncomingUniStreams))
   197  			Expect(c.DisableVersionNegotiationPackets).To(BeFalse())
   198  			Expect(c.DisablePathMTUDiscovery).To(BeFalse())
   199  			Expect(c.GetConfigForClient).To(BeNil())
   200  		})
   201  
   202  		It("populates empty fields with default values, for the server", func() {
   203  			c := populateServerConfig(&Config{})
   204  			Expect(c.RequireAddressValidation).ToNot(BeNil())
   205  		})
   206  	})
   207  })