github.com/mikelsr/quic-go@v0.36.1-0.20230701132136-1d9415b66898/integrationtests/self/self_suite_test.go (about) 1 package self_test 2 3 import ( 4 "bytes" 5 "context" 6 "crypto/tls" 7 "crypto/x509" 8 "flag" 9 "fmt" 10 "log" 11 mrand "math/rand" 12 "os" 13 "runtime/pprof" 14 "strconv" 15 "strings" 16 "sync" 17 "testing" 18 "time" 19 20 "github.com/mikelsr/quic-go" 21 "github.com/mikelsr/quic-go/integrationtests/tools" 22 "github.com/mikelsr/quic-go/internal/protocol" 23 "github.com/mikelsr/quic-go/internal/utils" 24 "github.com/mikelsr/quic-go/internal/wire" 25 "github.com/mikelsr/quic-go/logging" 26 27 . "github.com/onsi/ginkgo/v2" 28 . "github.com/onsi/gomega" 29 ) 30 31 const alpn = tools.ALPN 32 33 const ( 34 dataLen = 500 * 1024 // 500 KB 35 dataLenLong = 50 * 1024 * 1024 // 50 MB 36 ) 37 38 var ( 39 // PRData contains dataLen bytes of pseudo-random data. 40 PRData = GeneratePRData(dataLen) 41 // PRDataLong contains dataLenLong bytes of pseudo-random data. 42 PRDataLong = GeneratePRData(dataLenLong) 43 ) 44 45 // See https://en.wikipedia.org/wiki/Lehmer_random_number_generator 46 func GeneratePRData(l int) []byte { 47 res := make([]byte, l) 48 seed := uint64(1) 49 for i := 0; i < l; i++ { 50 seed = seed * 48271 % 2147483647 51 res[i] = byte(seed) 52 } 53 return res 54 } 55 56 const logBufSize = 100 * 1 << 20 // initial size of the log buffer: 100 MB 57 58 type syncedBuffer struct { 59 mutex sync.Mutex 60 61 *bytes.Buffer 62 } 63 64 func (b *syncedBuffer) Write(p []byte) (int, error) { 65 b.mutex.Lock() 66 n, err := b.Buffer.Write(p) 67 b.mutex.Unlock() 68 return n, err 69 } 70 71 func (b *syncedBuffer) Bytes() []byte { 72 b.mutex.Lock() 73 p := b.Buffer.Bytes() 74 b.mutex.Unlock() 75 return p 76 } 77 78 func (b *syncedBuffer) Reset() { 79 b.mutex.Lock() 80 b.Buffer.Reset() 81 b.mutex.Unlock() 82 } 83 84 var ( 85 logFileName string // the log file set in the ginkgo flags 86 logBufOnce sync.Once 87 logBuf *syncedBuffer 88 versionParam string 89 90 qlogTracer func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer 91 enableQlog bool 92 93 version quic.VersionNumber 94 tlsConfig *tls.Config 95 tlsConfigLongChain *tls.Config 96 tlsClientConfig *tls.Config 97 tlsClientConfigWithoutServerName *tls.Config 98 ) 99 100 // read the logfile command line flag 101 // to set call ginkgo -- -logfile=log.txt 102 func init() { 103 flag.StringVar(&logFileName, "logfile", "", "log file") 104 flag.StringVar(&versionParam, "version", "1", "QUIC version") 105 flag.BoolVar(&enableQlog, "qlog", false, "enable qlog") 106 107 ca, caPrivateKey, err := tools.GenerateCA() 108 if err != nil { 109 panic(err) 110 } 111 leafCert, leafPrivateKey, err := tools.GenerateLeafCert(ca, caPrivateKey) 112 if err != nil { 113 panic(err) 114 } 115 tlsConfig = &tls.Config{ 116 Certificates: []tls.Certificate{{ 117 Certificate: [][]byte{leafCert.Raw}, 118 PrivateKey: leafPrivateKey, 119 }}, 120 NextProtos: []string{alpn}, 121 } 122 tlsConfLongChain, err := tools.GenerateTLSConfigWithLongCertChain(ca, caPrivateKey) 123 if err != nil { 124 panic(err) 125 } 126 tlsConfigLongChain = tlsConfLongChain 127 128 root := x509.NewCertPool() 129 root.AddCert(ca) 130 tlsClientConfig = &tls.Config{ 131 ServerName: "localhost", 132 RootCAs: root, 133 NextProtos: []string{alpn}, 134 } 135 tlsClientConfigWithoutServerName = &tls.Config{ 136 RootCAs: root, 137 NextProtos: []string{alpn}, 138 } 139 } 140 141 var _ = BeforeSuite(func() { 142 mrand.Seed(GinkgoRandomSeed()) 143 144 if enableQlog { 145 qlogTracer = tools.NewQlogger(GinkgoWriter) 146 } 147 switch versionParam { 148 case "1": 149 version = quic.Version1 150 case "2": 151 version = quic.Version2 152 default: 153 Fail(fmt.Sprintf("unknown QUIC version: %s", versionParam)) 154 } 155 fmt.Printf("Using QUIC version: %s\n", version) 156 protocol.SupportedVersions = []quic.VersionNumber{version} 157 }) 158 159 func getTLSConfig() *tls.Config { 160 return tlsConfig.Clone() 161 } 162 163 func getTLSConfigWithLongCertChain() *tls.Config { 164 return tlsConfigLongChain.Clone() 165 } 166 167 func getTLSClientConfig() *tls.Config { 168 return tlsClientConfig.Clone() 169 } 170 171 func getTLSClientConfigWithoutServerName() *tls.Config { 172 return tlsClientConfigWithoutServerName.Clone() 173 } 174 175 func getQuicConfig(conf *quic.Config) *quic.Config { 176 if conf == nil { 177 conf = &quic.Config{} 178 } else { 179 conf = conf.Clone() 180 } 181 if enableQlog { 182 if conf.Tracer == nil { 183 conf.Tracer = qlogTracer 184 } else if qlogTracer != nil { 185 origTracer := conf.Tracer 186 conf.Tracer = func(ctx context.Context, p logging.Perspective, connID quic.ConnectionID) logging.ConnectionTracer { 187 return logging.NewMultiplexedConnectionTracer( 188 qlogTracer(ctx, p, connID), 189 origTracer(ctx, p, connID), 190 ) 191 } 192 } 193 } 194 return conf 195 } 196 197 var _ = BeforeEach(func() { 198 log.SetFlags(log.Ldate | log.Ltime | log.Lmicroseconds) 199 200 if debugLog() { 201 logBufOnce.Do(func() { 202 logBuf = &syncedBuffer{Buffer: bytes.NewBuffer(make([]byte, 0, logBufSize))} 203 }) 204 utils.DefaultLogger.SetLogLevel(utils.LogLevelDebug) 205 log.SetOutput(logBuf) 206 } 207 }) 208 209 func areHandshakesRunning() bool { 210 var b bytes.Buffer 211 pprof.Lookup("goroutine").WriteTo(&b, 1) 212 return strings.Contains(b.String(), "RunHandshake") 213 } 214 215 func areTransportsRunning() bool { 216 var b bytes.Buffer 217 pprof.Lookup("goroutine").WriteTo(&b, 1) 218 return strings.Contains(b.String(), "quic-go.(*Transport).listen") 219 } 220 221 var _ = AfterEach(func() { 222 Expect(areHandshakesRunning()).To(BeFalse()) 223 Eventually(areTransportsRunning).Should(BeFalse()) 224 225 if debugLog() { 226 logFile, err := os.Create(logFileName) 227 Expect(err).ToNot(HaveOccurred()) 228 logFile.Write(logBuf.Bytes()) 229 logFile.Close() 230 logBuf.Reset() 231 } 232 }) 233 234 // Debug says if this test is being logged 235 func debugLog() bool { 236 return len(logFileName) > 0 237 } 238 239 func scaleDuration(d time.Duration) time.Duration { 240 scaleFactor := 1 241 if f, err := strconv.Atoi(os.Getenv("TIMESCALE_FACTOR")); err == nil { // parsing "" errors, so this works fine if the env is not set 242 scaleFactor = f 243 } 244 Expect(scaleFactor).ToNot(BeZero()) 245 return time.Duration(scaleFactor) * d 246 } 247 248 func newTracer(tracer logging.ConnectionTracer) func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer { 249 return func(context.Context, logging.Perspective, quic.ConnectionID) logging.ConnectionTracer { return tracer } 250 } 251 252 type packet struct { 253 time time.Time 254 hdr *logging.ExtendedHeader 255 frames []logging.Frame 256 } 257 258 type shortHeaderPacket struct { 259 time time.Time 260 hdr *logging.ShortHeader 261 frames []logging.Frame 262 } 263 264 type packetTracer struct { 265 logging.NullConnectionTracer 266 closed chan struct{} 267 sentShortHdr, rcvdShortHdr []shortHeaderPacket 268 rcvdLongHdr []packet 269 } 270 271 func newPacketTracer() *packetTracer { 272 return &packetTracer{closed: make(chan struct{})} 273 } 274 275 func (t *packetTracer) ReceivedLongHeaderPacket(hdr *logging.ExtendedHeader, _ logging.ByteCount, frames []logging.Frame) { 276 t.rcvdLongHdr = append(t.rcvdLongHdr, packet{time: time.Now(), hdr: hdr, frames: frames}) 277 } 278 279 func (t *packetTracer) ReceivedShortHeaderPacket(hdr *logging.ShortHeader, _ logging.ByteCount, frames []logging.Frame) { 280 t.rcvdShortHdr = append(t.rcvdShortHdr, shortHeaderPacket{time: time.Now(), hdr: hdr, frames: frames}) 281 } 282 283 func (t *packetTracer) SentShortHeaderPacket(hdr *logging.ShortHeader, _ logging.ByteCount, ack *wire.AckFrame, frames []logging.Frame) { 284 if ack != nil { 285 frames = append(frames, ack) 286 } 287 t.sentShortHdr = append(t.sentShortHdr, shortHeaderPacket{time: time.Now(), hdr: hdr, frames: frames}) 288 } 289 290 func (t *packetTracer) Close() { close(t.closed) } 291 292 func (t *packetTracer) getSentShortHeaderPackets() []shortHeaderPacket { 293 <-t.closed 294 return t.sentShortHdr 295 } 296 297 func (t *packetTracer) getRcvdLongHeaderPackets() []packet { 298 <-t.closed 299 return t.rcvdLongHdr 300 } 301 302 func (t *packetTracer) getRcvdShortHeaderPackets() []shortHeaderPacket { 303 <-t.closed 304 return t.rcvdShortHdr 305 } 306 307 func TestSelf(t *testing.T) { 308 RegisterFailHandler(Fail) 309 RunSpecs(t, "Self integration tests") 310 }