github.com/apernet/quic-go@v0.43.1-0.20240515053213-5e9e635fd9f0/interop/client/main.go (about)

     1  package main
     2  
     3  import (
     4  	"crypto/tls"
     5  	"errors"
     6  	"flag"
     7  	"fmt"
     8  	"io"
     9  	"log"
    10  	"net/http"
    11  	"os"
    12  	"strings"
    13  	"time"
    14  
    15  	"golang.org/x/sync/errgroup"
    16  
    17  	"github.com/apernet/quic-go"
    18  	"github.com/apernet/quic-go/http3"
    19  	"github.com/apernet/quic-go/internal/handshake"
    20  	"github.com/apernet/quic-go/internal/protocol"
    21  	"github.com/apernet/quic-go/internal/qtls"
    22  	"github.com/apernet/quic-go/interop/http09"
    23  	"github.com/apernet/quic-go/interop/utils"
    24  )
    25  
    26  var errUnsupported = errors.New("unsupported test case")
    27  
    28  var tlsConf *tls.Config
    29  
    30  func main() {
    31  	logFile, err := os.Create("/logs/log.txt")
    32  	if err != nil {
    33  		fmt.Printf("Could not create log file: %s\n", err.Error())
    34  		os.Exit(1)
    35  	}
    36  	defer logFile.Close()
    37  	log.SetOutput(logFile)
    38  
    39  	keyLog, err := utils.GetSSLKeyLog()
    40  	if err != nil {
    41  		fmt.Printf("Could not create key log: %s\n", err.Error())
    42  		os.Exit(1)
    43  	}
    44  	if keyLog != nil {
    45  		defer keyLog.Close()
    46  	}
    47  
    48  	tlsConf = &tls.Config{
    49  		InsecureSkipVerify: true,
    50  		KeyLogWriter:       keyLog,
    51  	}
    52  	testcase := os.Getenv("TESTCASE")
    53  	if err := runTestcase(testcase); err != nil {
    54  		if err == errUnsupported {
    55  			fmt.Printf("unsupported test case: %s\n", testcase)
    56  			os.Exit(127)
    57  		}
    58  		fmt.Printf("Downloading files failed: %s\n", err.Error())
    59  		os.Exit(1)
    60  	}
    61  }
    62  
    63  func runTestcase(testcase string) error {
    64  	flag.Parse()
    65  	urls := flag.Args()
    66  
    67  	quicConf := &quic.Config{Tracer: utils.NewQLOGConnectionTracer}
    68  
    69  	if testcase == "http3" {
    70  		r := &http3.RoundTripper{
    71  			TLSClientConfig: tlsConf,
    72  			QUICConfig:      quicConf,
    73  		}
    74  		defer r.Close()
    75  		return downloadFiles(r, urls, false)
    76  	}
    77  
    78  	r := &http09.RoundTripper{
    79  		TLSClientConfig: tlsConf,
    80  		QuicConfig:      quicConf,
    81  	}
    82  	defer r.Close()
    83  
    84  	switch testcase {
    85  	case "handshake", "transfer", "retry":
    86  	case "keyupdate":
    87  		handshake.FirstKeyUpdateInterval = 100
    88  	case "chacha20":
    89  		reset := qtls.SetCipherSuite(tls.TLS_CHACHA20_POLY1305_SHA256)
    90  		defer reset()
    91  	case "multiconnect":
    92  		return runMultiConnectTest(r, urls)
    93  	case "versionnegotiation":
    94  		return runVersionNegotiationTest(r, urls)
    95  	case "resumption":
    96  		return runResumptionTest(r, urls, false)
    97  	case "zerortt":
    98  		return runResumptionTest(r, urls, true)
    99  	default:
   100  		return errUnsupported
   101  	}
   102  
   103  	return downloadFiles(r, urls, false)
   104  }
   105  
   106  func runVersionNegotiationTest(r *http09.RoundTripper, urls []string) error {
   107  	if len(urls) != 1 {
   108  		return errors.New("expected at least 2 URLs")
   109  	}
   110  	protocol.SupportedVersions = []protocol.Version{0x1a2a3a4a}
   111  	err := downloadFile(r, urls[0], false)
   112  	if err == nil {
   113  		return errors.New("expected version negotiation to fail")
   114  	}
   115  	if !strings.Contains(err.Error(), "No compatible QUIC version found") {
   116  		return fmt.Errorf("expect version negotiation error, got: %s", err.Error())
   117  	}
   118  	return nil
   119  }
   120  
   121  func runMultiConnectTest(r *http09.RoundTripper, urls []string) error {
   122  	for _, url := range urls {
   123  		if err := downloadFile(r, url, false); err != nil {
   124  			return err
   125  		}
   126  		if err := r.Close(); err != nil {
   127  			return err
   128  		}
   129  	}
   130  	return nil
   131  }
   132  
   133  type sessionCache struct {
   134  	tls.ClientSessionCache
   135  	put chan<- struct{}
   136  }
   137  
   138  func newSessionCache(c tls.ClientSessionCache) (tls.ClientSessionCache, <-chan struct{}) {
   139  	put := make(chan struct{}, 100)
   140  	return &sessionCache{ClientSessionCache: c, put: put}, put
   141  }
   142  
   143  func (c *sessionCache) Put(key string, cs *tls.ClientSessionState) {
   144  	c.ClientSessionCache.Put(key, cs)
   145  	c.put <- struct{}{}
   146  }
   147  
   148  func runResumptionTest(r *http09.RoundTripper, urls []string, use0RTT bool) error {
   149  	if len(urls) < 2 {
   150  		return errors.New("expected at least 2 URLs")
   151  	}
   152  
   153  	var put <-chan struct{}
   154  	tlsConf.ClientSessionCache, put = newSessionCache(tls.NewLRUClientSessionCache(1))
   155  
   156  	// do the first transfer
   157  	if err := downloadFiles(r, urls[:1], false); err != nil {
   158  		return err
   159  	}
   160  
   161  	// wait for the session ticket to arrive
   162  	select {
   163  	case <-time.NewTimer(10 * time.Second).C:
   164  		return errors.New("expected to receive a session ticket within 10 seconds")
   165  	case <-put:
   166  	}
   167  
   168  	if err := r.Close(); err != nil {
   169  		return err
   170  	}
   171  
   172  	// reestablish the connection, using the session ticket that the server (hopefully provided)
   173  	defer r.Close()
   174  	return downloadFiles(r, urls[1:], use0RTT)
   175  }
   176  
   177  func downloadFiles(cl http.RoundTripper, urls []string, use0RTT bool) error {
   178  	var g errgroup.Group
   179  	for _, u := range urls {
   180  		url := u
   181  		g.Go(func() error {
   182  			return downloadFile(cl, url, use0RTT)
   183  		})
   184  	}
   185  	return g.Wait()
   186  }
   187  
   188  func downloadFile(cl http.RoundTripper, url string, use0RTT bool) error {
   189  	method := http.MethodGet
   190  	if use0RTT {
   191  		method = http09.MethodGet0RTT
   192  	}
   193  	req, err := http.NewRequest(method, url, nil)
   194  	if err != nil {
   195  		return err
   196  	}
   197  	rsp, err := cl.RoundTrip(req)
   198  	if err != nil {
   199  		return err
   200  	}
   201  	defer rsp.Body.Close()
   202  
   203  	file, err := os.Create("/downloads" + req.URL.Path)
   204  	if err != nil {
   205  		return err
   206  	}
   207  	defer file.Close()
   208  	_, err = io.Copy(file, rsp.Body)
   209  	return err
   210  }