github.com/danielpfeifer02/quic-go-prio-packs@v0.41.0-28/config_test.go (about)

     1  package quic
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"reflect"
     8  	"time"
     9  
    10  	"github.com/danielpfeifer02/quic-go-prio-packs/internal/protocol"
    11  	"github.com/danielpfeifer02/quic-go-prio-packs/logging"
    12  	"github.com/danielpfeifer02/quic-go-prio-packs/quicvarint"
    13  
    14  	. "github.com/onsi/ginkgo/v2"
    15  	. "github.com/onsi/gomega"
    16  )
    17  
    18  var _ = Describe("Config", func() {
    19  	Context("validating", func() {
    20  		It("validates a nil config", func() {
    21  			Expect(validateConfig(nil)).To(Succeed())
    22  		})
    23  
    24  		It("validates a config with normal values", func() {
    25  			conf := populateConfig(&Config{
    26  				MaxIncomingStreams:     5,
    27  				MaxStreamReceiveWindow: 10,
    28  			})
    29  			Expect(validateConfig(conf)).To(Succeed())
    30  			Expect(conf.MaxIncomingStreams).To(BeEquivalentTo(5))
    31  			Expect(conf.MaxStreamReceiveWindow).To(BeEquivalentTo(10))
    32  		})
    33  
    34  		It("clips too large values for the stream limits", func() {
    35  			conf := &Config{
    36  				MaxIncomingStreams:    1<<60 + 1,
    37  				MaxIncomingUniStreams: 1<<60 + 2,
    38  			}
    39  			Expect(validateConfig(conf)).To(Succeed())
    40  			Expect(conf.MaxIncomingStreams).To(BeEquivalentTo(int64(1 << 60)))
    41  			Expect(conf.MaxIncomingUniStreams).To(BeEquivalentTo(int64(1 << 60)))
    42  		})
    43  
    44  		It("clips too large values for the flow control windows", func() {
    45  			conf := &Config{
    46  				MaxStreamReceiveWindow:     quicvarint.Max + 1,
    47  				MaxConnectionReceiveWindow: quicvarint.Max + 2,
    48  			}
    49  			Expect(validateConfig(conf)).To(Succeed())
    50  			Expect(conf.MaxStreamReceiveWindow).To(BeEquivalentTo(uint64(quicvarint.Max)))
    51  			Expect(conf.MaxConnectionReceiveWindow).To(BeEquivalentTo(uint64(quicvarint.Max)))
    52  		})
    53  	})
    54  
    55  	configWithNonZeroNonFunctionFields := func() *Config {
    56  		c := &Config{}
    57  		v := reflect.ValueOf(c).Elem()
    58  
    59  		typ := v.Type()
    60  		for i := 0; i < typ.NumField(); i++ {
    61  			f := v.Field(i)
    62  			if !f.CanSet() {
    63  				// unexported field; not cloned.
    64  				continue
    65  			}
    66  
    67  			switch fn := typ.Field(i).Name; fn {
    68  			case "GetConfigForClient", "RequireAddressValidation", "GetLogWriter", "AllowConnectionWindowIncrease", "Tracer":
    69  				// Can't compare functions.
    70  			case "Versions":
    71  				f.Set(reflect.ValueOf([]Version{1, 2, 3}))
    72  			case "ConnectionIDLength":
    73  				f.Set(reflect.ValueOf(8))
    74  			case "ConnectionIDGenerator":
    75  				f.Set(reflect.ValueOf(&protocol.DefaultConnectionIDGenerator{ConnLen: protocol.DefaultConnectionIDLength}))
    76  			case "HandshakeIdleTimeout":
    77  				f.Set(reflect.ValueOf(time.Second))
    78  			case "MaxIdleTimeout":
    79  				f.Set(reflect.ValueOf(time.Hour))
    80  			case "TokenStore":
    81  				f.Set(reflect.ValueOf(NewLRUTokenStore(2, 3)))
    82  			case "InitialStreamReceiveWindow":
    83  				f.Set(reflect.ValueOf(uint64(1234)))
    84  			case "MaxStreamReceiveWindow":
    85  				f.Set(reflect.ValueOf(uint64(9)))
    86  			case "InitialConnectionReceiveWindow":
    87  				f.Set(reflect.ValueOf(uint64(4321)))
    88  			case "MaxConnectionReceiveWindow":
    89  				f.Set(reflect.ValueOf(uint64(10)))
    90  			case "MaxIncomingStreams":
    91  				f.Set(reflect.ValueOf(int64(11)))
    92  			case "MaxIncomingUniStreams":
    93  				f.Set(reflect.ValueOf(int64(12)))
    94  			case "StatelessResetKey":
    95  				f.Set(reflect.ValueOf(&StatelessResetKey{1, 2, 3, 4}))
    96  			case "KeepAlivePeriod":
    97  				f.Set(reflect.ValueOf(time.Second))
    98  			case "EnableDatagrams":
    99  				f.Set(reflect.ValueOf(true))
   100  			case "DisableVersionNegotiationPackets":
   101  				f.Set(reflect.ValueOf(true))
   102  			case "DisablePathMTUDiscovery":
   103  				f.Set(reflect.ValueOf(true))
   104  			case "Allow0RTT":
   105  				f.Set(reflect.ValueOf(true))
   106  			default:
   107  				Fail(fmt.Sprintf("all fields must be accounted for, but saw unknown field %q", fn))
   108  			}
   109  		}
   110  		return c
   111  	}
   112  
   113  	It("uses twice the handshake idle timeouts for the handshake timeout", func() {
   114  		c := &Config{HandshakeIdleTimeout: time.Second * 11 / 2}
   115  		Expect(c.handshakeTimeout()).To(Equal(11 * time.Second))
   116  	})
   117  
   118  	Context("cloning", func() {
   119  		It("clones function fields", func() {
   120  			var calledAllowConnectionWindowIncrease, calledTracer bool
   121  			c1 := &Config{
   122  				GetConfigForClient:            func(info *ClientHelloInfo) (*Config, error) { return nil, errors.New("nope") },
   123  				AllowConnectionWindowIncrease: func(Connection, uint64) bool { calledAllowConnectionWindowIncrease = true; return true },
   124  				Tracer: func(context.Context, logging.Perspective, ConnectionID) *logging.ConnectionTracer {
   125  					calledTracer = true
   126  					return nil
   127  				},
   128  			}
   129  			c2 := c1.Clone()
   130  			c2.AllowConnectionWindowIncrease(nil, 1234)
   131  			Expect(calledAllowConnectionWindowIncrease).To(BeTrue())
   132  			_, err := c2.GetConfigForClient(&ClientHelloInfo{})
   133  			Expect(err).To(MatchError("nope"))
   134  			c2.Tracer(context.Background(), logging.PerspectiveClient, protocol.ConnectionID{})
   135  			Expect(calledTracer).To(BeTrue())
   136  		})
   137  
   138  		It("clones non-function fields", func() {
   139  			c := configWithNonZeroNonFunctionFields()
   140  			Expect(c.Clone()).To(Equal(c))
   141  		})
   142  
   143  		It("returns a copy", func() {
   144  			c1 := &Config{MaxIncomingStreams: 100}
   145  			c2 := c1.Clone()
   146  			c2.MaxIncomingStreams = 200
   147  
   148  			Expect(c1.MaxIncomingStreams).To(BeEquivalentTo(100))
   149  		})
   150  	})
   151  
   152  	Context("populating", func() {
   153  		It("copies non-function fields", func() {
   154  			c := configWithNonZeroNonFunctionFields()
   155  			Expect(populateConfig(c)).To(Equal(c))
   156  		})
   157  
   158  		It("populates empty fields with default values", func() {
   159  			c := populateConfig(&Config{})
   160  			Expect(c.Versions).To(Equal(protocol.SupportedVersions))
   161  			Expect(c.HandshakeIdleTimeout).To(Equal(protocol.DefaultHandshakeIdleTimeout))
   162  			Expect(c.InitialStreamReceiveWindow).To(BeEquivalentTo(protocol.DefaultInitialMaxStreamData))
   163  			Expect(c.MaxStreamReceiveWindow).To(BeEquivalentTo(protocol.DefaultMaxReceiveStreamFlowControlWindow))
   164  			Expect(c.InitialConnectionReceiveWindow).To(BeEquivalentTo(protocol.DefaultInitialMaxData))
   165  			Expect(c.MaxConnectionReceiveWindow).To(BeEquivalentTo(protocol.DefaultMaxReceiveConnectionFlowControlWindow))
   166  			Expect(c.MaxIncomingStreams).To(BeEquivalentTo(protocol.DefaultMaxIncomingStreams))
   167  			Expect(c.MaxIncomingUniStreams).To(BeEquivalentTo(protocol.DefaultMaxIncomingUniStreams))
   168  			Expect(c.DisablePathMTUDiscovery).To(BeFalse())
   169  			Expect(c.GetConfigForClient).To(BeNil())
   170  		})
   171  	})
   172  })