github.com/mikelsr/quic-go@v0.36.1-0.20230701132136-1d9415b66898/integrationtests/self/self_suite_test.go (about)

     1  package self_test
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"crypto/tls"
     7  	"crypto/x509"
     8  	"flag"
     9  	"fmt"
    10  	"log"
    11  	mrand "math/rand"
    12  	"os"
    13  	"runtime/pprof"
    14  	"strconv"
    15  	"strings"
    16  	"sync"
    17  	"testing"
    18  	"time"
    19  
    20  	"github.com/mikelsr/quic-go"
    21  	"github.com/mikelsr/quic-go/integrationtests/tools"
    22  	"github.com/mikelsr/quic-go/internal/protocol"
    23  	"github.com/mikelsr/quic-go/internal/utils"
    24  	"github.com/mikelsr/quic-go/internal/wire"
    25  	"github.com/mikelsr/quic-go/logging"
    26  
    27  	. "github.com/onsi/ginkgo/v2"
    28  	. "github.com/onsi/gomega"
    29  )
    30  
    31  const alpn = tools.ALPN
    32  
    33  const (
    34  	dataLen     = 500 * 1024       // 500 KB
    35  	dataLenLong = 50 * 1024 * 1024 // 50 MB
    36  )
    37  
    38  var (
    39  	// PRData contains dataLen bytes of pseudo-random data.
    40  	PRData = GeneratePRData(dataLen)
    41  	// PRDataLong contains dataLenLong bytes of pseudo-random data.
    42  	PRDataLong = GeneratePRData(dataLenLong)
    43  )
    44  
    45  // See https://en.wikipedia.org/wiki/Lehmer_random_number_generator
    46  func GeneratePRData(l int) []byte {
    47  	res := make([]byte, l)
    48  	seed := uint64(1)
    49  	for i := 0; i < l; i++ {
    50  		seed = seed * 48271 % 2147483647
    51  		res[i] = byte(seed)
    52  	}
    53  	return res
    54  }
    55  
    56  const logBufSize = 100 * 1 << 20 // initial size of the log buffer: 100 MB
    57  
    58  type syncedBuffer struct {
    59  	mutex sync.Mutex
    60  
    61  	*bytes.Buffer
    62  }
    63  
    64  func (b *syncedBuffer) Write(p []byte) (int, error) {
    65  	b.mutex.Lock()
    66  	n, err := b.Buffer.Write(p)
    67  	b.mutex.Unlock()
    68  	return n, err
    69  }
    70  
    71  func (b *syncedBuffer) Bytes() []byte {
    72  	b.mutex.Lock()
    73  	p := b.Buffer.Bytes()
    74  	b.mutex.Unlock()
    75  	return p
    76  }
    77  
    78  func (b *syncedBuffer) Reset() {
    79  	b.mutex.Lock()
    80  	b.Buffer.Reset()
    81  	b.mutex.Unlock()
    82  }
    83  
    84  var (
    85  	logFileName  string // the log file set in the ginkgo flags
    86  	logBufOnce   sync.Once
    87  	logBuf       *syncedBuffer
    88  	versionParam string
    89  
    90  	qlogTracer func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer
    91  	enableQlog bool
    92  
    93  	version                          quic.VersionNumber
    94  	tlsConfig                        *tls.Config
    95  	tlsConfigLongChain               *tls.Config
    96  	tlsClientConfig                  *tls.Config
    97  	tlsClientConfigWithoutServerName *tls.Config
    98  )
    99  
   100  // read the logfile command line flag
   101  // to set call ginkgo -- -logfile=log.txt
   102  func init() {
   103  	flag.StringVar(&logFileName, "logfile", "", "log file")
   104  	flag.StringVar(&versionParam, "version", "1", "QUIC version")
   105  	flag.BoolVar(&enableQlog, "qlog", false, "enable qlog")
   106  
   107  	ca, caPrivateKey, err := tools.GenerateCA()
   108  	if err != nil {
   109  		panic(err)
   110  	}
   111  	leafCert, leafPrivateKey, err := tools.GenerateLeafCert(ca, caPrivateKey)
   112  	if err != nil {
   113  		panic(err)
   114  	}
   115  	tlsConfig = &tls.Config{
   116  		Certificates: []tls.Certificate{{
   117  			Certificate: [][]byte{leafCert.Raw},
   118  			PrivateKey:  leafPrivateKey,
   119  		}},
   120  		NextProtos: []string{alpn},
   121  	}
   122  	tlsConfLongChain, err := tools.GenerateTLSConfigWithLongCertChain(ca, caPrivateKey)
   123  	if err != nil {
   124  		panic(err)
   125  	}
   126  	tlsConfigLongChain = tlsConfLongChain
   127  
   128  	root := x509.NewCertPool()
   129  	root.AddCert(ca)
   130  	tlsClientConfig = &tls.Config{
   131  		ServerName: "localhost",
   132  		RootCAs:    root,
   133  		NextProtos: []string{alpn},
   134  	}
   135  	tlsClientConfigWithoutServerName = &tls.Config{
   136  		RootCAs:    root,
   137  		NextProtos: []string{alpn},
   138  	}
   139  }
   140  
   141  var _ = BeforeSuite(func() {
   142  	mrand.Seed(GinkgoRandomSeed())
   143  
   144  	if enableQlog {
   145  		qlogTracer = tools.NewQlogger(GinkgoWriter)
   146  	}
   147  	switch versionParam {
   148  	case "1":
   149  		version = quic.Version1
   150  	case "2":
   151  		version = quic.Version2
   152  	default:
   153  		Fail(fmt.Sprintf("unknown QUIC version: %s", versionParam))
   154  	}
   155  	fmt.Printf("Using QUIC version: %s\n", version)
   156  	protocol.SupportedVersions = []quic.VersionNumber{version}
   157  })
   158  
   159  func getTLSConfig() *tls.Config {
   160  	return tlsConfig.Clone()
   161  }
   162  
   163  func getTLSConfigWithLongCertChain() *tls.Config {
   164  	return tlsConfigLongChain.Clone()
   165  }
   166  
   167  func getTLSClientConfig() *tls.Config {
   168  	return tlsClientConfig.Clone()
   169  }
   170  
   171  func getTLSClientConfigWithoutServerName() *tls.Config {
   172  	return tlsClientConfigWithoutServerName.Clone()
   173  }
   174  
   175  func getQuicConfig(conf *quic.Config) *quic.Config {
   176  	if conf == nil {
   177  		conf = &quic.Config{}
   178  	} else {
   179  		conf = conf.Clone()
   180  	}
   181  	if enableQlog {
   182  		if conf.Tracer == nil {
   183  			conf.Tracer = qlogTracer
   184  		} else if qlogTracer != nil {
   185  			origTracer := conf.Tracer
   186  			conf.Tracer = func(ctx context.Context, p logging.Perspective, connID quic.ConnectionID) logging.ConnectionTracer {
   187  				return logging.NewMultiplexedConnectionTracer(
   188  					qlogTracer(ctx, p, connID),
   189  					origTracer(ctx, p, connID),
   190  				)
   191  			}
   192  		}
   193  	}
   194  	return conf
   195  }
   196  
   197  var _ = BeforeEach(func() {
   198  	log.SetFlags(log.Ldate | log.Ltime | log.Lmicroseconds)
   199  
   200  	if debugLog() {
   201  		logBufOnce.Do(func() {
   202  			logBuf = &syncedBuffer{Buffer: bytes.NewBuffer(make([]byte, 0, logBufSize))}
   203  		})
   204  		utils.DefaultLogger.SetLogLevel(utils.LogLevelDebug)
   205  		log.SetOutput(logBuf)
   206  	}
   207  })
   208  
   209  func areHandshakesRunning() bool {
   210  	var b bytes.Buffer
   211  	pprof.Lookup("goroutine").WriteTo(&b, 1)
   212  	return strings.Contains(b.String(), "RunHandshake")
   213  }
   214  
   215  func areTransportsRunning() bool {
   216  	var b bytes.Buffer
   217  	pprof.Lookup("goroutine").WriteTo(&b, 1)
   218  	return strings.Contains(b.String(), "quic-go.(*Transport).listen")
   219  }
   220  
   221  var _ = AfterEach(func() {
   222  	Expect(areHandshakesRunning()).To(BeFalse())
   223  	Eventually(areTransportsRunning).Should(BeFalse())
   224  
   225  	if debugLog() {
   226  		logFile, err := os.Create(logFileName)
   227  		Expect(err).ToNot(HaveOccurred())
   228  		logFile.Write(logBuf.Bytes())
   229  		logFile.Close()
   230  		logBuf.Reset()
   231  	}
   232  })
   233  
   234  // Debug says if this test is being logged
   235  func debugLog() bool {
   236  	return len(logFileName) > 0
   237  }
   238  
   239  func scaleDuration(d time.Duration) time.Duration {
   240  	scaleFactor := 1
   241  	if f, err := strconv.Atoi(os.Getenv("TIMESCALE_FACTOR")); err == nil { // parsing "" errors, so this works fine if the env is not set
   242  		scaleFactor = f
   243  	}
   244  	Expect(scaleFactor).ToNot(BeZero())
   245  	return time.Duration(scaleFactor) * d
   246  }
   247  
   248  func newTracer(tracer logging.ConnectionTracer) func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer {
   249  	return func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer { return tracer }
   250  }
   251  
   252  type packet struct {
   253  	time   time.Time
   254  	hdr    *logging.ExtendedHeader
   255  	frames []logging.Frame
   256  }
   257  
   258  type shortHeaderPacket struct {
   259  	time   time.Time
   260  	hdr    *logging.ShortHeader
   261  	frames []logging.Frame
   262  }
   263  
   264  type packetTracer struct {
   265  	logging.NullConnectionTracer
   266  	closed                     chan struct{}
   267  	sentShortHdr, rcvdShortHdr []shortHeaderPacket
   268  	rcvdLongHdr                []packet
   269  }
   270  
   271  func newPacketTracer() *packetTracer {
   272  	return &packetTracer{closed: make(chan struct{})}
   273  }
   274  
   275  func (t *packetTracer) ReceivedLongHeaderPacket(hdr *logging.ExtendedHeader, _ logging.ByteCount, frames []logging.Frame) {
   276  	t.rcvdLongHdr = append(t.rcvdLongHdr, packet{time: time.Now(), hdr: hdr, frames: frames})
   277  }
   278  
   279  func (t *packetTracer) ReceivedShortHeaderPacket(hdr *logging.ShortHeader, _ logging.ByteCount, frames []logging.Frame) {
   280  	t.rcvdShortHdr = append(t.rcvdShortHdr, shortHeaderPacket{time: time.Now(), hdr: hdr, frames: frames})
   281  }
   282  
   283  func (t *packetTracer) SentShortHeaderPacket(hdr *logging.ShortHeader, _ logging.ByteCount, ack *wire.AckFrame, frames []logging.Frame) {
   284  	if ack != nil {
   285  		frames = append(frames, ack)
   286  	}
   287  	t.sentShortHdr = append(t.sentShortHdr, shortHeaderPacket{time: time.Now(), hdr: hdr, frames: frames})
   288  }
   289  
   290  func (t *packetTracer) Close() { close(t.closed) }
   291  
   292  func (t *packetTracer) getSentShortHeaderPackets() []shortHeaderPacket {
   293  	<-t.closed
   294  	return t.sentShortHdr
   295  }
   296  
   297  func (t *packetTracer) getRcvdLongHeaderPackets() []packet {
   298  	<-t.closed
   299  	return t.rcvdLongHdr
   300  }
   301  
   302  func (t *packetTracer) getRcvdShortHeaderPackets() []shortHeaderPacket {
   303  	<-t.closed
   304  	return t.rcvdShortHdr
   305  }
   306  
   307  func TestSelf(t *testing.T) {
   308  	RegisterFailHandler(Fail)
   309  	RunSpecs(t, "Self integration tests")
   310  }