github.com/apernet/quic-go@v0.43.1-0.20240515053213-5e9e635fd9f0/http3/conn.go (about)

     1  package http3
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"fmt"
     7  	"log/slog"
     8  	"net"
     9  	"sync"
    10  	"sync/atomic"
    11  
    12  	"github.com/apernet/quic-go"
    13  	"github.com/apernet/quic-go/internal/protocol"
    14  	"github.com/apernet/quic-go/quicvarint"
    15  
    16  	"github.com/quic-go/qpack"
    17  )
    18  
    19  // Connection is an HTTP/3 connection.
    20  // It has all methods from the quic.Connection expect for AcceptStream, AcceptUniStream,
    21  // SendDatagram and ReceiveDatagram.
    22  type Connection interface {
    23  	OpenStream() (quic.Stream, error)
    24  	OpenStreamSync(context.Context) (quic.Stream, error)
    25  	OpenUniStream() (quic.SendStream, error)
    26  	OpenUniStreamSync(context.Context) (quic.SendStream, error)
    27  	LocalAddr() net.Addr
    28  	RemoteAddr() net.Addr
    29  	CloseWithError(quic.ApplicationErrorCode, string) error
    30  	Context() context.Context
    31  	ConnectionState() quic.ConnectionState
    32  
    33  	// ReceivedSettings returns a channel that is closed once the client's SETTINGS frame was received.
    34  	ReceivedSettings() <-chan struct{}
    35  	// Settings returns the settings received on this connection.
    36  	Settings() *Settings
    37  }
    38  
    39  type connection struct {
    40  	quic.Connection
    41  
    42  	perspective protocol.Perspective
    43  	logger      *slog.Logger
    44  
    45  	enableDatagrams bool
    46  
    47  	decoder *qpack.Decoder
    48  
    49  	streamMx sync.Mutex
    50  	streams  map[protocol.StreamID]*datagrammer
    51  
    52  	settings         *Settings
    53  	receivedSettings chan struct{}
    54  }
    55  
    56  func newConnection(
    57  	quicConn quic.Connection,
    58  	enableDatagrams bool,
    59  	perspective protocol.Perspective,
    60  	logger *slog.Logger,
    61  ) *connection {
    62  	c := &connection{
    63  		Connection:       quicConn,
    64  		perspective:      perspective,
    65  		logger:           logger,
    66  		enableDatagrams:  enableDatagrams,
    67  		decoder:          qpack.NewDecoder(func(hf qpack.HeaderField) {}),
    68  		receivedSettings: make(chan struct{}),
    69  		streams:          make(map[protocol.StreamID]*datagrammer),
    70  	}
    71  	return c
    72  }
    73  
    74  func (c *connection) onStreamStateChange(id quic.StreamID, state streamState, e error) {
    75  	c.streamMx.Lock()
    76  	defer c.streamMx.Unlock()
    77  
    78  	d, ok := c.streams[id]
    79  	if !ok { // should never happen
    80  		return
    81  	}
    82  	var isDone bool
    83  	//nolint:exhaustive // These are all the cases we care about.
    84  	switch state {
    85  	case streamStateReceiveClosed:
    86  		isDone = d.SetReceiveError(e)
    87  	case streamStateSendClosed:
    88  		isDone = d.SetSendError(e)
    89  	case streamStateSendAndReceiveClosed:
    90  		isDone = true
    91  	default:
    92  		return
    93  	}
    94  	if isDone {
    95  		delete(c.streams, id)
    96  	}
    97  }
    98  
    99  func (c *connection) openRequestStream(
   100  	ctx context.Context,
   101  	requestWriter *requestWriter,
   102  	reqDone chan<- struct{},
   103  	disableCompression bool,
   104  	maxHeaderBytes uint64,
   105  ) (*requestStream, error) {
   106  	str, err := c.Connection.OpenStreamSync(ctx)
   107  	if err != nil {
   108  		return nil, err
   109  	}
   110  	datagrams := newDatagrammer(func(b []byte) error { return c.sendDatagram(str.StreamID(), b) })
   111  	c.streamMx.Lock()
   112  	c.streams[str.StreamID()] = datagrams
   113  	c.streamMx.Unlock()
   114  	qstr := newStateTrackingStream(str, func(s streamState, e error) { c.onStreamStateChange(str.StreamID(), s, e) })
   115  	hstr := newStream(qstr, c, datagrams)
   116  	return newRequestStream(hstr, requestWriter, reqDone, c.decoder, disableCompression, maxHeaderBytes), nil
   117  }
   118  
   119  func (c *connection) acceptStream(ctx context.Context) (quic.Stream, *datagrammer, error) {
   120  	str, err := c.AcceptStream(ctx)
   121  	if err != nil {
   122  		return nil, nil, err
   123  	}
   124  	datagrams := newDatagrammer(func(b []byte) error { return c.sendDatagram(str.StreamID(), b) })
   125  	if c.perspective == protocol.PerspectiveServer {
   126  		c.streamMx.Lock()
   127  		c.streams[str.StreamID()] = datagrams
   128  		c.streamMx.Unlock()
   129  		str = newStateTrackingStream(str, func(s streamState, e error) { c.onStreamStateChange(str.StreamID(), s, e) })
   130  	}
   131  	return str, datagrams, nil
   132  }
   133  
   134  func (c *connection) HandleUnidirectionalStreams(hijack func(StreamType, quic.ConnectionTracingID, quic.ReceiveStream, error) (hijacked bool)) {
   135  	var (
   136  		rcvdControlStr      atomic.Bool
   137  		rcvdQPACKEncoderStr atomic.Bool
   138  		rcvdQPACKDecoderStr atomic.Bool
   139  	)
   140  
   141  	for {
   142  		str, err := c.Connection.AcceptUniStream(context.Background())
   143  		if err != nil {
   144  			if c.logger != nil {
   145  				c.logger.Debug("accepting unidirectional stream failed", "error", err)
   146  			}
   147  			return
   148  		}
   149  
   150  		go func(str quic.ReceiveStream) {
   151  			streamType, err := quicvarint.Read(quicvarint.NewReader(str))
   152  			if err != nil {
   153  				id := c.Connection.Context().Value(quic.ConnectionTracingKey).(quic.ConnectionTracingID)
   154  				if hijack != nil && hijack(StreamType(streamType), id, str, err) {
   155  					return
   156  				}
   157  				if c.logger != nil {
   158  					c.logger.Debug("reading stream type on stream failed", "stream ID", str.StreamID(), "error", err)
   159  				}
   160  				return
   161  			}
   162  			// We're only interested in the control stream here.
   163  			switch streamType {
   164  			case streamTypeControlStream:
   165  			case streamTypeQPACKEncoderStream:
   166  				if isFirst := rcvdQPACKEncoderStr.CompareAndSwap(false, true); !isFirst {
   167  					c.Connection.CloseWithError(quic.ApplicationErrorCode(ErrCodeStreamCreationError), "duplicate QPACK encoder stream")
   168  				}
   169  				// Our QPACK implementation doesn't use the dynamic table yet.
   170  				return
   171  			case streamTypeQPACKDecoderStream:
   172  				if isFirst := rcvdQPACKDecoderStr.CompareAndSwap(false, true); !isFirst {
   173  					c.Connection.CloseWithError(quic.ApplicationErrorCode(ErrCodeStreamCreationError), "duplicate QPACK decoder stream")
   174  				}
   175  				// Our QPACK implementation doesn't use the dynamic table yet.
   176  				return
   177  			case streamTypePushStream:
   178  				switch c.perspective {
   179  				case protocol.PerspectiveClient:
   180  					// we never increased the Push ID, so we don't expect any push streams
   181  					c.Connection.CloseWithError(quic.ApplicationErrorCode(ErrCodeIDError), "")
   182  				case protocol.PerspectiveServer:
   183  					// only the server can push
   184  					c.Connection.CloseWithError(quic.ApplicationErrorCode(ErrCodeStreamCreationError), "")
   185  				}
   186  				return
   187  			default:
   188  				if hijack != nil {
   189  					if hijack(
   190  						StreamType(streamType),
   191  						c.Connection.Context().Value(quic.ConnectionTracingKey).(quic.ConnectionTracingID),
   192  						str,
   193  						nil,
   194  					) {
   195  						return
   196  					}
   197  				}
   198  				str.CancelRead(quic.StreamErrorCode(ErrCodeStreamCreationError))
   199  				return
   200  			}
   201  			// Only a single control stream is allowed.
   202  			if isFirstControlStr := rcvdControlStr.CompareAndSwap(false, true); !isFirstControlStr {
   203  				c.Connection.CloseWithError(quic.ApplicationErrorCode(ErrCodeStreamCreationError), "duplicate control stream")
   204  				return
   205  			}
   206  			f, err := parseNextFrame(str, nil)
   207  			if err != nil {
   208  				c.Connection.CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameError), "")
   209  				return
   210  			}
   211  			sf, ok := f.(*settingsFrame)
   212  			if !ok {
   213  				c.Connection.CloseWithError(quic.ApplicationErrorCode(ErrCodeMissingSettings), "")
   214  				return
   215  			}
   216  			c.settings = &Settings{
   217  				EnableDatagrams:       sf.Datagram,
   218  				EnableExtendedConnect: sf.ExtendedConnect,
   219  				Other:                 sf.Other,
   220  			}
   221  			close(c.receivedSettings)
   222  			if !sf.Datagram {
   223  				return
   224  			}
   225  			// If datagram support was enabled on our side as well as on the server side,
   226  			// we can expect it to have been negotiated both on the transport and on the HTTP/3 layer.
   227  			// Note: ConnectionState() will block until the handshake is complete (relevant when using 0-RTT).
   228  			if c.enableDatagrams && !c.Connection.ConnectionState().SupportsDatagrams {
   229  				c.Connection.CloseWithError(quic.ApplicationErrorCode(ErrCodeSettingsError), "missing QUIC Datagram support")
   230  				return
   231  			}
   232  			if c.enableDatagrams {
   233  				go func() {
   234  					if err := c.receiveDatagrams(); err != nil {
   235  						if c.logger != nil {
   236  							c.logger.Debug("receiving datagrams failed", "error", err)
   237  						}
   238  					}
   239  				}()
   240  			}
   241  		}(str)
   242  	}
   243  }
   244  
   245  func (c *connection) sendDatagram(streamID protocol.StreamID, b []byte) error {
   246  	// TODO: this creates a lot of garbage and an additional copy
   247  	data := make([]byte, 0, len(b)+8)
   248  	data = quicvarint.Append(data, uint64(streamID/4))
   249  	data = append(data, b...)
   250  	return c.Connection.SendDatagram(data)
   251  }
   252  
   253  func (c *connection) receiveDatagrams() error {
   254  	for {
   255  		b, err := c.Connection.ReceiveDatagram(context.Background())
   256  		if err != nil {
   257  			return err
   258  		}
   259  		// TODO: this is quite wasteful in terms of allocations
   260  		r := bytes.NewReader(b)
   261  		quarterStreamID, err := quicvarint.Read(r)
   262  		if err != nil {
   263  			c.Connection.CloseWithError(quic.ApplicationErrorCode(ErrCodeDatagramError), "")
   264  			return fmt.Errorf("could not read quarter stream id: %w", err)
   265  		}
   266  		if quarterStreamID > maxQuarterStreamID {
   267  			c.Connection.CloseWithError(quic.ApplicationErrorCode(ErrCodeDatagramError), "")
   268  			return fmt.Errorf("invalid quarter stream id: %w", err)
   269  		}
   270  		streamID := protocol.StreamID(4 * quarterStreamID)
   271  		c.streamMx.Lock()
   272  		dg, ok := c.streams[streamID]
   273  		if !ok {
   274  			c.streamMx.Unlock()
   275  			return nil
   276  		}
   277  		c.streamMx.Unlock()
   278  		dg.enqueue(b[len(b)-r.Len():])
   279  	}
   280  }
   281  
   282  // ReceivedSettings returns a channel that is closed once the peer's SETTINGS frame was received.
   283  func (c *connection) ReceivedSettings() <-chan struct{} { return c.receivedSettings }
   284  
   285  // Settings returns the settings received on this connection.
   286  // It is only valid to call this function after the channel returned by ReceivedSettings was closed.
   287  func (c *connection) Settings() *Settings { return c.settings }