github.com/psiphon-Labs/psiphon-tunnel-core@v2.0.28+incompatible/psiphon/common/quic/gquic-go/h2quic/client.go (about) 1 package h2quic 2 3 import ( 4 "crypto/tls" 5 "errors" 6 "fmt" 7 "io" 8 "net" 9 "net/http" 10 "strings" 11 "sync" 12 13 "golang.org/x/net/http2" 14 "golang.org/x/net/http2/hpack" 15 "golang.org/x/net/idna" 16 17 quic "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/quic/gquic-go" 18 "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/quic/gquic-go/internal/protocol" 19 "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/quic/gquic-go/internal/utils" 20 "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/quic/gquic-go/qerr" 21 ) 22 23 type roundTripperOpts struct { 24 DisableCompression bool 25 } 26 27 var dialAddr = quic.DialAddr 28 29 // client is a HTTP2 client doing QUIC requests 30 type client struct { 31 mutex sync.RWMutex 32 33 tlsConf *tls.Config 34 config *quic.Config 35 opts *roundTripperOpts 36 37 hostname string 38 handshakeErr error 39 dialOnce sync.Once 40 dialer func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.Session, error) 41 42 // [Psiphon] 43 // Fix close-while-dialing race condition by synchronizing access to 44 // client.session and adding a closed flag to indicate if the client was 45 // closed while a dial was in progress. 46 sessionMutex sync.Mutex 47 closed bool 48 session quic.Session 49 50 headerStream quic.Stream 51 headerErr *qerr.QuicError 52 headerErrored chan struct{} // this channel is closed if an error occurs on the header stream 53 requestWriter *requestWriter 54 55 responses map[protocol.StreamID]chan *http.Response 56 57 logger utils.Logger 58 } 59 60 var _ http.RoundTripper = &client{} 61 62 var defaultQuicConfig = &quic.Config{ 63 RequestConnectionIDOmission: true, 64 KeepAlive: true, 65 } 66 67 // newClient creates a new client 68 func newClient( 69 hostname string, 70 tlsConfig *tls.Config, 71 opts *roundTripperOpts, 72 quicConfig *quic.Config, 73 dialer func(network, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.Session, error), 74 ) *client { 75 config := defaultQuicConfig 76 if quicConfig != nil { 77 config = quicConfig 78 } 79 return &client{ 80 hostname: authorityAddr("https", hostname), 81 responses: make(map[protocol.StreamID]chan *http.Response), 82 tlsConf: tlsConfig, 83 config: config, 84 opts: opts, 85 headerErrored: make(chan struct{}), 86 dialer: dialer, 87 logger: utils.DefaultLogger.WithPrefix("client"), 88 } 89 } 90 91 // dial dials the connection 92 func (c *client) dial() error { 93 var err error 94 95 // [Psiphon] 96 var session quic.Session 97 98 if c.dialer != nil { 99 session, err = c.dialer("udp", c.hostname, c.tlsConf, c.config) 100 } else { 101 session, err = dialAddr(c.hostname, c.tlsConf, c.config) 102 } 103 if err != nil { 104 return err 105 } 106 107 // [Psiphon] 108 // Only this write and the Close reads of c.session require synchronization. 109 // After this point, it's safe to concurrently read c.session as it is not 110 // rewritten. 111 c.sessionMutex.Lock() 112 closed := c.closed 113 if !closed { 114 c.session = session 115 } 116 c.sessionMutex.Unlock() 117 if closed { 118 session.Close() 119 return errors.New("closed while dialing") 120 } 121 // [Psiphon] 122 123 // once the version has been negotiated, open the header stream 124 c.headerStream, err = c.session.OpenStream() 125 if err != nil { 126 return err 127 } 128 c.requestWriter = newRequestWriter(c.headerStream, c.logger) 129 go c.handleHeaderStream() 130 return nil 131 } 132 133 func (c *client) handleHeaderStream() { 134 decoder := hpack.NewDecoder(4096, func(hf hpack.HeaderField) {}) 135 h2framer := http2.NewFramer(nil, c.headerStream) 136 137 var err error 138 for err == nil { 139 err = c.readResponse(h2framer, decoder) 140 } 141 if quicErr, ok := err.(*qerr.QuicError); !ok || quicErr.ErrorCode != qerr.PeerGoingAway { 142 c.logger.Debugf("Error handling header stream: %s", err) 143 } 144 c.headerErr = qerr.Error(qerr.InvalidHeadersStreamData, err.Error()) 145 // stop all running request 146 close(c.headerErrored) 147 } 148 149 func (c *client) readResponse(h2framer *http2.Framer, decoder *hpack.Decoder) error { 150 frame, err := h2framer.ReadFrame() 151 if err != nil { 152 return err 153 } 154 hframe, ok := frame.(*http2.HeadersFrame) 155 if !ok { 156 return errors.New("not a headers frame") 157 } 158 mhframe := &http2.MetaHeadersFrame{HeadersFrame: hframe} 159 mhframe.Fields, err = decoder.DecodeFull(hframe.HeaderBlockFragment()) 160 if err != nil { 161 return fmt.Errorf("cannot read header fields: %s", err.Error()) 162 } 163 164 c.mutex.RLock() 165 responseChan, ok := c.responses[protocol.StreamID(hframe.StreamID)] 166 c.mutex.RUnlock() 167 if !ok { 168 return fmt.Errorf("response channel for stream %d not found", hframe.StreamID) 169 } 170 171 rsp, err := responseFromHeaders(mhframe) 172 if err != nil { 173 return err 174 } 175 responseChan <- rsp 176 return nil 177 } 178 179 // Roundtrip executes a request and returns a response 180 func (c *client) RoundTrip(req *http.Request) (*http.Response, error) { 181 // TODO: add port to address, if it doesn't have one 182 if req.URL.Scheme != "https" { 183 return nil, errors.New("quic http2: unsupported scheme") 184 } 185 if authorityAddr("https", hostnameFromRequest(req)) != c.hostname { 186 return nil, fmt.Errorf("h2quic Client BUG: RoundTrip called for the wrong client (expected %s, got %s)", c.hostname, req.Host) 187 } 188 189 c.dialOnce.Do(func() { 190 c.handshakeErr = c.dial() 191 }) 192 193 if c.handshakeErr != nil { 194 return nil, c.handshakeErr 195 } 196 197 hasBody := (req.Body != nil) 198 199 responseChan := make(chan *http.Response) 200 dataStream, err := c.session.OpenStreamSync() 201 if err != nil { 202 _ = c.closeWithError(err) 203 return nil, err 204 } 205 c.mutex.Lock() 206 c.responses[dataStream.StreamID()] = responseChan 207 c.mutex.Unlock() 208 209 var requestedGzip bool 210 if !c.opts.DisableCompression && req.Header.Get("Accept-Encoding") == "" && req.Header.Get("Range") == "" && req.Method != "HEAD" { 211 requestedGzip = true 212 } 213 // TODO: add support for trailers 214 endStream := !hasBody 215 err = c.requestWriter.WriteRequest(req, dataStream.StreamID(), endStream, requestedGzip) 216 if err != nil { 217 _ = c.closeWithError(err) 218 return nil, err 219 } 220 221 resc := make(chan error, 1) 222 if hasBody { 223 go func() { 224 resc <- c.writeRequestBody(dataStream, req.Body) 225 }() 226 } 227 228 var res *http.Response 229 230 var receivedResponse bool 231 var bodySent bool 232 233 if !hasBody { 234 bodySent = true 235 } 236 237 ctx := req.Context() 238 for !(bodySent && receivedResponse) { 239 select { 240 case res = <-responseChan: 241 receivedResponse = true 242 c.mutex.Lock() 243 delete(c.responses, dataStream.StreamID()) 244 c.mutex.Unlock() 245 case err := <-resc: 246 bodySent = true 247 if err != nil { 248 return nil, err 249 } 250 case <-ctx.Done(): 251 // error code 6 signals that stream was canceled 252 dataStream.CancelRead(6) 253 dataStream.CancelWrite(6) 254 c.mutex.Lock() 255 delete(c.responses, dataStream.StreamID()) 256 c.mutex.Unlock() 257 return nil, ctx.Err() 258 case <-c.headerErrored: 259 // an error occurred on the header stream 260 _ = c.closeWithError(c.headerErr) 261 return nil, c.headerErr 262 } 263 } 264 265 // TODO: correctly set this variable 266 var streamEnded bool 267 isHead := (req.Method == "HEAD") 268 269 res = setLength(res, isHead, streamEnded) 270 271 if streamEnded || isHead { 272 res.Body = noBody 273 } else { 274 res.Body = dataStream 275 if requestedGzip && res.Header.Get("Content-Encoding") == "gzip" { 276 res.Header.Del("Content-Encoding") 277 res.Header.Del("Content-Length") 278 res.ContentLength = -1 279 res.Body = &gzipReader{body: res.Body} 280 res.Uncompressed = true 281 } 282 } 283 284 res.Request = req 285 return res, nil 286 } 287 288 func (c *client) writeRequestBody(dataStream quic.Stream, body io.ReadCloser) (err error) { 289 defer func() { 290 cerr := body.Close() 291 if err == nil { 292 // TODO: what to do with dataStream here? Maybe reset it? 293 err = cerr 294 } 295 }() 296 297 _, err = io.Copy(dataStream, body) 298 if err != nil { 299 // TODO: what to do with dataStream here? Maybe reset it? 300 return err 301 } 302 return dataStream.Close() 303 } 304 305 func (c *client) closeWithError(e error) error { 306 307 // [Psiphon] 308 c.sessionMutex.Lock() 309 session := c.session 310 c.closed = true 311 c.sessionMutex.Unlock() 312 // [Psiphon] 313 314 if session == nil { 315 return nil 316 } 317 return session.CloseWithError(quic.ErrorCode(qerr.InternalError), e) 318 } 319 320 // Close closes the client 321 func (c *client) Close() error { 322 323 // [Psiphon] 324 c.sessionMutex.Lock() 325 session := c.session 326 c.closed = true 327 c.sessionMutex.Unlock() 328 // [Psiphon] 329 330 if session == nil { 331 return nil 332 } 333 return session.Close() 334 } 335 336 // copied from net/transport.go 337 338 // authorityAddr returns a given authority (a host/IP, or host:port / ip:port) 339 // and returns a host:port. The port 443 is added if needed. 340 func authorityAddr(scheme string, authority string) (addr string) { 341 host, port, err := net.SplitHostPort(authority) 342 if err != nil { // authority didn't have a port 343 port = "443" 344 if scheme == "http" { 345 port = "80" 346 } 347 host = authority 348 } 349 if a, err := idna.ToASCII(host); err == nil { 350 host = a 351 } 352 // IPv6 address literal, without a port: 353 if strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]") { 354 return host + ":" + port 355 } 356 return net.JoinHostPort(host, port) 357 }