github.com/daeuniverse/quic-go@v0.0.0-20240413031024-943f218e0810/integrationtests/self/conn_id_test.go (about)

     1  package self_test
     2  
     3  import (
     4  	"context"
     5  	"crypto/rand"
     6  	"fmt"
     7  	"io"
     8  	mrand "math/rand"
     9  	"net"
    10  
    11  	"github.com/daeuniverse/quic-go"
    12  	"github.com/daeuniverse/quic-go/internal/protocol"
    13  
    14  	. "github.com/onsi/ginkgo/v2"
    15  	. "github.com/onsi/gomega"
    16  )
    17  
    18  type connIDGenerator struct {
    19  	length int
    20  }
    21  
    22  func (c *connIDGenerator) GenerateConnectionID() (quic.ConnectionID, error) {
    23  	b := make([]byte, c.length)
    24  	if _, err := rand.Read(b); err != nil {
    25  		fmt.Fprintf(GinkgoWriter, "generating conn ID failed: %s", err)
    26  	}
    27  	return protocol.ParseConnectionID(b), nil
    28  }
    29  
    30  func (c *connIDGenerator) ConnectionIDLen() int {
    31  	return c.length
    32  }
    33  
    34  var _ = Describe("Connection ID lengths tests", func() {
    35  	randomConnIDLen := func() int { return 4 + int(mrand.Int31n(15)) }
    36  
    37  	// connIDLen is ignored when connIDGenerator is set
    38  	runServer := func(connIDLen int, connIDGenerator quic.ConnectionIDGenerator) (*quic.Listener, func()) {
    39  		if connIDGenerator != nil {
    40  			GinkgoWriter.Write([]byte(fmt.Sprintf("Using %d byte connection ID generator for the server\n", connIDGenerator.ConnectionIDLen())))
    41  		} else {
    42  			GinkgoWriter.Write([]byte(fmt.Sprintf("Using %d byte connection ID for the server\n", connIDLen)))
    43  		}
    44  		addr, err := net.ResolveUDPAddr("udp", "localhost:0")
    45  		Expect(err).ToNot(HaveOccurred())
    46  		conn, err := net.ListenUDP("udp", addr)
    47  		Expect(err).ToNot(HaveOccurred())
    48  		tr := &quic.Transport{
    49  			Conn:                  conn,
    50  			ConnectionIDLength:    connIDLen,
    51  			ConnectionIDGenerator: connIDGenerator,
    52  		}
    53  		addTracer(tr)
    54  		ln, err := tr.Listen(getTLSConfig(), getQuicConfig(nil))
    55  		Expect(err).ToNot(HaveOccurred())
    56  		go func() {
    57  			defer GinkgoRecover()
    58  			for {
    59  				conn, err := ln.Accept(context.Background())
    60  				if err != nil {
    61  					return
    62  				}
    63  				go func() {
    64  					defer GinkgoRecover()
    65  					str, err := conn.OpenStream()
    66  					Expect(err).ToNot(HaveOccurred())
    67  					defer str.Close()
    68  					_, err = str.Write(PRData)
    69  					Expect(err).ToNot(HaveOccurred())
    70  				}()
    71  			}
    72  		}()
    73  		return ln, func() {
    74  			ln.Close()
    75  			tr.Close()
    76  		}
    77  	}
    78  
    79  	// connIDLen is ignored when connIDGenerator is set
    80  	runClient := func(addr net.Addr, connIDLen int, connIDGenerator quic.ConnectionIDGenerator) {
    81  		if connIDGenerator != nil {
    82  			GinkgoWriter.Write([]byte(fmt.Sprintf("Using %d byte connection ID generator for the client\n", connIDGenerator.ConnectionIDLen())))
    83  		} else {
    84  			GinkgoWriter.Write([]byte(fmt.Sprintf("Using %d byte connection ID for the client\n", connIDLen)))
    85  		}
    86  		laddr, err := net.ResolveUDPAddr("udp", "localhost:0")
    87  		Expect(err).ToNot(HaveOccurred())
    88  		conn, err := net.ListenUDP("udp", laddr)
    89  		Expect(err).ToNot(HaveOccurred())
    90  		defer conn.Close()
    91  		tr := &quic.Transport{
    92  			Conn:                  conn,
    93  			ConnectionIDLength:    connIDLen,
    94  			ConnectionIDGenerator: connIDGenerator,
    95  		}
    96  		addTracer(tr)
    97  		defer tr.Close()
    98  		cl, err := tr.Dial(
    99  			context.Background(),
   100  			&net.UDPAddr{IP: net.IPv4(127, 0, 0, 1), Port: addr.(*net.UDPAddr).Port},
   101  			getTLSClientConfig(),
   102  			getQuicConfig(nil),
   103  		)
   104  		Expect(err).ToNot(HaveOccurred())
   105  		defer cl.CloseWithError(0, "")
   106  		str, err := cl.AcceptStream(context.Background())
   107  		Expect(err).ToNot(HaveOccurred())
   108  		data, err := io.ReadAll(str)
   109  		Expect(err).ToNot(HaveOccurred())
   110  		Expect(data).To(Equal(PRData))
   111  	}
   112  
   113  	It("downloads a file using a 0-byte connection ID for the client", func() {
   114  		ln, closeFn := runServer(randomConnIDLen(), nil)
   115  		defer closeFn()
   116  		runClient(ln.Addr(), 0, nil)
   117  	})
   118  
   119  	It("downloads a file when both client and server use a random connection ID length", func() {
   120  		ln, closeFn := runServer(randomConnIDLen(), nil)
   121  		defer closeFn()
   122  		runClient(ln.Addr(), randomConnIDLen(), nil)
   123  	})
   124  
   125  	It("downloads a file when both client and server use a custom connection ID generator", func() {
   126  		ln, closeFn := runServer(0, &connIDGenerator{length: randomConnIDLen()})
   127  		defer closeFn()
   128  		runClient(ln.Addr(), 0, &connIDGenerator{length: randomConnIDLen()})
   129  	})
   130  })