github.com/tumi8/quic-go@v0.37.4-tum/integrationtests/self/self_suite_test.go (about)

     1  package self_test
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"crypto/tls"
     7  	"crypto/x509"
     8  	"flag"
     9  	"fmt"
    10  	"log"
    11  	"os"
    12  	"runtime/pprof"
    13  	"strconv"
    14  	"strings"
    15  	"sync"
    16  	"testing"
    17  	"time"
    18  
    19  	"github.com/tumi8/quic-go"
    20  	"github.com/tumi8/quic-go/integrationtests/tools"
    21  	"github.com/tumi8/quic-go/noninternal/protocol"
    22  	"github.com/tumi8/quic-go/noninternal/utils"
    23  	"github.com/tumi8/quic-go/noninternal/wire"
    24  	"github.com/tumi8/quic-go/logging"
    25  
    26  	. "github.com/onsi/ginkgo/v2"
    27  	. "github.com/onsi/gomega"
    28  )
    29  
    30  const alpn = tools.ALPN
    31  
    32  const (
    33  	dataLen     = 500 * 1024       // 500 KB
    34  	dataLenLong = 50 * 1024 * 1024 // 50 MB
    35  )
    36  
    37  var (
    38  	// PRData contains dataLen bytes of pseudo-random data.
    39  	PRData = GeneratePRData(dataLen)
    40  	// PRDataLong contains dataLenLong bytes of pseudo-random data.
    41  	PRDataLong = GeneratePRData(dataLenLong)
    42  )
    43  
    44  // See https://en.wikipedia.org/wiki/Lehmer_random_number_generator
    45  func GeneratePRData(l int) []byte {
    46  	res := make([]byte, l)
    47  	seed := uint64(1)
    48  	for i := 0; i < l; i++ {
    49  		seed = seed * 48271 % 2147483647
    50  		res[i] = byte(seed)
    51  	}
    52  	return res
    53  }
    54  
    55  const logBufSize = 100 * 1 << 20 // initial size of the log buffer: 100 MB
    56  
    57  type syncedBuffer struct {
    58  	mutex sync.Mutex
    59  
    60  	*bytes.Buffer
    61  }
    62  
    63  func (b *syncedBuffer) Write(p []byte) (int, error) {
    64  	b.mutex.Lock()
    65  	n, err := b.Buffer.Write(p)
    66  	b.mutex.Unlock()
    67  	return n, err
    68  }
    69  
    70  func (b *syncedBuffer) Bytes() []byte {
    71  	b.mutex.Lock()
    72  	p := b.Buffer.Bytes()
    73  	b.mutex.Unlock()
    74  	return p
    75  }
    76  
    77  func (b *syncedBuffer) Reset() {
    78  	b.mutex.Lock()
    79  	b.Buffer.Reset()
    80  	b.mutex.Unlock()
    81  }
    82  
    83  var (
    84  	logFileName  string // the log file set in the ginkgo flags
    85  	logBufOnce   sync.Once
    86  	logBuf       *syncedBuffer
    87  	versionParam string
    88  
    89  	qlogTracer func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer
    90  	enableQlog bool
    91  
    92  	version                          quic.VersionNumber
    93  	tlsConfig                        *tls.Config
    94  	tlsConfigLongChain               *tls.Config
    95  	tlsClientConfig                  *tls.Config
    96  	tlsClientConfigWithoutServerName *tls.Config
    97  )
    98  
    99  // read the logfile command line flag
   100  // to set call ginkgo -- -logfile=log.txt
   101  func init() {
   102  	flag.StringVar(&logFileName, "logfile", "", "log file")
   103  	flag.StringVar(&versionParam, "version", "1", "QUIC version")
   104  	flag.BoolVar(&enableQlog, "qlog", false, "enable qlog")
   105  
   106  	ca, caPrivateKey, err := tools.GenerateCA()
   107  	if err != nil {
   108  		panic(err)
   109  	}
   110  	leafCert, leafPrivateKey, err := tools.GenerateLeafCert(ca, caPrivateKey)
   111  	if err != nil {
   112  		panic(err)
   113  	}
   114  	tlsConfig = &tls.Config{
   115  		Certificates: []tls.Certificate{{
   116  			Certificate: [][]byte{leafCert.Raw},
   117  			PrivateKey:  leafPrivateKey,
   118  		}},
   119  		NextProtos: []string{alpn},
   120  	}
   121  	tlsConfLongChain, err := tools.GenerateTLSConfigWithLongCertChain(ca, caPrivateKey)
   122  	if err != nil {
   123  		panic(err)
   124  	}
   125  	tlsConfigLongChain = tlsConfLongChain
   126  
   127  	root := x509.NewCertPool()
   128  	root.AddCert(ca)
   129  	tlsClientConfig = &tls.Config{
   130  		ServerName: "localhost",
   131  		RootCAs:    root,
   132  		NextProtos: []string{alpn},
   133  	}
   134  	tlsClientConfigWithoutServerName = &tls.Config{
   135  		RootCAs:    root,
   136  		NextProtos: []string{alpn},
   137  	}
   138  }
   139  
   140  var _ = BeforeSuite(func() {
   141  	if enableQlog {
   142  		qlogTracer = tools.NewQlogger(GinkgoWriter)
   143  	}
   144  	switch versionParam {
   145  	case "1":
   146  		version = quic.Version1
   147  	case "2":
   148  		version = quic.Version2
   149  	default:
   150  		Fail(fmt.Sprintf("unknown QUIC version: %s", versionParam))
   151  	}
   152  	fmt.Printf("Using QUIC version: %s\n", version)
   153  	protocol.SupportedVersions = []quic.VersionNumber{version}
   154  })
   155  
   156  func getTLSConfig() *tls.Config {
   157  	return tlsConfig.Clone()
   158  }
   159  
   160  func getTLSConfigWithLongCertChain() *tls.Config {
   161  	return tlsConfigLongChain.Clone()
   162  }
   163  
   164  func getTLSClientConfig() *tls.Config {
   165  	return tlsClientConfig.Clone()
   166  }
   167  
   168  func getTLSClientConfigWithoutServerName() *tls.Config {
   169  	return tlsClientConfigWithoutServerName.Clone()
   170  }
   171  
   172  func getQuicConfig(conf *quic.Config) *quic.Config {
   173  	if conf == nil {
   174  		conf = &quic.Config{}
   175  	} else {
   176  		conf = conf.Clone()
   177  	}
   178  	if enableQlog {
   179  		if conf.Tracer == nil {
   180  			conf.Tracer = qlogTracer
   181  		} else if qlogTracer != nil {
   182  			origTracer := conf.Tracer
   183  			conf.Tracer = func(ctx context.Context, p logging.Perspective, connID quic.ConnectionID) logging.ConnectionTracer {
   184  				return logging.NewMultiplexedConnectionTracer(
   185  					qlogTracer(ctx, p, connID),
   186  					origTracer(ctx, p, connID),
   187  				)
   188  			}
   189  		}
   190  	}
   191  	return conf
   192  }
   193  
   194  var _ = BeforeEach(func() {
   195  	log.SetFlags(log.Ldate | log.Ltime | log.Lmicroseconds)
   196  
   197  	if debugLog() {
   198  		logBufOnce.Do(func() {
   199  			logBuf = &syncedBuffer{Buffer: bytes.NewBuffer(make([]byte, 0, logBufSize))}
   200  		})
   201  		utils.DefaultLogger.SetLogLevel(utils.LogLevelDebug)
   202  		log.SetOutput(logBuf)
   203  	}
   204  })
   205  
   206  func areHandshakesRunning() bool {
   207  	var b bytes.Buffer
   208  	pprof.Lookup("goroutine").WriteTo(&b, 1)
   209  	return strings.Contains(b.String(), "RunHandshake")
   210  }
   211  
   212  func areTransportsRunning() bool {
   213  	var b bytes.Buffer
   214  	pprof.Lookup("goroutine").WriteTo(&b, 1)
   215  	return strings.Contains(b.String(), "quic-go.(*Transport).listen")
   216  }
   217  
   218  var _ = AfterEach(func() {
   219  	Expect(areHandshakesRunning()).To(BeFalse())
   220  	Eventually(areTransportsRunning).Should(BeFalse())
   221  
   222  	if debugLog() {
   223  		logFile, err := os.Create(logFileName)
   224  		Expect(err).ToNot(HaveOccurred())
   225  		logFile.Write(logBuf.Bytes())
   226  		logFile.Close()
   227  		logBuf.Reset()
   228  	}
   229  })
   230  
   231  // Debug says if this test is being logged
   232  func debugLog() bool {
   233  	return len(logFileName) > 0
   234  }
   235  
   236  func scaleDuration(d time.Duration) time.Duration {
   237  	scaleFactor := 1
   238  	if f, err := strconv.Atoi(os.Getenv("TIMESCALE_FACTOR")); err == nil { // parsing "" errors, so this works fine if the env is not set
   239  		scaleFactor = f
   240  	}
   241  	Expect(scaleFactor).ToNot(BeZero())
   242  	return time.Duration(scaleFactor) * d
   243  }
   244  
   245  func newTracer(tracer logging.ConnectionTracer) func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer {
   246  	return func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer { return tracer }
   247  }
   248  
   249  type packet struct {
   250  	time   time.Time
   251  	hdr    *logging.ExtendedHeader
   252  	frames []logging.Frame
   253  }
   254  
   255  type shortHeaderPacket struct {
   256  	time   time.Time
   257  	hdr    *logging.ShortHeader
   258  	frames []logging.Frame
   259  }
   260  
   261  type packetTracer struct {
   262  	logging.NullConnectionTracer
   263  	closed                     chan struct{}
   264  	sentShortHdr, rcvdShortHdr []shortHeaderPacket
   265  	rcvdLongHdr                []packet
   266  }
   267  
   268  func newPacketTracer() *packetTracer {
   269  	return &packetTracer{closed: make(chan struct{})}
   270  }
   271  
   272  func (t *packetTracer) ReceivedLongHeaderPacket(hdr *logging.ExtendedHeader, _ logging.ByteCount, frames []logging.Frame) {
   273  	t.rcvdLongHdr = append(t.rcvdLongHdr, packet{time: time.Now(), hdr: hdr, frames: frames})
   274  }
   275  
   276  func (t *packetTracer) ReceivedShortHeaderPacket(hdr *logging.ShortHeader, _ logging.ByteCount, frames []logging.Frame) {
   277  	t.rcvdShortHdr = append(t.rcvdShortHdr, shortHeaderPacket{time: time.Now(), hdr: hdr, frames: frames})
   278  }
   279  
   280  func (t *packetTracer) SentShortHeaderPacket(hdr *logging.ShortHeader, _ logging.ByteCount, ack *wire.AckFrame, frames []logging.Frame) {
   281  	if ack != nil {
   282  		frames = append(frames, ack)
   283  	}
   284  	t.sentShortHdr = append(t.sentShortHdr, shortHeaderPacket{time: time.Now(), hdr: hdr, frames: frames})
   285  }
   286  
   287  func (t *packetTracer) Close() { close(t.closed) }
   288  
   289  func (t *packetTracer) getSentShortHeaderPackets() []shortHeaderPacket {
   290  	<-t.closed
   291  	return t.sentShortHdr
   292  }
   293  
   294  func (t *packetTracer) getRcvdLongHeaderPackets() []packet {
   295  	<-t.closed
   296  	return t.rcvdLongHdr
   297  }
   298  
   299  func (t *packetTracer) getRcvdShortHeaderPackets() []shortHeaderPacket {
   300  	<-t.closed
   301  	return t.rcvdShortHdr
   302  }
   303  
   304  func TestSelf(t *testing.T) {
   305  	RegisterFailHandler(Fail)
   306  	RunSpecs(t, "Self integration tests")
   307  }