github.com/metacubex/quic-go@v0.44.1-0.20240520163451-20b689a59136/integrationtests/self/self_suite_test.go (about)

     1  package self_test
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"crypto/tls"
     7  	"crypto/x509"
     8  	"flag"
     9  	"fmt"
    10  	"log"
    11  	"os"
    12  	"runtime/pprof"
    13  	"strconv"
    14  	"strings"
    15  	"sync"
    16  	"testing"
    17  	"time"
    18  
    19  	"github.com/metacubex/quic-go"
    20  	"github.com/metacubex/quic-go/integrationtests/tools"
    21  	"github.com/metacubex/quic-go/internal/protocol"
    22  	"github.com/metacubex/quic-go/internal/utils"
    23  	"github.com/metacubex/quic-go/internal/wire"
    24  	"github.com/metacubex/quic-go/logging"
    25  
    26  	. "github.com/onsi/ginkgo/v2"
    27  	. "github.com/onsi/gomega"
    28  )
    29  
    30  const alpn = tools.ALPN
    31  
    32  const (
    33  	dataLen     = 500 * 1024       // 500 KB
    34  	dataLenLong = 50 * 1024 * 1024 // 50 MB
    35  )
    36  
    37  var (
    38  	// PRData contains dataLen bytes of pseudo-random data.
    39  	PRData = GeneratePRData(dataLen)
    40  	// PRDataLong contains dataLenLong bytes of pseudo-random data.
    41  	PRDataLong = GeneratePRData(dataLenLong)
    42  )
    43  
    44  // See https://en.wikipedia.org/wiki/Lehmer_random_number_generator
    45  func GeneratePRData(l int) []byte {
    46  	res := make([]byte, l)
    47  	seed := uint64(1)
    48  	for i := 0; i < l; i++ {
    49  		seed = seed * 48271 % 2147483647
    50  		res[i] = byte(seed)
    51  	}
    52  	return res
    53  }
    54  
    55  const logBufSize = 100 * 1 << 20 // initial size of the log buffer: 100 MB
    56  
    57  type syncedBuffer struct {
    58  	mutex sync.Mutex
    59  
    60  	*bytes.Buffer
    61  }
    62  
    63  func (b *syncedBuffer) Write(p []byte) (int, error) {
    64  	b.mutex.Lock()
    65  	n, err := b.Buffer.Write(p)
    66  	b.mutex.Unlock()
    67  	return n, err
    68  }
    69  
    70  func (b *syncedBuffer) Bytes() []byte {
    71  	b.mutex.Lock()
    72  	p := b.Buffer.Bytes()
    73  	b.mutex.Unlock()
    74  	return p
    75  }
    76  
    77  func (b *syncedBuffer) Reset() {
    78  	b.mutex.Lock()
    79  	b.Buffer.Reset()
    80  	b.mutex.Unlock()
    81  }
    82  
    83  var (
    84  	logFileName  string // the log file set in the ginkgo flags
    85  	logBufOnce   sync.Once
    86  	logBuf       *syncedBuffer
    87  	versionParam string
    88  
    89  	enableQlog bool
    90  
    91  	version                          quic.Version
    92  	tlsConfig                        *tls.Config
    93  	tlsConfigLongChain               *tls.Config
    94  	tlsClientConfig                  *tls.Config
    95  	tlsClientConfigWithoutServerName *tls.Config
    96  )
    97  
    98  // read the logfile command line flag
    99  // to set call ginkgo -- -logfile=log.txt
   100  func init() {
   101  	flag.StringVar(&logFileName, "logfile", "", "log file")
   102  	flag.StringVar(&versionParam, "version", "1", "QUIC version")
   103  	flag.BoolVar(&enableQlog, "qlog", false, "enable qlog")
   104  
   105  	ca, caPrivateKey, err := tools.GenerateCA()
   106  	if err != nil {
   107  		panic(err)
   108  	}
   109  	leafCert, leafPrivateKey, err := tools.GenerateLeafCert(ca, caPrivateKey)
   110  	if err != nil {
   111  		panic(err)
   112  	}
   113  	tlsConfig = &tls.Config{
   114  		Certificates: []tls.Certificate{{
   115  			Certificate: [][]byte{leafCert.Raw},
   116  			PrivateKey:  leafPrivateKey,
   117  		}},
   118  		NextProtos: []string{alpn},
   119  	}
   120  	tlsConfLongChain, err := tools.GenerateTLSConfigWithLongCertChain(ca, caPrivateKey)
   121  	if err != nil {
   122  		panic(err)
   123  	}
   124  	tlsConfigLongChain = tlsConfLongChain
   125  
   126  	root := x509.NewCertPool()
   127  	root.AddCert(ca)
   128  	tlsClientConfig = &tls.Config{
   129  		ServerName: "localhost",
   130  		RootCAs:    root,
   131  		NextProtos: []string{alpn},
   132  	}
   133  	tlsClientConfigWithoutServerName = &tls.Config{
   134  		RootCAs:    root,
   135  		NextProtos: []string{alpn},
   136  	}
   137  }
   138  
   139  var _ = BeforeSuite(func() {
   140  	switch versionParam {
   141  	case "1":
   142  		version = quic.Version1
   143  	case "2":
   144  		version = quic.Version2
   145  	default:
   146  		Fail(fmt.Sprintf("unknown QUIC version: %s", versionParam))
   147  	}
   148  	fmt.Printf("Using QUIC version: %s\n", version)
   149  	protocol.SupportedVersions = []quic.Version{version}
   150  })
   151  
   152  func getTLSConfig() *tls.Config {
   153  	return tlsConfig.Clone()
   154  }
   155  
   156  func getTLSConfigWithLongCertChain() *tls.Config {
   157  	return tlsConfigLongChain.Clone()
   158  }
   159  
   160  func getTLSClientConfig() *tls.Config {
   161  	return tlsClientConfig.Clone()
   162  }
   163  
   164  func getTLSClientConfigWithoutServerName() *tls.Config {
   165  	return tlsClientConfigWithoutServerName.Clone()
   166  }
   167  
   168  func getQuicConfig(conf *quic.Config) *quic.Config {
   169  	if conf == nil {
   170  		conf = &quic.Config{}
   171  	} else {
   172  		conf = conf.Clone()
   173  	}
   174  	if !enableQlog {
   175  		return conf
   176  	}
   177  	if conf.Tracer == nil {
   178  		conf.Tracer = func(ctx context.Context, p logging.Perspective, connID quic.ConnectionID) *logging.ConnectionTracer {
   179  			return logging.NewMultiplexedConnectionTracer(
   180  				tools.NewQlogConnectionTracer(GinkgoWriter)(ctx, p, connID),
   181  				// multiplex it with an empty tracer to check that we're correctly ignoring unset callbacks everywhere
   182  				&logging.ConnectionTracer{},
   183  			)
   184  		}
   185  		return conf
   186  	}
   187  	origTracer := conf.Tracer
   188  	conf.Tracer = func(ctx context.Context, p logging.Perspective, connID quic.ConnectionID) *logging.ConnectionTracer {
   189  		return logging.NewMultiplexedConnectionTracer(
   190  			tools.NewQlogConnectionTracer(GinkgoWriter)(ctx, p, connID),
   191  			origTracer(ctx, p, connID),
   192  		)
   193  	}
   194  	return conf
   195  }
   196  
   197  func addTracer(tr *quic.Transport) {
   198  	if !enableQlog {
   199  		return
   200  	}
   201  	if tr.Tracer == nil {
   202  		tr.Tracer = logging.NewMultiplexedTracer(
   203  			tools.QlogTracer(GinkgoWriter),
   204  			// multiplex it with an empty tracer to check that we're correctly ignoring unset callbacks everywhere
   205  			&logging.Tracer{},
   206  		)
   207  		return
   208  	}
   209  	origTracer := tr.Tracer
   210  	tr.Tracer = logging.NewMultiplexedTracer(
   211  		tools.QlogTracer(GinkgoWriter),
   212  		origTracer,
   213  	)
   214  }
   215  
   216  var _ = BeforeEach(func() {
   217  	log.SetFlags(log.Ldate | log.Ltime | log.Lmicroseconds)
   218  
   219  	if debugLog() {
   220  		logBufOnce.Do(func() {
   221  			logBuf = &syncedBuffer{Buffer: bytes.NewBuffer(make([]byte, 0, logBufSize))}
   222  		})
   223  		utils.DefaultLogger.SetLogLevel(utils.LogLevelDebug)
   224  		log.SetOutput(logBuf)
   225  	}
   226  })
   227  
   228  func areHandshakesRunning() bool {
   229  	var b bytes.Buffer
   230  	pprof.Lookup("goroutine").WriteTo(&b, 1)
   231  	return strings.Contains(b.String(), "RunHandshake")
   232  }
   233  
   234  func areTransportsRunning() bool {
   235  	var b bytes.Buffer
   236  	pprof.Lookup("goroutine").WriteTo(&b, 1)
   237  	return strings.Contains(b.String(), "quic-go.(*Transport).listen")
   238  }
   239  
   240  var _ = AfterEach(func() {
   241  	Expect(areHandshakesRunning()).To(BeFalse())
   242  	Eventually(areTransportsRunning).Should(BeFalse())
   243  
   244  	if debugLog() {
   245  		logFile, err := os.Create(logFileName)
   246  		Expect(err).ToNot(HaveOccurred())
   247  		logFile.Write(logBuf.Bytes())
   248  		logFile.Close()
   249  		logBuf.Reset()
   250  	}
   251  })
   252  
   253  // Debug says if this test is being logged
   254  func debugLog() bool {
   255  	return len(logFileName) > 0
   256  }
   257  
   258  func scaleDuration(d time.Duration) time.Duration {
   259  	scaleFactor := 1
   260  	if f, err := strconv.Atoi(os.Getenv("TIMESCALE_FACTOR")); err == nil { // parsing "" errors, so this works fine if the env is not set
   261  		scaleFactor = f
   262  	}
   263  	Expect(scaleFactor).ToNot(BeZero())
   264  	return time.Duration(scaleFactor) * d
   265  }
   266  
   267  func newTracer(tracer *logging.ConnectionTracer) func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer {
   268  	return func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer { return tracer }
   269  }
   270  
   271  type packet struct {
   272  	time   time.Time
   273  	hdr    *logging.ExtendedHeader
   274  	frames []logging.Frame
   275  }
   276  
   277  type shortHeaderPacket struct {
   278  	time   time.Time
   279  	hdr    *logging.ShortHeader
   280  	frames []logging.Frame
   281  }
   282  
   283  type packetCounter struct {
   284  	closed                     chan struct{}
   285  	sentShortHdr, rcvdShortHdr []shortHeaderPacket
   286  	rcvdLongHdr                []packet
   287  }
   288  
   289  func (t *packetCounter) getSentShortHeaderPackets() []shortHeaderPacket {
   290  	<-t.closed
   291  	return t.sentShortHdr
   292  }
   293  
   294  func (t *packetCounter) getRcvdLongHeaderPackets() []packet {
   295  	<-t.closed
   296  	return t.rcvdLongHdr
   297  }
   298  
   299  func (t *packetCounter) getRcvdShortHeaderPackets() []shortHeaderPacket {
   300  	<-t.closed
   301  	return t.rcvdShortHdr
   302  }
   303  
   304  func newPacketTracer() (*packetCounter, *logging.ConnectionTracer) {
   305  	c := &packetCounter{closed: make(chan struct{})}
   306  	return c, &logging.ConnectionTracer{
   307  		ReceivedLongHeaderPacket: func(hdr *logging.ExtendedHeader, _ logging.ByteCount, _ logging.ECN, frames []logging.Frame) {
   308  			c.rcvdLongHdr = append(c.rcvdLongHdr, packet{time: time.Now(), hdr: hdr, frames: frames})
   309  		},
   310  		ReceivedShortHeaderPacket: func(hdr *logging.ShortHeader, _ logging.ByteCount, _ logging.ECN, frames []logging.Frame) {
   311  			c.rcvdShortHdr = append(c.rcvdShortHdr, shortHeaderPacket{time: time.Now(), hdr: hdr, frames: frames})
   312  		},
   313  		SentShortHeaderPacket: func(hdr *logging.ShortHeader, _ logging.ByteCount, _ logging.ECN, ack *wire.AckFrame, frames []logging.Frame) {
   314  			if ack != nil {
   315  				frames = append(frames, ack)
   316  			}
   317  			c.sentShortHdr = append(c.sentShortHdr, shortHeaderPacket{time: time.Now(), hdr: hdr, frames: frames})
   318  		},
   319  		Close: func() { close(c.closed) },
   320  	}
   321  }
   322  
   323  func TestSelf(t *testing.T) {
   324  	RegisterFailHandler(Fail)
   325  	RunSpecs(t, "Self integration tests")
   326  }