github.com/MerlinKodo/quic-go@v0.39.2/interop/http09/client.go (about)

     1  package http09
     2  
     3  import (
     4  	"context"
     5  	"crypto/tls"
     6  	"errors"
     7  	"io"
     8  	"log"
     9  	"net"
    10  	"net/http"
    11  	"strings"
    12  	"sync"
    13  
    14  	"golang.org/x/net/idna"
    15  
    16  	"github.com/MerlinKodo/quic-go"
    17  )
    18  
    19  // MethodGet0RTT allows a GET request to be sent using 0-RTT.
    20  // Note that 0-RTT data doesn't provide replay protection.
    21  const MethodGet0RTT = "GET_0RTT"
    22  
    23  // RoundTripper performs HTTP/0.9 roundtrips over QUIC.
    24  type RoundTripper struct {
    25  	mutex sync.Mutex
    26  
    27  	TLSClientConfig *tls.Config
    28  	QuicConfig      *quic.Config
    29  
    30  	clients map[string]*client
    31  }
    32  
    33  var _ http.RoundTripper = &RoundTripper{}
    34  
    35  // RoundTrip performs a HTTP/0.9 request.
    36  // It only supports GET requests.
    37  func (r *RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
    38  	if req.Method != http.MethodGet && req.Method != MethodGet0RTT {
    39  		return nil, errors.New("only GET requests supported")
    40  	}
    41  
    42  	log.Printf("Requesting %s.\n", req.URL)
    43  
    44  	r.mutex.Lock()
    45  	hostname := authorityAddr("https", hostnameFromRequest(req))
    46  	if r.clients == nil {
    47  		r.clients = make(map[string]*client)
    48  	}
    49  	c, ok := r.clients[hostname]
    50  	if !ok {
    51  		tlsConf := &tls.Config{}
    52  		if r.TLSClientConfig != nil {
    53  			tlsConf = r.TLSClientConfig.Clone()
    54  		}
    55  		tlsConf.NextProtos = []string{h09alpn}
    56  		c = &client{
    57  			hostname: hostname,
    58  			tlsConf:  tlsConf,
    59  			quicConf: r.QuicConfig,
    60  		}
    61  		r.clients[hostname] = c
    62  	}
    63  	r.mutex.Unlock()
    64  	return c.RoundTrip(req)
    65  }
    66  
    67  // Close closes the roundtripper.
    68  func (r *RoundTripper) Close() error {
    69  	r.mutex.Lock()
    70  	defer r.mutex.Unlock()
    71  
    72  	for id, c := range r.clients {
    73  		if err := c.Close(); err != nil {
    74  			return err
    75  		}
    76  		delete(r.clients, id)
    77  	}
    78  	return nil
    79  }
    80  
    81  type client struct {
    82  	hostname string
    83  	tlsConf  *tls.Config
    84  	quicConf *quic.Config
    85  
    86  	once    sync.Once
    87  	conn    quic.EarlyConnection
    88  	dialErr error
    89  }
    90  
    91  func (c *client) RoundTrip(req *http.Request) (*http.Response, error) {
    92  	c.once.Do(func() {
    93  		c.conn, c.dialErr = quic.DialAddrEarly(context.Background(), c.hostname, c.tlsConf, c.quicConf)
    94  	})
    95  	if c.dialErr != nil {
    96  		return nil, c.dialErr
    97  	}
    98  	if req.Method != MethodGet0RTT {
    99  		<-c.conn.HandshakeComplete()
   100  	}
   101  	return c.doRequest(req)
   102  }
   103  
   104  func (c *client) doRequest(req *http.Request) (*http.Response, error) {
   105  	str, err := c.conn.OpenStreamSync(context.Background())
   106  	if err != nil {
   107  		return nil, err
   108  	}
   109  	cmd := "GET " + req.URL.Path + "\r\n"
   110  	if _, err := str.Write([]byte(cmd)); err != nil {
   111  		return nil, err
   112  	}
   113  	if err := str.Close(); err != nil {
   114  		return nil, err
   115  	}
   116  	rsp := &http.Response{
   117  		Proto:      "HTTP/0.9",
   118  		ProtoMajor: 0,
   119  		ProtoMinor: 9,
   120  		Request:    req,
   121  		Body:       io.NopCloser(str),
   122  	}
   123  	return rsp, nil
   124  }
   125  
   126  func (c *client) Close() error {
   127  	if c.conn == nil {
   128  		return nil
   129  	}
   130  	return c.conn.CloseWithError(0, "")
   131  }
   132  
   133  func hostnameFromRequest(req *http.Request) string {
   134  	if req.URL != nil {
   135  		return req.URL.Host
   136  	}
   137  	return ""
   138  }
   139  
   140  // authorityAddr returns a given authority (a host/IP, or host:port / ip:port)
   141  // and returns a host:port. The port 443 is added if needed.
   142  func authorityAddr(scheme string, authority string) (addr string) {
   143  	host, port, err := net.SplitHostPort(authority)
   144  	if err != nil { // authority didn't have a port
   145  		port = "443"
   146  		if scheme == "http" {
   147  			port = "80"
   148  		}
   149  		host = authority
   150  	}
   151  	if a, err := idna.ToASCII(host); err == nil {
   152  		host = a
   153  	}
   154  	// IPv6 address literal, without a port:
   155  	if strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]") {
   156  		return host + ":" + port
   157  	}
   158  	return net.JoinHostPort(host, port)
   159  }