github.com/alejandroesc/spdy@v0.0.0-20200317064415-01a02f0eb389/transport.go (about) 1 // Copyright 2013 Jamie Hall. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package spdy 6 7 import ( 8 "crypto/tls" 9 "errors" 10 "fmt" 11 "net" 12 "net/http" 13 "net/http/httputil" 14 "net/url" 15 "strings" 16 "sync" 17 "time" 18 19 "github.com/SlyMarbo/spdy/common" 20 ) 21 22 // A Transport is an HTTP/SPDY http.RoundTripper. 23 type Transport struct { 24 m sync.Mutex 25 26 // Proxy specifies a function to return a proxy for a given 27 // Request. If the function returns a non-nil error, the 28 // request is aborted with the provided error. 29 // If Proxy is nil or returns a nil *URL, no proxy is used. 30 Proxy func(*http.Request) (*url.URL, error) 31 32 // Dial specifies the dial function for creating TCP 33 // connections. 34 // If Dial is nil, net.Dial is used. 35 Dial func(network, addr string) (net.Conn, error) // TODO: use 36 37 // TLSClientConfig specifies the TLS configuration to use with 38 // tls.Client. If nil, the default configuration is used. 39 TLSClientConfig *tls.Config 40 41 // DisableKeepAlives, if true, prevents re-use of TCP connections 42 // between different HTTP requests. 43 DisableKeepAlives bool 44 45 // DisableCompression, if true, prevents the Transport from 46 // requesting compression with an "Accept-Encoding: gzip" 47 // request header when the Request contains no existing 48 // Accept-Encoding value. If the Transport requests gzip on 49 // its own and gets a gzipped response, it's transparently 50 // decoded in the Response.Body. However, if the user 51 // explicitly requested gzip it is not automatically 52 // uncompressed. 53 DisableCompression bool 54 55 // MaxIdleConnsPerHost, if non-zero, controls the maximum idle 56 // (keep-alive) to keep per-host. If zero, 57 // DefaultMaxIdleConnsPerHost is used. 58 MaxIdleConnsPerHost int 59 60 // ResponseHeaderTimeout, if non-zero, specifies the amount of 61 // time to wait for a server's response headers after fully 62 // writing the request (including its body, if any). This 63 // time does not include the time to read the response body. 64 ResponseHeaderTimeout time.Duration 65 66 spdyConns map[string]common.Conn // SPDY connections mapped to host:port. 67 tcpConns map[string]chan net.Conn // Non-SPDY connections mapped to host:port. 68 connLimit map[string]chan struct{} // Used to enforce the TCP conn limit. 69 70 // Priority is used to determine the request priority of SPDY 71 // requests. If nil, spdy.DefaultPriority is used. 72 Priority func(*url.URL) common.Priority 73 74 // Receiver is used to receive the server's response. If left 75 // nil, the default Receiver will parse and create a normal 76 // Response. 77 Receiver common.Receiver 78 79 // PushReceiver is used to receive server pushes. If left nil, 80 // pushes will be refused. The provided Request will be that 81 // sent with the server push. See Receiver for more detail on 82 // its methods. 83 PushReceiver common.Receiver 84 } 85 86 // NewTransport gives a simple initialised Transport. 87 func NewTransport(insecureSkipVerify bool) *Transport { 88 return &Transport{ 89 TLSClientConfig: &tls.Config{ 90 InsecureSkipVerify: insecureSkipVerify, 91 NextProtos: npn(), 92 }, 93 } 94 } 95 96 // dial makes the connection to an endpoint. 97 func (t *Transport) dial(u *url.URL) (conn net.Conn, err error) { 98 99 if t.TLSClientConfig == nil { 100 t.TLSClientConfig = &tls.Config{ 101 NextProtos: npn(), 102 } 103 } else if t.TLSClientConfig.NextProtos == nil { 104 t.TLSClientConfig.NextProtos = npn() 105 } 106 107 // Wait for a connection slot to become available. 108 <-t.connLimit[u.Host] 109 110 switch u.Scheme { 111 case "http": 112 conn, err = net.Dial("tcp", u.Host) 113 case "https": 114 conn, err = tls.Dial("tcp", u.Host, t.TLSClientConfig) 115 default: 116 err = errors.New(fmt.Sprintf("Error: URL has invalid scheme %q.", u.Scheme)) 117 } 118 119 if err != nil { 120 // The connection never happened, which frees up a slot. 121 t.connLimit[u.Host] <- struct{}{} 122 } 123 124 return conn, err 125 } 126 127 // doHTTP is used to process an HTTP(S) request, using the TCP connection pool. 128 func (t *Transport) doHTTP(conn net.Conn, req *http.Request) (*http.Response, error) { 129 debug.Printf("Requesting %q over HTTP.\n", req.URL.String()) 130 131 // Create the HTTP ClientConn, which handles the 132 // HTTP details. 133 httpConn := httputil.NewClientConn(conn, nil) 134 res, err := httpConn.Do(req) 135 if err != nil { 136 return nil, err 137 } 138 139 if !res.Close { 140 t.tcpConns[req.URL.Host] <- conn 141 } else { 142 // This connection is closing, so another can be used. 143 t.connLimit[req.URL.Host] <- struct{}{} 144 err = httpConn.Close() 145 if err != nil { 146 return nil, err 147 } 148 } 149 150 return res, nil 151 } 152 153 // RoundTrip handles the actual request; ensuring a connection is 154 // made, determining which protocol to use, and performing the 155 // request. 156 func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) { 157 u := req.URL 158 159 // Make sure the URL host contains the port. 160 if !strings.Contains(u.Host, ":") { 161 switch u.Scheme { 162 case "http": 163 u.Host += ":80" 164 165 case "https": 166 u.Host += ":443" 167 } 168 } 169 170 conn, tcpConn, err := t.process(req) 171 if err != nil { 172 return nil, err 173 } 174 if tcpConn != nil { 175 return t.doHTTP(tcpConn, req) 176 } 177 178 // The connection has now been established. 179 180 debug.Printf("Requesting %q over SPDY.\n", u.String()) 181 182 // Determine the request priority. 183 var priority common.Priority 184 if t.Priority != nil { 185 priority = t.Priority(req.URL) 186 } else { 187 priority = common.DefaultPriority(req.URL) 188 } 189 190 res, err := conn.RequestResponse(req, t.Receiver, priority) 191 if conn.Closed() { 192 t.connLimit[u.Host] <- struct{}{} 193 } 194 if err != nil { 195 return nil, err 196 } 197 198 return res, nil 199 } 200 201 func (t *Transport) process(req *http.Request) (common.Conn, net.Conn, error) { 202 t.m.Lock() 203 defer t.m.Unlock() 204 205 u := req.URL 206 207 // Initialise structures if necessary. 208 if t.spdyConns == nil { 209 t.spdyConns = make(map[string]common.Conn) 210 } 211 if t.tcpConns == nil { 212 t.tcpConns = make(map[string]chan net.Conn) 213 } 214 if t.connLimit == nil { 215 t.connLimit = make(map[string]chan struct{}) 216 } 217 if t.MaxIdleConnsPerHost == 0 { 218 t.MaxIdleConnsPerHost = http.DefaultMaxIdleConnsPerHost 219 } 220 if _, ok := t.connLimit[u.Host]; !ok { 221 limitChan := make(chan struct{}, t.MaxIdleConnsPerHost) 222 t.connLimit[u.Host] = limitChan 223 for i := 0; i < t.MaxIdleConnsPerHost; i++ { 224 limitChan <- struct{}{} 225 } 226 } 227 228 // Check the non-SPDY connection pool. 229 if connChan, ok := t.tcpConns[u.Host]; ok { 230 select { 231 case tcpConn := <-connChan: 232 // Use a connection from the pool. 233 return nil, tcpConn, nil 234 default: 235 } 236 } else { 237 t.tcpConns[u.Host] = make(chan net.Conn, t.MaxIdleConnsPerHost) 238 } 239 240 // Check the SPDY connection pool. 241 conn, ok := t.spdyConns[u.Host] 242 if !ok || u.Scheme == "http" || (conn != nil && conn.Closed()) { 243 tcpConn, err := t.dial(req.URL) 244 if err != nil { 245 return nil, nil, err 246 } 247 248 if tlsConn, ok := tcpConn.(*tls.Conn); !ok { 249 // Handle HTTP requests. 250 return nil, tcpConn, nil 251 } else { 252 // Handle HTTPS/SPDY requests. 253 state := tlsConn.ConnectionState() 254 255 // Complete handshake if necessary. 256 if !state.HandshakeComplete { 257 err = tlsConn.Handshake() 258 if err != nil { 259 return nil, nil, err 260 } 261 } 262 263 // Verify hostname, unless requested not to. 264 if !t.TLSClientConfig.InsecureSkipVerify { 265 err = tlsConn.VerifyHostname(req.URL.Host) 266 if err != nil { 267 // Also try verifying the hostname with/without a port number. 268 i := strings.Index(req.URL.Host, ":") 269 err = tlsConn.VerifyHostname(req.URL.Host[:i]) 270 if err != nil { 271 return nil, nil, err 272 } 273 } 274 } 275 276 // If a protocol could not be negotiated, assume HTTPS. 277 if !state.NegotiatedProtocolIsMutual { 278 return nil, tcpConn, nil 279 } 280 281 // Scan the list of supported NPN strings. 282 supported := false 283 for _, proto := range npn() { 284 if state.NegotiatedProtocol == proto { 285 supported = true 286 break 287 } 288 } 289 290 // Ensure the negotiated protocol is supported. 291 if !supported && state.NegotiatedProtocol != "" { 292 msg := fmt.Sprintf("Error: Unsupported negotiated protocol %q.", state.NegotiatedProtocol) 293 return nil, nil, errors.New(msg) 294 } 295 296 // Handle the protocol. 297 switch state.NegotiatedProtocol { 298 case "http/1.1", "": 299 return nil, tcpConn, nil 300 301 case "spdy/3.1": 302 newConn, err := NewClientConn(tlsConn, t.PushReceiver, 3, 1) 303 if err != nil { 304 return nil, nil, err 305 } 306 go newConn.Run() 307 t.spdyConns[u.Host] = newConn 308 conn = newConn 309 310 case "spdy/3": 311 newConn, err := NewClientConn(tlsConn, t.PushReceiver, 3, 0) 312 if err != nil { 313 return nil, nil, err 314 } 315 go newConn.Run() 316 t.spdyConns[u.Host] = newConn 317 conn = newConn 318 319 case "spdy/2": 320 newConn, err := NewClientConn(tlsConn, t.PushReceiver, 2, 0) 321 if err != nil { 322 return nil, nil, err 323 } 324 go newConn.Run() 325 t.spdyConns[u.Host] = newConn 326 conn = newConn 327 } 328 } 329 } 330 331 return conn, nil, nil 332 }