github.com/mikelsr/quic-go@v0.36.1-0.20230701132136-1d9415b66898/config_test.go (about)

     1  package quic
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"net"
     8  	"reflect"
     9  	"time"
    10  
    11  	"github.com/mikelsr/quic-go/internal/protocol"
    12  	"github.com/mikelsr/quic-go/logging"
    13  	"github.com/mikelsr/quic-go/quicvarint"
    14  
    15  	. "github.com/onsi/ginkgo/v2"
    16  	. "github.com/onsi/gomega"
    17  )
    18  
    19  var _ = Describe("Config", func() {
    20  	Context("validating", func() {
    21  		It("validates a nil config", func() {
    22  			Expect(validateConfig(nil)).To(Succeed())
    23  		})
    24  
    25  		It("validates a config with normal values", func() {
    26  			conf := populateServerConfig(&Config{
    27  				MaxIncomingStreams:     5,
    28  				MaxStreamReceiveWindow: 10,
    29  			})
    30  			Expect(validateConfig(conf)).To(Succeed())
    31  			Expect(conf.MaxIncomingStreams).To(BeEquivalentTo(5))
    32  			Expect(conf.MaxStreamReceiveWindow).To(BeEquivalentTo(10))
    33  		})
    34  
    35  		It("clips too large values for the stream limits", func() {
    36  			conf := &Config{
    37  				MaxIncomingStreams:    1<<60 + 1,
    38  				MaxIncomingUniStreams: 1<<60 + 2,
    39  			}
    40  			Expect(validateConfig(conf)).To(Succeed())
    41  			Expect(conf.MaxIncomingStreams).To(BeEquivalentTo(int64(1 << 60)))
    42  			Expect(conf.MaxIncomingUniStreams).To(BeEquivalentTo(int64(1 << 60)))
    43  		})
    44  
    45  		It("clips too large values for the flow control windows", func() {
    46  			conf := &Config{
    47  				MaxStreamReceiveWindow:     quicvarint.Max + 1,
    48  				MaxConnectionReceiveWindow: quicvarint.Max + 2,
    49  			}
    50  			Expect(validateConfig(conf)).To(Succeed())
    51  			Expect(conf.MaxStreamReceiveWindow).To(BeEquivalentTo(uint64(quicvarint.Max)))
    52  			Expect(conf.MaxConnectionReceiveWindow).To(BeEquivalentTo(uint64(quicvarint.Max)))
    53  		})
    54  	})
    55  
    56  	configWithNonZeroNonFunctionFields := func() *Config {
    57  		c := &Config{}
    58  		v := reflect.ValueOf(c).Elem()
    59  
    60  		typ := v.Type()
    61  		for i := 0; i < typ.NumField(); i++ {
    62  			f := v.Field(i)
    63  			if !f.CanSet() {
    64  				// unexported field; not cloned.
    65  				continue
    66  			}
    67  
    68  			switch fn := typ.Field(i).Name; fn {
    69  			case "GetConfigForClient", "RequireAddressValidation", "GetLogWriter", "AllowConnectionWindowIncrease", "Tracer":
    70  				// Can't compare functions.
    71  			case "Versions":
    72  				f.Set(reflect.ValueOf([]VersionNumber{1, 2, 3}))
    73  			case "ConnectionIDLength":
    74  				f.Set(reflect.ValueOf(8))
    75  			case "ConnectionIDGenerator":
    76  				f.Set(reflect.ValueOf(&protocol.DefaultConnectionIDGenerator{ConnLen: protocol.DefaultConnectionIDLength}))
    77  			case "HandshakeIdleTimeout":
    78  				f.Set(reflect.ValueOf(time.Second))
    79  			case "MaxIdleTimeout":
    80  				f.Set(reflect.ValueOf(time.Hour))
    81  			case "MaxTokenAge":
    82  				f.Set(reflect.ValueOf(2 * time.Hour))
    83  			case "MaxRetryTokenAge":
    84  				f.Set(reflect.ValueOf(2 * time.Minute))
    85  			case "TokenStore":
    86  				f.Set(reflect.ValueOf(NewLRUTokenStore(2, 3)))
    87  			case "InitialStreamReceiveWindow":
    88  				f.Set(reflect.ValueOf(uint64(1234)))
    89  			case "MaxStreamReceiveWindow":
    90  				f.Set(reflect.ValueOf(uint64(9)))
    91  			case "InitialConnectionReceiveWindow":
    92  				f.Set(reflect.ValueOf(uint64(4321)))
    93  			case "MaxConnectionReceiveWindow":
    94  				f.Set(reflect.ValueOf(uint64(10)))
    95  			case "MaxIncomingStreams":
    96  				f.Set(reflect.ValueOf(int64(11)))
    97  			case "MaxIncomingUniStreams":
    98  				f.Set(reflect.ValueOf(int64(12)))
    99  			case "StatelessResetKey":
   100  				f.Set(reflect.ValueOf(&StatelessResetKey{1, 2, 3, 4}))
   101  			case "KeepAlivePeriod":
   102  				f.Set(reflect.ValueOf(time.Second))
   103  			case "EnableDatagrams":
   104  				f.Set(reflect.ValueOf(true))
   105  			case "DisableVersionNegotiationPackets":
   106  				f.Set(reflect.ValueOf(true))
   107  			case "DisablePathMTUDiscovery":
   108  				f.Set(reflect.ValueOf(true))
   109  			case "Allow0RTT":
   110  				f.Set(reflect.ValueOf(true))
   111  			default:
   112  				Fail(fmt.Sprintf("all fields must be accounted for, but saw unknown field %q", fn))
   113  			}
   114  		}
   115  		return c
   116  	}
   117  
   118  	It("uses 10s handshake timeout for short handshake idle timeouts", func() {
   119  		c := &Config{HandshakeIdleTimeout: time.Second}
   120  		Expect(c.handshakeTimeout()).To(Equal(protocol.DefaultHandshakeTimeout))
   121  	})
   122  
   123  	It("uses twice the handshake idle timeouts for the handshake timeout, for long handshake idle timeouts", func() {
   124  		c := &Config{HandshakeIdleTimeout: time.Second * 11 / 2}
   125  		Expect(c.handshakeTimeout()).To(Equal(11 * time.Second))
   126  	})
   127  
   128  	Context("cloning", func() {
   129  		It("clones function fields", func() {
   130  			var calledAddrValidation, calledAllowConnectionWindowIncrease, calledTracer bool
   131  			c1 := &Config{
   132  				GetConfigForClient:            func(info *ClientHelloInfo) (*Config, error) { return nil, errors.New("nope") },
   133  				AllowConnectionWindowIncrease: func(Connection, uint64) bool { calledAllowConnectionWindowIncrease = true; return true },
   134  				RequireAddressValidation:      func(net.Addr) bool { calledAddrValidation = true; return true },
   135  				Tracer: func(context.Context, logging.Perspective, ConnectionID) logging.ConnectionTracer {
   136  					calledTracer = true
   137  					return nil
   138  				},
   139  			}
   140  			c2 := c1.Clone()
   141  			c2.RequireAddressValidation(&net.UDPAddr{})
   142  			Expect(calledAddrValidation).To(BeTrue())
   143  			c2.AllowConnectionWindowIncrease(nil, 1234)
   144  			Expect(calledAllowConnectionWindowIncrease).To(BeTrue())
   145  			_, err := c2.GetConfigForClient(&ClientHelloInfo{})
   146  			Expect(err).To(MatchError("nope"))
   147  			c2.Tracer(context.Background(), logging.PerspectiveClient, protocol.ConnectionID{})
   148  			Expect(calledTracer).To(BeTrue())
   149  		})
   150  
   151  		It("clones non-function fields", func() {
   152  			c := configWithNonZeroNonFunctionFields()
   153  			Expect(c.Clone()).To(Equal(c))
   154  		})
   155  
   156  		It("returns a copy", func() {
   157  			c1 := &Config{
   158  				MaxIncomingStreams:       100,
   159  				RequireAddressValidation: func(net.Addr) bool { return true },
   160  			}
   161  			c2 := c1.Clone()
   162  			c2.MaxIncomingStreams = 200
   163  			c2.RequireAddressValidation = func(net.Addr) bool { return false }
   164  
   165  			Expect(c1.MaxIncomingStreams).To(BeEquivalentTo(100))
   166  			Expect(c1.RequireAddressValidation(&net.UDPAddr{})).To(BeTrue())
   167  		})
   168  	})
   169  
   170  	Context("populating", func() {
   171  		It("populates function fields", func() {
   172  			var calledAddrValidation bool
   173  			c1 := &Config{}
   174  			c1.RequireAddressValidation = func(net.Addr) bool { calledAddrValidation = true; return true }
   175  			c2 := populateConfig(c1)
   176  			c2.RequireAddressValidation(&net.UDPAddr{})
   177  			Expect(calledAddrValidation).To(BeTrue())
   178  		})
   179  
   180  		It("copies non-function fields", func() {
   181  			c := configWithNonZeroNonFunctionFields()
   182  			Expect(populateConfig(c)).To(Equal(c))
   183  		})
   184  
   185  		It("populates empty fields with default values", func() {
   186  			c := populateConfig(&Config{})
   187  			Expect(c.Versions).To(Equal(protocol.SupportedVersions))
   188  			Expect(c.HandshakeIdleTimeout).To(Equal(protocol.DefaultHandshakeIdleTimeout))
   189  			Expect(c.InitialStreamReceiveWindow).To(BeEquivalentTo(protocol.DefaultInitialMaxStreamData))
   190  			Expect(c.MaxStreamReceiveWindow).To(BeEquivalentTo(protocol.DefaultMaxReceiveStreamFlowControlWindow))
   191  			Expect(c.InitialConnectionReceiveWindow).To(BeEquivalentTo(protocol.DefaultInitialMaxData))
   192  			Expect(c.MaxConnectionReceiveWindow).To(BeEquivalentTo(protocol.DefaultMaxReceiveConnectionFlowControlWindow))
   193  			Expect(c.MaxIncomingStreams).To(BeEquivalentTo(protocol.DefaultMaxIncomingStreams))
   194  			Expect(c.MaxIncomingUniStreams).To(BeEquivalentTo(protocol.DefaultMaxIncomingUniStreams))
   195  			Expect(c.DisableVersionNegotiationPackets).To(BeFalse())
   196  			Expect(c.DisablePathMTUDiscovery).To(BeFalse())
   197  			Expect(c.GetConfigForClient).To(BeNil())
   198  		})
   199  
   200  		It("populates empty fields with default values, for the server", func() {
   201  			c := populateServerConfig(&Config{})
   202  			Expect(c.RequireAddressValidation).ToNot(BeNil())
   203  		})
   204  	})
   205  })