github.com/sagernet/quic-go@v0.43.1-beta.1/http3_ech/conn.go (about)

     1  package http3
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"fmt"
     7  	"net"
     8  	"sync"
     9  	"sync/atomic"
    10  
    11  	"github.com/quic-go/qpack"
    12  	"github.com/sagernet/quic-go/ech"
    13  	"github.com/sagernet/quic-go/internal/protocol"
    14  	"github.com/sagernet/quic-go/quicvarint"
    15  	"golang.org/x/exp/slog"
    16  )
    17  
    18  // Connection is an HTTP/3 connection.
    19  // It has all methods from the quic.Connection expect for AcceptStream, AcceptUniStream,
    20  // SendDatagram and ReceiveDatagram.
    21  type Connection interface {
    22  	OpenStream() (quic.Stream, error)
    23  	OpenStreamSync(context.Context) (quic.Stream, error)
    24  	OpenUniStream() (quic.SendStream, error)
    25  	OpenUniStreamSync(context.Context) (quic.SendStream, error)
    26  	LocalAddr() net.Addr
    27  	RemoteAddr() net.Addr
    28  	CloseWithError(quic.ApplicationErrorCode, string) error
    29  	Context() context.Context
    30  	ConnectionState() quic.ConnectionState
    31  
    32  	// ReceivedSettings returns a channel that is closed once the client's SETTINGS frame was received.
    33  	ReceivedSettings() <-chan struct{}
    34  	// Settings returns the settings received on this connection.
    35  	Settings() *Settings
    36  }
    37  
    38  type connection struct {
    39  	quic.Connection
    40  
    41  	perspective protocol.Perspective
    42  	logger      *slog.Logger
    43  
    44  	enableDatagrams bool
    45  
    46  	decoder *qpack.Decoder
    47  
    48  	streamMx sync.Mutex
    49  	streams  map[protocol.StreamID]*datagrammer
    50  
    51  	settings         *Settings
    52  	receivedSettings chan struct{}
    53  }
    54  
    55  func newConnection(
    56  	quicConn quic.Connection,
    57  	enableDatagrams bool,
    58  	perspective protocol.Perspective,
    59  	logger *slog.Logger,
    60  ) *connection {
    61  	c := &connection{
    62  		Connection:       quicConn,
    63  		perspective:      perspective,
    64  		logger:           logger,
    65  		enableDatagrams:  enableDatagrams,
    66  		decoder:          qpack.NewDecoder(func(hf qpack.HeaderField) {}),
    67  		receivedSettings: make(chan struct{}),
    68  		streams:          make(map[protocol.StreamID]*datagrammer),
    69  	}
    70  	return c
    71  }
    72  
    73  func (c *connection) onStreamStateChange(id quic.StreamID, state streamState, e error) {
    74  	c.streamMx.Lock()
    75  	defer c.streamMx.Unlock()
    76  
    77  	d, ok := c.streams[id]
    78  	if !ok { // should never happen
    79  		return
    80  	}
    81  	var isDone bool
    82  	//nolint:exhaustive // These are all the cases we care about.
    83  	switch state {
    84  	case streamStateReceiveClosed:
    85  		isDone = d.SetReceiveError(e)
    86  	case streamStateSendClosed:
    87  		isDone = d.SetSendError(e)
    88  	default:
    89  		return
    90  	}
    91  	if isDone {
    92  		delete(c.streams, id)
    93  	}
    94  }
    95  
    96  func (c *connection) openRequestStream(
    97  	ctx context.Context,
    98  	requestWriter *requestWriter,
    99  	reqDone chan<- struct{},
   100  	disableCompression bool,
   101  	maxHeaderBytes uint64,
   102  ) (*requestStream, error) {
   103  	str, err := c.Connection.OpenStreamSync(ctx)
   104  	if err != nil {
   105  		return nil, err
   106  	}
   107  	datagrams := newDatagrammer(func(b []byte) error { return c.sendDatagram(str.StreamID(), b) })
   108  	c.streamMx.Lock()
   109  	c.streams[str.StreamID()] = datagrams
   110  	c.streamMx.Unlock()
   111  	qstr := newStateTrackingStream(str, func(s streamState, e error) { c.onStreamStateChange(str.StreamID(), s, e) })
   112  	hstr := newStream(qstr, c, datagrams)
   113  	return newRequestStream(hstr, requestWriter, reqDone, c.decoder, disableCompression, maxHeaderBytes), nil
   114  }
   115  
   116  func (c *connection) acceptStream(ctx context.Context) (quic.Stream, *datagrammer, error) {
   117  	str, err := c.AcceptStream(ctx)
   118  	if err != nil {
   119  		return nil, nil, err
   120  	}
   121  	datagrams := newDatagrammer(func(b []byte) error { return c.sendDatagram(str.StreamID(), b) })
   122  	if c.perspective == protocol.PerspectiveServer {
   123  		c.streamMx.Lock()
   124  		c.streams[str.StreamID()] = datagrams
   125  		c.streamMx.Unlock()
   126  		str = newStateTrackingStream(str, func(s streamState, e error) { c.onStreamStateChange(str.StreamID(), s, e) })
   127  	}
   128  	return str, datagrams, nil
   129  }
   130  
   131  func (c *connection) HandleUnidirectionalStreams(hijack func(StreamType, quic.ConnectionTracingID, quic.ReceiveStream, error) (hijacked bool)) {
   132  	var (
   133  		rcvdControlStr      atomic.Bool
   134  		rcvdQPACKEncoderStr atomic.Bool
   135  		rcvdQPACKDecoderStr atomic.Bool
   136  	)
   137  
   138  	for {
   139  		str, err := c.Connection.AcceptUniStream(context.Background())
   140  		if err != nil {
   141  			if c.logger != nil {
   142  				c.logger.Debug("accepting unidirectional stream failed", "error", err)
   143  			}
   144  			return
   145  		}
   146  
   147  		go func(str quic.ReceiveStream) {
   148  			streamType, err := quicvarint.Read(quicvarint.NewReader(str))
   149  			if err != nil {
   150  				id := c.Connection.Context().Value(quic.ConnectionTracingKey).(quic.ConnectionTracingID)
   151  				if hijack != nil && hijack(StreamType(streamType), id, str, err) {
   152  					return
   153  				}
   154  				if c.logger != nil {
   155  					c.logger.Debug("reading stream type on stream failed", "stream ID", str.StreamID(), "error", err)
   156  				}
   157  				return
   158  			}
   159  			// We're only interested in the control stream here.
   160  			switch streamType {
   161  			case streamTypeControlStream:
   162  			case streamTypeQPACKEncoderStream:
   163  				if isFirst := rcvdQPACKEncoderStr.CompareAndSwap(false, true); !isFirst {
   164  					c.Connection.CloseWithError(quic.ApplicationErrorCode(ErrCodeStreamCreationError), "duplicate QPACK encoder stream")
   165  				}
   166  				// Our QPACK implementation doesn't use the dynamic table yet.
   167  				return
   168  			case streamTypeQPACKDecoderStream:
   169  				if isFirst := rcvdQPACKDecoderStr.CompareAndSwap(false, true); !isFirst {
   170  					c.Connection.CloseWithError(quic.ApplicationErrorCode(ErrCodeStreamCreationError), "duplicate QPACK decoder stream")
   171  				}
   172  				// Our QPACK implementation doesn't use the dynamic table yet.
   173  				return
   174  			case streamTypePushStream:
   175  				switch c.perspective {
   176  				case protocol.PerspectiveClient:
   177  					// we never increased the Push ID, so we don't expect any push streams
   178  					c.Connection.CloseWithError(quic.ApplicationErrorCode(ErrCodeIDError), "")
   179  				case protocol.PerspectiveServer:
   180  					// only the server can push
   181  					c.Connection.CloseWithError(quic.ApplicationErrorCode(ErrCodeStreamCreationError), "")
   182  				}
   183  				return
   184  			default:
   185  				if hijack != nil {
   186  					if hijack(
   187  						StreamType(streamType),
   188  						c.Connection.Context().Value(quic.ConnectionTracingKey).(quic.ConnectionTracingID),
   189  						str,
   190  						nil,
   191  					) {
   192  						return
   193  					}
   194  				}
   195  				str.CancelRead(quic.StreamErrorCode(ErrCodeStreamCreationError))
   196  				return
   197  			}
   198  			// Only a single control stream is allowed.
   199  			if isFirstControlStr := rcvdControlStr.CompareAndSwap(false, true); !isFirstControlStr {
   200  				c.Connection.CloseWithError(quic.ApplicationErrorCode(ErrCodeStreamCreationError), "duplicate control stream")
   201  				return
   202  			}
   203  			f, err := parseNextFrame(str, nil)
   204  			if err != nil {
   205  				c.Connection.CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameError), "")
   206  				return
   207  			}
   208  			sf, ok := f.(*settingsFrame)
   209  			if !ok {
   210  				c.Connection.CloseWithError(quic.ApplicationErrorCode(ErrCodeMissingSettings), "")
   211  				return
   212  			}
   213  			c.settings = &Settings{
   214  				EnableDatagrams:       sf.Datagram,
   215  				EnableExtendedConnect: sf.ExtendedConnect,
   216  				Other:                 sf.Other,
   217  			}
   218  			close(c.receivedSettings)
   219  			if !sf.Datagram {
   220  				return
   221  			}
   222  			// If datagram support was enabled on our side as well as on the server side,
   223  			// we can expect it to have been negotiated both on the transport and on the HTTP/3 layer.
   224  			// Note: ConnectionState() will block until the handshake is complete (relevant when using 0-RTT).
   225  			if c.enableDatagrams && !c.Connection.ConnectionState().SupportsDatagrams {
   226  				c.Connection.CloseWithError(quic.ApplicationErrorCode(ErrCodeSettingsError), "missing QUIC Datagram support")
   227  				return
   228  			}
   229  			go func() {
   230  				if err := c.receiveDatagrams(); err != nil {
   231  					if c.logger != nil {
   232  						c.logger.Debug("receiving datagrams failed", "error", err)
   233  					}
   234  				}
   235  			}()
   236  		}(str)
   237  	}
   238  }
   239  
   240  func (c *connection) sendDatagram(streamID protocol.StreamID, b []byte) error {
   241  	// TODO: this creates a lot of garbage and an additional copy
   242  	data := make([]byte, 0, len(b)+8)
   243  	data = quicvarint.Append(data, uint64(streamID/4))
   244  	data = append(data, b...)
   245  	return c.Connection.SendDatagram(data)
   246  }
   247  
   248  func (c *connection) receiveDatagrams() error {
   249  	for {
   250  		b, err := c.Connection.ReceiveDatagram(context.Background())
   251  		if err != nil {
   252  			return err
   253  		}
   254  		// TODO: this is quite wasteful in terms of allocations
   255  		r := bytes.NewReader(b)
   256  		quarterStreamID, err := quicvarint.Read(r)
   257  		if err != nil {
   258  			c.Connection.CloseWithError(quic.ApplicationErrorCode(ErrCodeDatagramError), "")
   259  			return fmt.Errorf("could not read quarter stream id: %w", err)
   260  		}
   261  		if quarterStreamID > maxQuarterStreamID {
   262  			c.Connection.CloseWithError(quic.ApplicationErrorCode(ErrCodeDatagramError), "")
   263  			return fmt.Errorf("invalid quarter stream id: %w", err)
   264  		}
   265  		streamID := protocol.StreamID(4 * quarterStreamID)
   266  		c.streamMx.Lock()
   267  		dg, ok := c.streams[streamID]
   268  		if !ok {
   269  			c.streamMx.Unlock()
   270  			return nil
   271  		}
   272  		c.streamMx.Unlock()
   273  		dg.enqueue(b[len(b)-r.Len():])
   274  	}
   275  }
   276  
   277  // ReceivedSettings returns a channel that is closed once the peer's SETTINGS frame was received.
   278  func (c *connection) ReceivedSettings() <-chan struct{} { return c.receivedSettings }
   279  
   280  // Settings returns the settings received on this connection.
   281  // It is only valid to call this function after the channel returned by ReceivedSettings was closed.
   282  func (c *connection) Settings() *Settings { return c.settings }