github.com/apernet/quic-go@v0.43.1-0.20240515053213-5e9e635fd9f0/integrationtests/self/self_suite_test.go (about) 1 package self_test 2 3 import ( 4 "bytes" 5 "context" 6 "crypto/tls" 7 "crypto/x509" 8 "flag" 9 "fmt" 10 "log" 11 "os" 12 "runtime/pprof" 13 "strconv" 14 "strings" 15 "sync" 16 "testing" 17 "time" 18 19 "github.com/apernet/quic-go" 20 "github.com/apernet/quic-go/integrationtests/tools" 21 "github.com/apernet/quic-go/internal/protocol" 22 "github.com/apernet/quic-go/internal/utils" 23 "github.com/apernet/quic-go/internal/wire" 24 "github.com/apernet/quic-go/logging" 25 26 . "github.com/onsi/ginkgo/v2" 27 . "github.com/onsi/gomega" 28 ) 29 30 const alpn = tools.ALPN 31 32 const ( 33 dataLen = 500 * 1024 // 500 KB 34 dataLenLong = 50 * 1024 * 1024 // 50 MB 35 ) 36 37 var ( 38 // PRData contains dataLen bytes of pseudo-random data. 39 PRData = GeneratePRData(dataLen) 40 // PRDataLong contains dataLenLong bytes of pseudo-random data. 41 PRDataLong = GeneratePRData(dataLenLong) 42 ) 43 44 // See https://en.wikipedia.org/wiki/Lehmer_random_number_generator 45 func GeneratePRData(l int) []byte { 46 res := make([]byte, l) 47 seed := uint64(1) 48 for i := 0; i < l; i++ { 49 seed = seed * 48271 % 2147483647 50 res[i] = byte(seed) 51 } 52 return res 53 } 54 55 const logBufSize = 100 * 1 << 20 // initial size of the log buffer: 100 MB 56 57 type syncedBuffer struct { 58 mutex sync.Mutex 59 60 *bytes.Buffer 61 } 62 63 func (b *syncedBuffer) Write(p []byte) (int, error) { 64 b.mutex.Lock() 65 n, err := b.Buffer.Write(p) 66 b.mutex.Unlock() 67 return n, err 68 } 69 70 func (b *syncedBuffer) Bytes() []byte { 71 b.mutex.Lock() 72 p := b.Buffer.Bytes() 73 b.mutex.Unlock() 74 return p 75 } 76 77 func (b *syncedBuffer) Reset() { 78 b.mutex.Lock() 79 b.Buffer.Reset() 80 b.mutex.Unlock() 81 } 82 83 var ( 84 logFileName string // the log file set in the ginkgo flags 85 logBufOnce sync.Once 86 logBuf *syncedBuffer 87 versionParam string 88 89 enableQlog bool 90 91 version quic.Version 92 tlsConfig *tls.Config 93 tlsConfigLongChain *tls.Config 94 tlsClientConfig *tls.Config 95 tlsClientConfigWithoutServerName *tls.Config 96 ) 97 98 // read the logfile command line flag 99 // to set call ginkgo -- -logfile=log.txt 100 func init() { 101 flag.StringVar(&logFileName, "logfile", "", "log file") 102 flag.StringVar(&versionParam, "version", "1", "QUIC version") 103 flag.BoolVar(&enableQlog, "qlog", false, "enable qlog") 104 105 ca, caPrivateKey, err := tools.GenerateCA() 106 if err != nil { 107 panic(err) 108 } 109 leafCert, leafPrivateKey, err := tools.GenerateLeafCert(ca, caPrivateKey) 110 if err != nil { 111 panic(err) 112 } 113 tlsConfig = &tls.Config{ 114 Certificates: []tls.Certificate{{ 115 Certificate: [][]byte{leafCert.Raw}, 116 PrivateKey: leafPrivateKey, 117 }}, 118 NextProtos: []string{alpn}, 119 } 120 tlsConfLongChain, err := tools.GenerateTLSConfigWithLongCertChain(ca, caPrivateKey) 121 if err != nil { 122 panic(err) 123 } 124 tlsConfigLongChain = tlsConfLongChain 125 126 root := x509.NewCertPool() 127 root.AddCert(ca) 128 tlsClientConfig = &tls.Config{ 129 ServerName: "localhost", 130 RootCAs: root, 131 NextProtos: []string{alpn}, 132 } 133 tlsClientConfigWithoutServerName = &tls.Config{ 134 RootCAs: root, 135 NextProtos: []string{alpn}, 136 } 137 } 138 139 var _ = BeforeSuite(func() { 140 switch versionParam { 141 case "1": 142 version = quic.Version1 143 case "2": 144 version = quic.Version2 145 default: 146 Fail(fmt.Sprintf("unknown QUIC version: %s", versionParam)) 147 } 148 fmt.Printf("Using QUIC version: %s\n", version) 149 protocol.SupportedVersions = []quic.Version{version} 150 }) 151 152 func getTLSConfig() *tls.Config { 153 return tlsConfig.Clone() 154 } 155 156 func getTLSConfigWithLongCertChain() *tls.Config { 157 return tlsConfigLongChain.Clone() 158 } 159 160 func getTLSClientConfig() *tls.Config { 161 return tlsClientConfig.Clone() 162 } 163 164 func getTLSClientConfigWithoutServerName() *tls.Config { 165 return tlsClientConfigWithoutServerName.Clone() 166 } 167 168 func getQuicConfig(conf *quic.Config) *quic.Config { 169 if conf == nil { 170 conf = &quic.Config{} 171 } else { 172 conf = conf.Clone() 173 } 174 if !enableQlog { 175 return conf 176 } 177 if conf.Tracer == nil { 178 conf.Tracer = func(ctx context.Context, p logging.Perspective, connID quic.ConnectionID) *logging.ConnectionTracer { 179 return logging.NewMultiplexedConnectionTracer( 180 tools.NewQlogConnectionTracer(GinkgoWriter)(ctx, p, connID), 181 // multiplex it with an empty tracer to check that we're correctly ignoring unset callbacks everywhere 182 &logging.ConnectionTracer{}, 183 ) 184 } 185 return conf 186 } 187 origTracer := conf.Tracer 188 conf.Tracer = func(ctx context.Context, p logging.Perspective, connID quic.ConnectionID) *logging.ConnectionTracer { 189 return logging.NewMultiplexedConnectionTracer( 190 tools.NewQlogConnectionTracer(GinkgoWriter)(ctx, p, connID), 191 origTracer(ctx, p, connID), 192 ) 193 } 194 return conf 195 } 196 197 func addTracer(tr *quic.Transport) { 198 if !enableQlog { 199 return 200 } 201 if tr.Tracer == nil { 202 tr.Tracer = logging.NewMultiplexedTracer( 203 tools.QlogTracer(GinkgoWriter), 204 // multiplex it with an empty tracer to check that we're correctly ignoring unset callbacks everywhere 205 &logging.Tracer{}, 206 ) 207 return 208 } 209 origTracer := tr.Tracer 210 tr.Tracer = logging.NewMultiplexedTracer( 211 tools.QlogTracer(GinkgoWriter), 212 origTracer, 213 ) 214 } 215 216 var _ = BeforeEach(func() { 217 log.SetFlags(log.Ldate | log.Ltime | log.Lmicroseconds) 218 219 if debugLog() { 220 logBufOnce.Do(func() { 221 logBuf = &syncedBuffer{Buffer: bytes.NewBuffer(make([]byte, 0, logBufSize))} 222 }) 223 utils.DefaultLogger.SetLogLevel(utils.LogLevelDebug) 224 log.SetOutput(logBuf) 225 } 226 }) 227 228 func areHandshakesRunning() bool { 229 var b bytes.Buffer 230 pprof.Lookup("goroutine").WriteTo(&b, 1) 231 return strings.Contains(b.String(), "RunHandshake") 232 } 233 234 func areTransportsRunning() bool { 235 var b bytes.Buffer 236 pprof.Lookup("goroutine").WriteTo(&b, 1) 237 return strings.Contains(b.String(), "quic-go.(*Transport).listen") 238 } 239 240 var _ = AfterEach(func() { 241 Expect(areHandshakesRunning()).To(BeFalse()) 242 Eventually(areTransportsRunning).Should(BeFalse()) 243 244 if debugLog() { 245 logFile, err := os.Create(logFileName) 246 Expect(err).ToNot(HaveOccurred()) 247 logFile.Write(logBuf.Bytes()) 248 logFile.Close() 249 logBuf.Reset() 250 } 251 }) 252 253 // Debug says if this test is being logged 254 func debugLog() bool { 255 return len(logFileName) > 0 256 } 257 258 func scaleDuration(d time.Duration) time.Duration { 259 scaleFactor := 1 260 if f, err := strconv.Atoi(os.Getenv("TIMESCALE_FACTOR")); err == nil { // parsing "" errors, so this works fine if the env is not set 261 scaleFactor = f 262 } 263 Expect(scaleFactor).ToNot(BeZero()) 264 return time.Duration(scaleFactor) * d 265 } 266 267 func newTracer(tracer *logging.ConnectionTracer) func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer { 268 return func(context.Context, logging.Perspective, quic.ConnectionID) *logging.ConnectionTracer { return tracer } 269 } 270 271 type packet struct { 272 time time.Time 273 hdr *logging.ExtendedHeader 274 frames []logging.Frame 275 } 276 277 type shortHeaderPacket struct { 278 time time.Time 279 hdr *logging.ShortHeader 280 frames []logging.Frame 281 } 282 283 type packetCounter struct { 284 closed chan struct{} 285 sentShortHdr, rcvdShortHdr []shortHeaderPacket 286 rcvdLongHdr []packet 287 } 288 289 func (t *packetCounter) getSentShortHeaderPackets() []shortHeaderPacket { 290 <-t.closed 291 return t.sentShortHdr 292 } 293 294 func (t *packetCounter) getRcvdLongHeaderPackets() []packet { 295 <-t.closed 296 return t.rcvdLongHdr 297 } 298 299 func (t *packetCounter) getRcvdShortHeaderPackets() []shortHeaderPacket { 300 <-t.closed 301 return t.rcvdShortHdr 302 } 303 304 func newPacketTracer() (*packetCounter, *logging.ConnectionTracer) { 305 c := &packetCounter{closed: make(chan struct{})} 306 return c, &logging.ConnectionTracer{ 307 ReceivedLongHeaderPacket: func(hdr *logging.ExtendedHeader, _ logging.ByteCount, _ logging.ECN, frames []logging.Frame) { 308 c.rcvdLongHdr = append(c.rcvdLongHdr, packet{time: time.Now(), hdr: hdr, frames: frames}) 309 }, 310 ReceivedShortHeaderPacket: func(hdr *logging.ShortHeader, _ logging.ByteCount, _ logging.ECN, frames []logging.Frame) { 311 c.rcvdShortHdr = append(c.rcvdShortHdr, shortHeaderPacket{time: time.Now(), hdr: hdr, frames: frames}) 312 }, 313 SentShortHeaderPacket: func(hdr *logging.ShortHeader, _ logging.ByteCount, _ logging.ECN, ack *wire.AckFrame, frames []logging.Frame) { 314 if ack != nil { 315 frames = append(frames, ack) 316 } 317 c.sentShortHdr = append(c.sentShortHdr, shortHeaderPacket{time: time.Now(), hdr: hdr, frames: frames}) 318 }, 319 Close: func() { close(c.closed) }, 320 } 321 } 322 323 func TestSelf(t *testing.T) { 324 RegisterFailHandler(Fail) 325 RunSpecs(t, "Self integration tests") 326 }