golang.org/x/net@v0.25.1-0.20240516223405-c87a5b62e243/internal/quic/cmd/interop/main.go (about)

     1  // Copyright 2023 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  //go:build go1.21
     6  
     7  // The interop command is the client and server used by QUIC interoperability tests.
     8  //
     9  // https://github.com/marten-seemann/quic-interop-runner
    10  package main
    11  
    12  import (
    13  	"bytes"
    14  	"context"
    15  	"crypto/tls"
    16  	"errors"
    17  	"flag"
    18  	"fmt"
    19  	"io"
    20  	"log"
    21  	"log/slog"
    22  	"net"
    23  	"net/url"
    24  	"os"
    25  	"path/filepath"
    26  	"sync"
    27  
    28  	"golang.org/x/net/quic"
    29  	"golang.org/x/net/quic/qlog"
    30  )
    31  
    32  var (
    33  	listen  = flag.String("listen", "", "listen address")
    34  	cert    = flag.String("cert", "", "certificate")
    35  	pkey    = flag.String("key", "", "private key")
    36  	root    = flag.String("root", "", "serve files from this root")
    37  	output  = flag.String("output", "", "directory to write files to")
    38  	qlogdir = flag.String("qlog", "", "directory to write qlog output to")
    39  )
    40  
    41  func main() {
    42  	ctx := context.Background()
    43  	flag.Parse()
    44  	urls := flag.Args()
    45  
    46  	config := &quic.Config{
    47  		TLSConfig: &tls.Config{
    48  			InsecureSkipVerify: true,
    49  			MinVersion:         tls.VersionTLS13,
    50  			NextProtos:         []string{"hq-interop"},
    51  		},
    52  		MaxBidiRemoteStreams: -1,
    53  		MaxUniRemoteStreams:  -1,
    54  		QLogLogger: slog.New(qlog.NewJSONHandler(qlog.HandlerOptions{
    55  			Level: quic.QLogLevelFrame,
    56  			Dir:   *qlogdir,
    57  		})),
    58  	}
    59  	if *cert != "" {
    60  		c, err := tls.LoadX509KeyPair(*cert, *pkey)
    61  		if err != nil {
    62  			log.Fatal(err)
    63  		}
    64  		config.TLSConfig.Certificates = []tls.Certificate{c}
    65  	}
    66  	if *root != "" {
    67  		config.MaxBidiRemoteStreams = 100
    68  	}
    69  	if keylog := os.Getenv("SSLKEYLOGFILE"); keylog != "" {
    70  		f, err := os.Create(keylog)
    71  		if err != nil {
    72  			log.Fatal(err)
    73  		}
    74  		defer f.Close()
    75  		config.TLSConfig.KeyLogWriter = f
    76  	}
    77  
    78  	testcase := os.Getenv("TESTCASE")
    79  	switch testcase {
    80  	case "handshake", "keyupdate":
    81  		basicTest(ctx, config, urls)
    82  		return
    83  	case "chacha20":
    84  		// "[...] offer only ChaCha20 as a ciphersuite."
    85  		//
    86  		// crypto/tls does not support configuring TLS 1.3 ciphersuites,
    87  		// so we can't support this test.
    88  	case "transfer":
    89  		// "The client should use small initial flow control windows
    90  		// for both stream- and connection-level flow control
    91  		// such that the during the transfer of files on the order of 1 MB
    92  		// the flow control window needs to be increased."
    93  		config.MaxStreamReadBufferSize = 64 << 10
    94  		config.MaxConnReadBufferSize = 64 << 10
    95  		basicTest(ctx, config, urls)
    96  		return
    97  	case "http3":
    98  		// TODO
    99  	case "multiconnect":
   100  		// TODO
   101  	case "resumption":
   102  		// TODO
   103  	case "retry":
   104  		// TODO
   105  	case "versionnegotiation":
   106  		// "The client should start a connection using
   107  		// an unsupported version number [...]"
   108  		//
   109  		// We don't support setting the client's version,
   110  		// so only run this test as a server.
   111  		if *listen != "" && len(urls) == 0 {
   112  			basicTest(ctx, config, urls)
   113  			return
   114  		}
   115  	case "v2":
   116  		// We do not support QUIC v2.
   117  	case "zerortt":
   118  		// TODO
   119  	}
   120  	fmt.Printf("unsupported test case %q\n", testcase)
   121  	os.Exit(127)
   122  }
   123  
   124  // basicTest runs the standard test setup.
   125  //
   126  // As a server, it serves the contents of the -root directory.
   127  // As a client, it downloads all the provided URLs in parallel,
   128  // making one connection to each destination server.
   129  func basicTest(ctx context.Context, config *quic.Config, urls []string) {
   130  	l, err := quic.Listen("udp", *listen, config)
   131  	if err != nil {
   132  		log.Fatal(err)
   133  	}
   134  	log.Printf("listening on %v", l.LocalAddr())
   135  
   136  	byAuthority := map[string][]*url.URL{}
   137  	for _, s := range urls {
   138  		u, addr, err := parseURL(s)
   139  		if err != nil {
   140  			log.Fatal(err)
   141  		}
   142  		byAuthority[addr] = append(byAuthority[addr], u)
   143  	}
   144  	var g sync.WaitGroup
   145  	defer g.Wait()
   146  	for addr, u := range byAuthority {
   147  		addr, u := addr, u
   148  		g.Add(1)
   149  		go func() {
   150  			defer g.Done()
   151  			fetchFrom(ctx, config, l, addr, u)
   152  		}()
   153  	}
   154  
   155  	if config.MaxBidiRemoteStreams >= 0 {
   156  		serve(ctx, l)
   157  	}
   158  }
   159  
   160  func serve(ctx context.Context, l *quic.Endpoint) error {
   161  	for {
   162  		c, err := l.Accept(ctx)
   163  		if err != nil {
   164  			return err
   165  		}
   166  		go serveConn(ctx, c)
   167  	}
   168  }
   169  
   170  func serveConn(ctx context.Context, c *quic.Conn) {
   171  	for {
   172  		s, err := c.AcceptStream(ctx)
   173  		if err != nil {
   174  			return
   175  		}
   176  		go func() {
   177  			if err := serveReq(ctx, s); err != nil {
   178  				log.Print("serveReq:", err)
   179  			}
   180  		}()
   181  	}
   182  }
   183  
   184  func serveReq(ctx context.Context, s *quic.Stream) error {
   185  	defer s.Close()
   186  	req, err := io.ReadAll(s)
   187  	if err != nil {
   188  		return err
   189  	}
   190  	if !bytes.HasSuffix(req, []byte("\r\n")) {
   191  		return errors.New("invalid request")
   192  	}
   193  	req = bytes.TrimSuffix(req, []byte("\r\n"))
   194  	if !bytes.HasPrefix(req, []byte("GET /")) {
   195  		return errors.New("invalid request")
   196  	}
   197  	req = bytes.TrimPrefix(req, []byte("GET /"))
   198  	if !filepath.IsLocal(string(req)) {
   199  		return errors.New("invalid request")
   200  	}
   201  	f, err := os.Open(filepath.Join(*root, string(req)))
   202  	if err != nil {
   203  		return err
   204  	}
   205  	defer f.Close()
   206  	_, err = io.Copy(s, f)
   207  	return err
   208  }
   209  
   210  func parseURL(s string) (u *url.URL, authority string, err error) {
   211  	u, err = url.Parse(s)
   212  	if err != nil {
   213  		return nil, "", err
   214  	}
   215  	host := u.Hostname()
   216  	port := u.Port()
   217  	if port == "" {
   218  		port = "443"
   219  	}
   220  	authority = net.JoinHostPort(host, port)
   221  	return u, authority, nil
   222  }
   223  
   224  func fetchFrom(ctx context.Context, config *quic.Config, l *quic.Endpoint, addr string, urls []*url.URL) {
   225  	conn, err := l.Dial(ctx, "udp", addr, config)
   226  	if err != nil {
   227  		log.Printf("%v: %v", addr, err)
   228  		return
   229  	}
   230  	log.Printf("connected to %v", addr)
   231  	defer conn.Close()
   232  	var g sync.WaitGroup
   233  	for _, u := range urls {
   234  		u := u
   235  		g.Add(1)
   236  		go func() {
   237  			defer g.Done()
   238  			if err := fetchOne(ctx, conn, u); err != nil {
   239  				log.Printf("fetch %v: %v", u, err)
   240  			} else {
   241  				log.Printf("fetched %v", u)
   242  			}
   243  		}()
   244  	}
   245  	g.Wait()
   246  }
   247  
   248  func fetchOne(ctx context.Context, conn *quic.Conn, u *url.URL) error {
   249  	if len(u.Path) == 0 || u.Path[0] != '/' || !filepath.IsLocal(u.Path[1:]) {
   250  		return errors.New("invalid path")
   251  	}
   252  	file, err := os.Create(filepath.Join(*output, u.Path[1:]))
   253  	if err != nil {
   254  		return err
   255  	}
   256  	s, err := conn.NewStream(ctx)
   257  	if err != nil {
   258  		return err
   259  	}
   260  	defer s.Close()
   261  	if _, err := s.Write([]byte("GET " + u.Path + "\r\n")); err != nil {
   262  		return err
   263  	}
   264  	s.CloseWrite()
   265  	if _, err := io.Copy(file, s); err != nil {
   266  		return err
   267  	}
   268  	return nil
   269  }