github.com/MerlinKodo/quic-go@v0.39.2/config_test.go (about)

     1  package quic
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"net"
     8  	"reflect"
     9  	"time"
    10  
    11  	"github.com/MerlinKodo/quic-go/internal/protocol"
    12  	"github.com/MerlinKodo/quic-go/logging"
    13  	"github.com/MerlinKodo/quic-go/quicvarint"
    14  
    15  	. "github.com/onsi/ginkgo/v2"
    16  	. "github.com/onsi/gomega"
    17  )
    18  
    19  var _ = Describe("Config", func() {
    20  	Context("validating", func() {
    21  		It("validates a nil config", func() {
    22  			Expect(validateConfig(nil)).To(Succeed())
    23  		})
    24  
    25  		It("validates a config with normal values", func() {
    26  			conf := populateServerConfig(&Config{
    27  				MaxIncomingStreams:     5,
    28  				MaxStreamReceiveWindow: 10,
    29  			})
    30  			Expect(validateConfig(conf)).To(Succeed())
    31  			Expect(conf.MaxIncomingStreams).To(BeEquivalentTo(5))
    32  			Expect(conf.MaxStreamReceiveWindow).To(BeEquivalentTo(10))
    33  		})
    34  
    35  		It("clips too large values for the stream limits", func() {
    36  			conf := &Config{
    37  				MaxIncomingStreams:    1<<60 + 1,
    38  				MaxIncomingUniStreams: 1<<60 + 2,
    39  			}
    40  			Expect(validateConfig(conf)).To(Succeed())
    41  			Expect(conf.MaxIncomingStreams).To(BeEquivalentTo(int64(1 << 60)))
    42  			Expect(conf.MaxIncomingUniStreams).To(BeEquivalentTo(int64(1 << 60)))
    43  		})
    44  
    45  		It("clips too large values for the flow control windows", func() {
    46  			conf := &Config{
    47  				MaxStreamReceiveWindow:     quicvarint.Max + 1,
    48  				MaxConnectionReceiveWindow: quicvarint.Max + 2,
    49  			}
    50  			Expect(validateConfig(conf)).To(Succeed())
    51  			Expect(conf.MaxStreamReceiveWindow).To(BeEquivalentTo(uint64(quicvarint.Max)))
    52  			Expect(conf.MaxConnectionReceiveWindow).To(BeEquivalentTo(uint64(quicvarint.Max)))
    53  		})
    54  	})
    55  
    56  	configWithNonZeroNonFunctionFields := func() *Config {
    57  		c := &Config{}
    58  		v := reflect.ValueOf(c).Elem()
    59  
    60  		typ := v.Type()
    61  		for i := 0; i < typ.NumField(); i++ {
    62  			f := v.Field(i)
    63  			if !f.CanSet() {
    64  				// unexported field; not cloned.
    65  				continue
    66  			}
    67  
    68  			switch fn := typ.Field(i).Name; fn {
    69  			case "GetConfigForClient", "RequireAddressValidation", "GetLogWriter", "AllowConnectionWindowIncrease", "Tracer":
    70  				// Can't compare functions.
    71  			case "Versions":
    72  				f.Set(reflect.ValueOf([]VersionNumber{1, 2, 3}))
    73  			case "ConnectionIDLength":
    74  				f.Set(reflect.ValueOf(8))
    75  			case "ConnectionIDGenerator":
    76  				f.Set(reflect.ValueOf(&protocol.DefaultConnectionIDGenerator{ConnLen: protocol.DefaultConnectionIDLength}))
    77  			case "HandshakeIdleTimeout":
    78  				f.Set(reflect.ValueOf(time.Second))
    79  			case "MaxIdleTimeout":
    80  				f.Set(reflect.ValueOf(time.Hour))
    81  			case "TokenStore":
    82  				f.Set(reflect.ValueOf(NewLRUTokenStore(2, 3)))
    83  			case "InitialStreamReceiveWindow":
    84  				f.Set(reflect.ValueOf(uint64(1234)))
    85  			case "MaxStreamReceiveWindow":
    86  				f.Set(reflect.ValueOf(uint64(9)))
    87  			case "InitialConnectionReceiveWindow":
    88  				f.Set(reflect.ValueOf(uint64(4321)))
    89  			case "MaxConnectionReceiveWindow":
    90  				f.Set(reflect.ValueOf(uint64(10)))
    91  			case "MaxIncomingStreams":
    92  				f.Set(reflect.ValueOf(int64(11)))
    93  			case "MaxIncomingUniStreams":
    94  				f.Set(reflect.ValueOf(int64(12)))
    95  			case "StatelessResetKey":
    96  				f.Set(reflect.ValueOf(&StatelessResetKey{1, 2, 3, 4}))
    97  			case "KeepAlivePeriod":
    98  				f.Set(reflect.ValueOf(time.Second))
    99  			case "EnableDatagrams":
   100  				f.Set(reflect.ValueOf(true))
   101  			case "DisableVersionNegotiationPackets":
   102  				f.Set(reflect.ValueOf(true))
   103  			case "DisablePathMTUDiscovery":
   104  				f.Set(reflect.ValueOf(true))
   105  			case "Allow0RTT":
   106  				f.Set(reflect.ValueOf(true))
   107  			default:
   108  				Fail(fmt.Sprintf("all fields must be accounted for, but saw unknown field %q", fn))
   109  			}
   110  		}
   111  		return c
   112  	}
   113  
   114  	It("uses twice the handshake idle timeouts for the handshake timeout", func() {
   115  		c := &Config{HandshakeIdleTimeout: time.Second * 11 / 2}
   116  		Expect(c.handshakeTimeout()).To(Equal(11 * time.Second))
   117  	})
   118  
   119  	Context("cloning", func() {
   120  		It("clones function fields", func() {
   121  			var calledAddrValidation, calledAllowConnectionWindowIncrease, calledTracer bool
   122  			c1 := &Config{
   123  				GetConfigForClient:            func(info *ClientHelloInfo) (*Config, error) { return nil, errors.New("nope") },
   124  				AllowConnectionWindowIncrease: func(Connection, uint64) bool { calledAllowConnectionWindowIncrease = true; return true },
   125  				RequireAddressValidation:      func(net.Addr) bool { calledAddrValidation = true; return true },
   126  				Tracer: func(context.Context, logging.Perspective, ConnectionID) *logging.ConnectionTracer {
   127  					calledTracer = true
   128  					return nil
   129  				},
   130  			}
   131  			c2 := c1.Clone()
   132  			c2.RequireAddressValidation(&net.UDPAddr{})
   133  			Expect(calledAddrValidation).To(BeTrue())
   134  			c2.AllowConnectionWindowIncrease(nil, 1234)
   135  			Expect(calledAllowConnectionWindowIncrease).To(BeTrue())
   136  			_, err := c2.GetConfigForClient(&ClientHelloInfo{})
   137  			Expect(err).To(MatchError("nope"))
   138  			c2.Tracer(context.Background(), logging.PerspectiveClient, protocol.ConnectionID{})
   139  			Expect(calledTracer).To(BeTrue())
   140  		})
   141  
   142  		It("clones non-function fields", func() {
   143  			c := configWithNonZeroNonFunctionFields()
   144  			Expect(c.Clone()).To(Equal(c))
   145  		})
   146  
   147  		It("returns a copy", func() {
   148  			c1 := &Config{
   149  				MaxIncomingStreams:       100,
   150  				RequireAddressValidation: func(net.Addr) bool { return true },
   151  			}
   152  			c2 := c1.Clone()
   153  			c2.MaxIncomingStreams = 200
   154  			c2.RequireAddressValidation = func(net.Addr) bool { return false }
   155  
   156  			Expect(c1.MaxIncomingStreams).To(BeEquivalentTo(100))
   157  			Expect(c1.RequireAddressValidation(&net.UDPAddr{})).To(BeTrue())
   158  		})
   159  	})
   160  
   161  	Context("populating", func() {
   162  		It("populates function fields", func() {
   163  			var calledAddrValidation bool
   164  			c1 := &Config{}
   165  			c1.RequireAddressValidation = func(net.Addr) bool { calledAddrValidation = true; return true }
   166  			c2 := populateConfig(c1)
   167  			c2.RequireAddressValidation(&net.UDPAddr{})
   168  			Expect(calledAddrValidation).To(BeTrue())
   169  		})
   170  
   171  		It("copies non-function fields", func() {
   172  			c := configWithNonZeroNonFunctionFields()
   173  			Expect(populateConfig(c)).To(Equal(c))
   174  		})
   175  
   176  		It("populates empty fields with default values", func() {
   177  			c := populateConfig(&Config{})
   178  			Expect(c.Versions).To(Equal(protocol.SupportedVersions))
   179  			Expect(c.HandshakeIdleTimeout).To(Equal(protocol.DefaultHandshakeIdleTimeout))
   180  			Expect(c.InitialStreamReceiveWindow).To(BeEquivalentTo(protocol.DefaultInitialMaxStreamData))
   181  			Expect(c.MaxStreamReceiveWindow).To(BeEquivalentTo(protocol.DefaultMaxReceiveStreamFlowControlWindow))
   182  			Expect(c.InitialConnectionReceiveWindow).To(BeEquivalentTo(protocol.DefaultInitialMaxData))
   183  			Expect(c.MaxConnectionReceiveWindow).To(BeEquivalentTo(protocol.DefaultMaxReceiveConnectionFlowControlWindow))
   184  			Expect(c.MaxIncomingStreams).To(BeEquivalentTo(protocol.DefaultMaxIncomingStreams))
   185  			Expect(c.MaxIncomingUniStreams).To(BeEquivalentTo(protocol.DefaultMaxIncomingUniStreams))
   186  			Expect(c.DisablePathMTUDiscovery).To(BeFalse())
   187  			Expect(c.GetConfigForClient).To(BeNil())
   188  		})
   189  
   190  		It("populates empty fields with default values, for the server", func() {
   191  			c := populateServerConfig(&Config{})
   192  			Expect(c.RequireAddressValidation).ToNot(BeNil())
   193  		})
   194  	})
   195  })