github.com/alejandroEsc/spdy@v0.0.0-20200317064415-01a02f0eb389/transport.go (about)

     1  // Copyright 2013 Jamie Hall. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package spdy
     6  
     7  import (
     8  	"crypto/tls"
     9  	"errors"
    10  	"fmt"
    11  	"net"
    12  	"net/http"
    13  	"net/http/httputil"
    14  	"net/url"
    15  	"strings"
    16  	"sync"
    17  	"time"
    18  
    19  	"github.com/SlyMarbo/spdy/common"
    20  )
    21  
    22  // A Transport is an HTTP/SPDY http.RoundTripper.
    23  type Transport struct {
    24  	m sync.Mutex
    25  
    26  	// Proxy specifies a function to return a proxy for a given
    27  	// Request. If the function returns a non-nil error, the
    28  	// request is aborted with the provided error.
    29  	// If Proxy is nil or returns a nil *URL, no proxy is used.
    30  	Proxy func(*http.Request) (*url.URL, error)
    31  
    32  	// Dial specifies the dial function for creating TCP
    33  	// connections.
    34  	// If Dial is nil, net.Dial is used.
    35  	Dial func(network, addr string) (net.Conn, error) // TODO: use
    36  
    37  	// TLSClientConfig specifies the TLS configuration to use with
    38  	// tls.Client. If nil, the default configuration is used.
    39  	TLSClientConfig *tls.Config
    40  
    41  	// DisableKeepAlives, if true, prevents re-use of TCP connections
    42  	// between different HTTP requests.
    43  	DisableKeepAlives bool
    44  
    45  	// DisableCompression, if true, prevents the Transport from
    46  	// requesting compression with an "Accept-Encoding: gzip"
    47  	// request header when the Request contains no existing
    48  	// Accept-Encoding value. If the Transport requests gzip on
    49  	// its own and gets a gzipped response, it's transparently
    50  	// decoded in the Response.Body. However, if the user
    51  	// explicitly requested gzip it is not automatically
    52  	// uncompressed.
    53  	DisableCompression bool
    54  
    55  	// MaxIdleConnsPerHost, if non-zero, controls the maximum idle
    56  	// (keep-alive) to keep per-host.  If zero,
    57  	// DefaultMaxIdleConnsPerHost is used.
    58  	MaxIdleConnsPerHost int
    59  
    60  	// ResponseHeaderTimeout, if non-zero, specifies the amount of
    61  	// time to wait for a server's response headers after fully
    62  	// writing the request (including its body, if any). This
    63  	// time does not include the time to read the response body.
    64  	ResponseHeaderTimeout time.Duration
    65  
    66  	spdyConns map[string]common.Conn   // SPDY connections mapped to host:port.
    67  	tcpConns  map[string]chan net.Conn // Non-SPDY connections mapped to host:port.
    68  	connLimit map[string]chan struct{} // Used to enforce the TCP conn limit.
    69  
    70  	// Priority is used to determine the request priority of SPDY
    71  	// requests. If nil, spdy.DefaultPriority is used.
    72  	Priority func(*url.URL) common.Priority
    73  
    74  	// Receiver is used to receive the server's response. If left
    75  	// nil, the default Receiver will parse and create a normal
    76  	// Response.
    77  	Receiver common.Receiver
    78  
    79  	// PushReceiver is used to receive server pushes. If left nil,
    80  	// pushes will be refused. The provided Request will be that
    81  	// sent with the server push. See Receiver for more detail on
    82  	// its methods.
    83  	PushReceiver common.Receiver
    84  }
    85  
    86  // NewTransport gives a simple initialised Transport.
    87  func NewTransport(insecureSkipVerify bool) *Transport {
    88  	return &Transport{
    89  		TLSClientConfig: &tls.Config{
    90  			InsecureSkipVerify: insecureSkipVerify,
    91  			NextProtos:         npn(),
    92  		},
    93  	}
    94  }
    95  
    96  // dial makes the connection to an endpoint.
    97  func (t *Transport) dial(u *url.URL) (conn net.Conn, err error) {
    98  
    99  	if t.TLSClientConfig == nil {
   100  		t.TLSClientConfig = &tls.Config{
   101  			NextProtos: npn(),
   102  		}
   103  	} else if t.TLSClientConfig.NextProtos == nil {
   104  		t.TLSClientConfig.NextProtos = npn()
   105  	}
   106  
   107  	// Wait for a connection slot to become available.
   108  	<-t.connLimit[u.Host]
   109  
   110  	switch u.Scheme {
   111  	case "http":
   112  		conn, err = net.Dial("tcp", u.Host)
   113  	case "https":
   114  		conn, err = tls.Dial("tcp", u.Host, t.TLSClientConfig)
   115  	default:
   116  		err = errors.New(fmt.Sprintf("Error: URL has invalid scheme %q.", u.Scheme))
   117  	}
   118  
   119  	if err != nil {
   120  		// The connection never happened, which frees up a slot.
   121  		t.connLimit[u.Host] <- struct{}{}
   122  	}
   123  
   124  	return conn, err
   125  }
   126  
   127  // doHTTP is used to process an HTTP(S) request, using the TCP connection pool.
   128  func (t *Transport) doHTTP(conn net.Conn, req *http.Request) (*http.Response, error) {
   129  	debug.Printf("Requesting %q over HTTP.\n", req.URL.String())
   130  
   131  	// Create the HTTP ClientConn, which handles the
   132  	// HTTP details.
   133  	httpConn := httputil.NewClientConn(conn, nil)
   134  	res, err := httpConn.Do(req)
   135  	if err != nil {
   136  		return nil, err
   137  	}
   138  
   139  	if !res.Close {
   140  		t.tcpConns[req.URL.Host] <- conn
   141  	} else {
   142  		// This connection is closing, so another can be used.
   143  		t.connLimit[req.URL.Host] <- struct{}{}
   144  		err = httpConn.Close()
   145  		if err != nil {
   146  			return nil, err
   147  		}
   148  	}
   149  
   150  	return res, nil
   151  }
   152  
   153  // RoundTrip handles the actual request; ensuring a connection is
   154  // made, determining which protocol to use, and performing the
   155  // request.
   156  func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) {
   157  	u := req.URL
   158  
   159  	// Make sure the URL host contains the port.
   160  	if !strings.Contains(u.Host, ":") {
   161  		switch u.Scheme {
   162  		case "http":
   163  			u.Host += ":80"
   164  
   165  		case "https":
   166  			u.Host += ":443"
   167  		}
   168  	}
   169  
   170  	conn, tcpConn, err := t.process(req)
   171  	if err != nil {
   172  		return nil, err
   173  	}
   174  	if tcpConn != nil {
   175  		return t.doHTTP(tcpConn, req)
   176  	}
   177  
   178  	// The connection has now been established.
   179  
   180  	debug.Printf("Requesting %q over SPDY.\n", u.String())
   181  
   182  	// Determine the request priority.
   183  	var priority common.Priority
   184  	if t.Priority != nil {
   185  		priority = t.Priority(req.URL)
   186  	} else {
   187  		priority = common.DefaultPriority(req.URL)
   188  	}
   189  
   190  	res, err := conn.RequestResponse(req, t.Receiver, priority)
   191  	if conn.Closed() {
   192  		t.connLimit[u.Host] <- struct{}{}
   193  	}
   194  	if err != nil {
   195  		return nil, err
   196  	}
   197  
   198  	return res, nil
   199  }
   200  
   201  func (t *Transport) process(req *http.Request) (common.Conn, net.Conn, error) {
   202  	t.m.Lock()
   203  	defer t.m.Unlock()
   204  
   205  	u := req.URL
   206  
   207  	// Initialise structures if necessary.
   208  	if t.spdyConns == nil {
   209  		t.spdyConns = make(map[string]common.Conn)
   210  	}
   211  	if t.tcpConns == nil {
   212  		t.tcpConns = make(map[string]chan net.Conn)
   213  	}
   214  	if t.connLimit == nil {
   215  		t.connLimit = make(map[string]chan struct{})
   216  	}
   217  	if t.MaxIdleConnsPerHost == 0 {
   218  		t.MaxIdleConnsPerHost = http.DefaultMaxIdleConnsPerHost
   219  	}
   220  	if _, ok := t.connLimit[u.Host]; !ok {
   221  		limitChan := make(chan struct{}, t.MaxIdleConnsPerHost)
   222  		t.connLimit[u.Host] = limitChan
   223  		for i := 0; i < t.MaxIdleConnsPerHost; i++ {
   224  			limitChan <- struct{}{}
   225  		}
   226  	}
   227  
   228  	// Check the non-SPDY connection pool.
   229  	if connChan, ok := t.tcpConns[u.Host]; ok {
   230  		select {
   231  		case tcpConn := <-connChan:
   232  			// Use a connection from the pool.
   233  			return nil, tcpConn, nil
   234  		default:
   235  		}
   236  	} else {
   237  		t.tcpConns[u.Host] = make(chan net.Conn, t.MaxIdleConnsPerHost)
   238  	}
   239  
   240  	// Check the SPDY connection pool.
   241  	conn, ok := t.spdyConns[u.Host]
   242  	if !ok || u.Scheme == "http" || (conn != nil && conn.Closed()) {
   243  		tcpConn, err := t.dial(req.URL)
   244  		if err != nil {
   245  			return nil, nil, err
   246  		}
   247  
   248  		if tlsConn, ok := tcpConn.(*tls.Conn); !ok {
   249  			// Handle HTTP requests.
   250  			return nil, tcpConn, nil
   251  		} else {
   252  			// Handle HTTPS/SPDY requests.
   253  			state := tlsConn.ConnectionState()
   254  
   255  			// Complete handshake if necessary.
   256  			if !state.HandshakeComplete {
   257  				err = tlsConn.Handshake()
   258  				if err != nil {
   259  					return nil, nil, err
   260  				}
   261  			}
   262  
   263  			// Verify hostname, unless requested not to.
   264  			if !t.TLSClientConfig.InsecureSkipVerify {
   265  				err = tlsConn.VerifyHostname(req.URL.Host)
   266  				if err != nil {
   267  					// Also try verifying the hostname with/without a port number.
   268  					i := strings.Index(req.URL.Host, ":")
   269  					err = tlsConn.VerifyHostname(req.URL.Host[:i])
   270  					if err != nil {
   271  						return nil, nil, err
   272  					}
   273  				}
   274  			}
   275  
   276  			// If a protocol could not be negotiated, assume HTTPS.
   277  			if !state.NegotiatedProtocolIsMutual {
   278  				return nil, tcpConn, nil
   279  			}
   280  
   281  			// Scan the list of supported NPN strings.
   282  			supported := false
   283  			for _, proto := range npn() {
   284  				if state.NegotiatedProtocol == proto {
   285  					supported = true
   286  					break
   287  				}
   288  			}
   289  
   290  			// Ensure the negotiated protocol is supported.
   291  			if !supported && state.NegotiatedProtocol != "" {
   292  				msg := fmt.Sprintf("Error: Unsupported negotiated protocol %q.", state.NegotiatedProtocol)
   293  				return nil, nil, errors.New(msg)
   294  			}
   295  
   296  			// Handle the protocol.
   297  			switch state.NegotiatedProtocol {
   298  			case "http/1.1", "":
   299  				return nil, tcpConn, nil
   300  
   301  			case "spdy/3.1":
   302  				newConn, err := NewClientConn(tlsConn, t.PushReceiver, 3, 1)
   303  				if err != nil {
   304  					return nil, nil, err
   305  				}
   306  				go newConn.Run()
   307  				t.spdyConns[u.Host] = newConn
   308  				conn = newConn
   309  
   310  			case "spdy/3":
   311  				newConn, err := NewClientConn(tlsConn, t.PushReceiver, 3, 0)
   312  				if err != nil {
   313  					return nil, nil, err
   314  				}
   315  				go newConn.Run()
   316  				t.spdyConns[u.Host] = newConn
   317  				conn = newConn
   318  
   319  			case "spdy/2":
   320  				newConn, err := NewClientConn(tlsConn, t.PushReceiver, 2, 0)
   321  				if err != nil {
   322  					return nil, nil, err
   323  				}
   324  				go newConn.Run()
   325  				t.spdyConns[u.Host] = newConn
   326  				conn = newConn
   327  			}
   328  		}
   329  	}
   330  
   331  	return conn, nil, nil
   332  }