github.com/TugasAkhir-QUIC/quic-go@v0.0.2-0.20240215011318-d20e25a9054c/integrationtests/self/self_suite_test.go (about)

     1  package self_test
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"crypto/tls"
     7  	"crypto/x509"
     8  	"flag"
     9  	"fmt"
    10  	"log"
    11  	"os"
    12  	"runtime/pprof"
    13  	"strconv"
    14  	"strings"
    15  	"sync"
    16  	"testing"
    17  	"time"
    18  
    19  	"github.com/TugasAkhir-QUIC/quic-go"
    20  	"github.com/TugasAkhir-QUIC/quic-go/integrationtests/tools"
    21  	"github.com/TugasAkhir-QUIC/quic-go/internal/protocol"
    22  	"github.com/TugasAkhir-QUIC/quic-go/internal/utils"
    23  	"github.com/TugasAkhir-QUIC/quic-go/internal/wire"
    24  	"github.com/TugasAkhir-QUIC/quic-go/logging"
    25  
    26  	. "github.com/onsi/ginkgo/v2"
    27  	. "github.com/onsi/gomega"
    28  )
    29  
    30  const alpn = tools.ALPN
    31  
    32  const (
    33  	dataLen     = 500 * 1024       // 500 KB
    34  	dataLenLong = 50 * 1024 * 1024 // 50 MB
    35  )
    36  
    37  var (
    38  	// PRData contains dataLen bytes of pseudo-random data.
    39  	PRData = GeneratePRData(dataLen)
    40  	// PRDataLong contains dataLenLong bytes of pseudo-random data.
    41  	PRDataLong = GeneratePRData(dataLenLong)
    42  )
    43  
    44  // See https://en.wikipedia.org/wiki/Lehmer_random_number_generator
    45  func GeneratePRData(l int) []byte {
    46  	res := make([]byte, l)
    47  	seed := uint64(1)
    48  	for i := 0; i < l; i++ {
    49  		seed = seed * 48271 % 2147483647
    50  		res[i] = byte(seed)
    51  	}
    52  	return res
    53  }
    54  
    55  const logBufSize = 100 * 1 << 20 // initial size of the log buffer: 100 MB
    56  
    57  type syncedBuffer struct {
    58  	mutex sync.Mutex
    59  
    60  	*bytes.Buffer
    61  }
    62  
    63  func (b *syncedBuffer) Write(p []byte) (int, error) {
    64  	b.mutex.Lock()
    65  	n, err := b.Buffer.Write(p)
    66  	b.mutex.Unlock()
    67  	return n, err
    68  }
    69  
    70  func (b *syncedBuffer) Bytes() []byte {
    71  	b.mutex.Lock()
    72  	p := b.Buffer.Bytes()
    73  	b.mutex.Unlock()
    74  	return p
    75  }
    76  
    77  func (b *syncedBuffer) Reset() {
    78  	b.mutex.Lock()
    79  	b.Buffer.Reset()
    80  	b.mutex.Unlock()
    81  }
    82  
    83  var (
    84  	logFileName  string // the log file set in the ginkgo flags
    85  	logBufOnce   sync.Once
    86  	logBuf       *syncedBuffer
    87  	versionParam string
    88  
    89  	qlogTracer func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer
    90  	enableQlog bool
    91  
    92  	version                          quic.Version
    93  	tlsConfig                        *tls.Config
    94  	tlsConfigLongChain               *tls.Config
    95  	tlsClientConfig                  *tls.Config
    96  	tlsClientConfigWithoutServerName *tls.Config
    97  )
    98  
    99  // read the logfile command line flag
   100  // to set call ginkgo -- -logfile=log.txt
   101  func init() {
   102  	flag.StringVar(&logFileName, "logfile", "", "log file")
   103  	flag.StringVar(&versionParam, "version", "1", "QUIC version")
   104  	flag.BoolVar(&enableQlog, "qlog", false, "enable qlog")
   105  
   106  	ca, caPrivateKey, err := tools.GenerateCA()
   107  	if err != nil {
   108  		panic(err)
   109  	}
   110  	leafCert, leafPrivateKey, err := tools.GenerateLeafCert(ca, caPrivateKey)
   111  	if err != nil {
   112  		panic(err)
   113  	}
   114  	tlsConfig = &tls.Config{
   115  		Certificates: []tls.Certificate{{
   116  			Certificate: [][]byte{leafCert.Raw},
   117  			PrivateKey:  leafPrivateKey,
   118  		}},
   119  		NextProtos: []string{alpn},
   120  	}
   121  	tlsConfLongChain, err := tools.GenerateTLSConfigWithLongCertChain(ca, caPrivateKey)
   122  	if err != nil {
   123  		panic(err)
   124  	}
   125  	tlsConfigLongChain = tlsConfLongChain
   126  
   127  	root := x509.NewCertPool()
   128  	root.AddCert(ca)
   129  	tlsClientConfig = &tls.Config{
   130  		ServerName: "localhost",
   131  		RootCAs:    root,
   132  		NextProtos: []string{alpn},
   133  	}
   134  	tlsClientConfigWithoutServerName = &tls.Config{
   135  		RootCAs:    root,
   136  		NextProtos: []string{alpn},
   137  	}
   138  }
   139  
   140  var _ = BeforeSuite(func() {
   141  	if enableQlog {
   142  		qlogTracer = tools.NewQlogger(GinkgoWriter)
   143  	}
   144  	switch versionParam {
   145  	case "1":
   146  		version = quic.Version1
   147  	case "2":
   148  		version = quic.Version2
   149  	default:
   150  		Fail(fmt.Sprintf("unknown QUIC version: %s", versionParam))
   151  	}
   152  	fmt.Printf("Using QUIC version: %s\n", version)
   153  	protocol.SupportedVersions = []quic.Version{version}
   154  })
   155  
   156  func getTLSConfig() *tls.Config {
   157  	return tlsConfig.Clone()
   158  }
   159  
   160  func getTLSConfigWithLongCertChain() *tls.Config {
   161  	return tlsConfigLongChain.Clone()
   162  }
   163  
   164  func getTLSClientConfig() *tls.Config {
   165  	return tlsClientConfig.Clone()
   166  }
   167  
   168  func getTLSClientConfigWithoutServerName() *tls.Config {
   169  	return tlsClientConfigWithoutServerName.Clone()
   170  }
   171  
   172  func getQuicConfig(conf *quic.Config) *quic.Config {
   173  	if conf == nil {
   174  		conf = &quic.Config{}
   175  	} else {
   176  		conf = conf.Clone()
   177  	}
   178  	if enableQlog {
   179  		if conf.Tracer == nil {
   180  			conf.Tracer = func(ctx context.Context, p logging.Perspective, connID quic.ConnectionID) *logging.ConnectionTracer {
   181  				return logging.NewMultiplexedConnectionTracer(
   182  					qlogTracer(ctx, p, connID),
   183  					// multiplex it with an empty tracer to check that we're correctly ignoring unset callbacks everywhere
   184  					&logging.ConnectionTracer{},
   185  				)
   186  			}
   187  		} else if qlogTracer != nil {
   188  			origTracer := conf.Tracer
   189  			conf.Tracer = func(ctx context.Context, p logging.Perspective, connID quic.ConnectionID) *logging.ConnectionTracer {
   190  				return logging.NewMultiplexedConnectionTracer(
   191  					qlogTracer(ctx, p, connID),
   192  					origTracer(ctx, p, connID),
   193  				)
   194  			}
   195  		}
   196  	}
   197  	return conf
   198  }
   199  
   200  var _ = BeforeEach(func() {
   201  	log.SetFlags(log.Ldate | log.Ltime | log.Lmicroseconds)
   202  
   203  	if debugLog() {
   204  		logBufOnce.Do(func() {
   205  			logBuf = &syncedBuffer{Buffer: bytes.NewBuffer(make([]byte, 0, logBufSize))}
   206  		})
   207  		utils.DefaultLogger.SetLogLevel(utils.LogLevelDebug)
   208  		log.SetOutput(logBuf)
   209  	}
   210  })
   211  
   212  func areHandshakesRunning() bool {
   213  	var b bytes.Buffer
   214  	pprof.Lookup("goroutine").WriteTo(&b, 1)
   215  	return strings.Contains(b.String(), "RunHandshake")
   216  }
   217  
   218  func areTransportsRunning() bool {
   219  	var b bytes.Buffer
   220  	pprof.Lookup("goroutine").WriteTo(&b, 1)
   221  	return strings.Contains(b.String(), "quic-go.(*Transport).listen")
   222  }
   223  
   224  var _ = AfterEach(func() {
   225  	Expect(areHandshakesRunning()).To(BeFalse())
   226  	Eventually(areTransportsRunning).Should(BeFalse())
   227  
   228  	if debugLog() {
   229  		logFile, err := os.Create(logFileName)
   230  		Expect(err).ToNot(HaveOccurred())
   231  		logFile.Write(logBuf.Bytes())
   232  		logFile.Close()
   233  		logBuf.Reset()
   234  	}
   235  })
   236  
   237  // Debug says if this test is being logged
   238  func debugLog() bool {
   239  	return len(logFileName) > 0
   240  }
   241  
   242  func scaleDuration(d time.Duration) time.Duration {
   243  	scaleFactor := 1
   244  	if f, err := strconv.Atoi(os.Getenv("TIMESCALE_FACTOR")); err == nil { // parsing "" errors, so this works fine if the env is not set
   245  		scaleFactor = f
   246  	}
   247  	Expect(scaleFactor).ToNot(BeZero())
   248  	return time.Duration(scaleFactor) * d
   249  }
   250  
   251  func newTracer(tracer *logging.ConnectionTracer) func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer {
   252  	return func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer { return tracer }
   253  }
   254  
   255  type packet struct {
   256  	time   time.Time
   257  	hdr    *logging.ExtendedHeader
   258  	frames []logging.Frame
   259  }
   260  
   261  type shortHeaderPacket struct {
   262  	time   time.Time
   263  	hdr    *logging.ShortHeader
   264  	frames []logging.Frame
   265  }
   266  
   267  type packetCounter struct {
   268  	closed                     chan struct{}
   269  	sentShortHdr, rcvdShortHdr []shortHeaderPacket
   270  	rcvdLongHdr                []packet
   271  }
   272  
   273  func (t *packetCounter) getSentShortHeaderPackets() []shortHeaderPacket {
   274  	<-t.closed
   275  	return t.sentShortHdr
   276  }
   277  
   278  func (t *packetCounter) getRcvdLongHeaderPackets() []packet {
   279  	<-t.closed
   280  	return t.rcvdLongHdr
   281  }
   282  
   283  func (t *packetCounter) getRcvdShortHeaderPackets() []shortHeaderPacket {
   284  	<-t.closed
   285  	return t.rcvdShortHdr
   286  }
   287  
   288  func newPacketTracer() (*packetCounter, *logging.ConnectionTracer) {
   289  	c := &packetCounter{closed: make(chan struct{})}
   290  	return c, &logging.ConnectionTracer{
   291  		ReceivedLongHeaderPacket: func(hdr *logging.ExtendedHeader, _ logging.ByteCount, _ logging.ECN, frames []logging.Frame) {
   292  			c.rcvdLongHdr = append(c.rcvdLongHdr, packet{time: time.Now(), hdr: hdr, frames: frames})
   293  		},
   294  		ReceivedShortHeaderPacket: func(hdr *logging.ShortHeader, _ logging.ByteCount, _ logging.ECN, frames []logging.Frame) {
   295  			c.rcvdShortHdr = append(c.rcvdShortHdr, shortHeaderPacket{time: time.Now(), hdr: hdr, frames: frames})
   296  		},
   297  		SentShortHeaderPacket: func(hdr *logging.ShortHeader, _ logging.ByteCount, _ logging.ECN, ack *wire.AckFrame, frames []logging.Frame) {
   298  			if ack != nil {
   299  				frames = append(frames, ack)
   300  			}
   301  			c.sentShortHdr = append(c.sentShortHdr, shortHeaderPacket{time: time.Now(), hdr: hdr, frames: frames})
   302  		},
   303  		Close: func() { close(c.closed) },
   304  	}
   305  }
   306  
   307  func TestSelf(t *testing.T) {
   308  	RegisterFailHandler(Fail)
   309  	RunSpecs(t, "Self integration tests")
   310  }