github.com/sagernet/quic-go@v0.43.1-beta.1/http3/conn.go (about)

     1  package http3
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"fmt"
     7  	"net"
     8  	"sync"
     9  	"sync/atomic"
    10  
    11  	"github.com/quic-go/qpack"
    12  	"github.com/sagernet/quic-go"
    13  	"github.com/sagernet/quic-go/internal/protocol"
    14  	"github.com/sagernet/quic-go/quicvarint"
    15  	"golang.org/x/exp/slog"
    16  )
    17  
    18  // Connection is an HTTP/3 connection.
    19  // It has all methods from the quic.Connection expect for AcceptStream, AcceptUniStream,
    20  // SendDatagram and ReceiveDatagram.
    21  type Connection interface {
    22  	OpenStream() (quic.Stream, error)
    23  	OpenStreamSync(context.Context) (quic.Stream, error)
    24  	OpenUniStream() (quic.SendStream, error)
    25  	OpenUniStreamSync(context.Context) (quic.SendStream, error)
    26  	LocalAddr() net.Addr
    27  	RemoteAddr() net.Addr
    28  	CloseWithError(quic.ApplicationErrorCode, string) error
    29  	Context() context.Context
    30  	ConnectionState() quic.ConnectionState
    31  
    32  	// ReceivedSettings returns a channel that is closed once the client's SETTINGS frame was received.
    33  	ReceivedSettings() <-chan struct{}
    34  	// Settings returns the settings received on this connection.
    35  	Settings() *Settings
    36  }
    37  
    38  type connection struct {
    39  	quic.Connection
    40  
    41  	perspective protocol.Perspective
    42  	logger      *slog.Logger
    43  
    44  	enableDatagrams bool
    45  
    46  	decoder *qpack.Decoder
    47  
    48  	streamMx sync.Mutex
    49  	streams  map[protocol.StreamID]*datagrammer
    50  
    51  	settings         *Settings
    52  	receivedSettings chan struct{}
    53  }
    54  
    55  func newConnection(
    56  	quicConn quic.Connection,
    57  	enableDatagrams bool,
    58  	perspective protocol.Perspective,
    59  	logger *slog.Logger,
    60  ) *connection {
    61  	c := &connection{
    62  		Connection:       quicConn,
    63  		perspective:      perspective,
    64  		logger:           logger,
    65  		enableDatagrams:  enableDatagrams,
    66  		decoder:          qpack.NewDecoder(func(hf qpack.HeaderField) {}),
    67  		receivedSettings: make(chan struct{}),
    68  		streams:          make(map[protocol.StreamID]*datagrammer),
    69  	}
    70  	return c
    71  }
    72  
    73  func (c *connection) onStreamStateChange(id quic.StreamID, state streamState, e error) {
    74  	c.streamMx.Lock()
    75  	defer c.streamMx.Unlock()
    76  
    77  	d, ok := c.streams[id]
    78  	if !ok { // should never happen
    79  		return
    80  	}
    81  	var isDone bool
    82  	//nolint:exhaustive // These are all the cases we care about.
    83  	switch state {
    84  	case streamStateReceiveClosed:
    85  		isDone = d.SetReceiveError(e)
    86  	case streamStateSendClosed:
    87  		isDone = d.SetSendError(e)
    88  	case streamStateSendAndReceiveClosed:
    89  		isDone = true
    90  	default:
    91  		return
    92  	}
    93  	if isDone {
    94  		delete(c.streams, id)
    95  	}
    96  }
    97  
    98  func (c *connection) openRequestStream(
    99  	ctx context.Context,
   100  	requestWriter *requestWriter,
   101  	reqDone chan<- struct{},
   102  	disableCompression bool,
   103  	maxHeaderBytes uint64,
   104  ) (*requestStream, error) {
   105  	str, err := c.Connection.OpenStreamSync(ctx)
   106  	if err != nil {
   107  		return nil, err
   108  	}
   109  	datagrams := newDatagrammer(func(b []byte) error { return c.sendDatagram(str.StreamID(), b) })
   110  	c.streamMx.Lock()
   111  	c.streams[str.StreamID()] = datagrams
   112  	c.streamMx.Unlock()
   113  	qstr := newStateTrackingStream(str, func(s streamState, e error) { c.onStreamStateChange(str.StreamID(), s, e) })
   114  	hstr := newStream(qstr, c, datagrams)
   115  	return newRequestStream(hstr, requestWriter, reqDone, c.decoder, disableCompression, maxHeaderBytes), nil
   116  }
   117  
   118  func (c *connection) acceptStream(ctx context.Context) (quic.Stream, *datagrammer, error) {
   119  	str, err := c.AcceptStream(ctx)
   120  	if err != nil {
   121  		return nil, nil, err
   122  	}
   123  	datagrams := newDatagrammer(func(b []byte) error { return c.sendDatagram(str.StreamID(), b) })
   124  	if c.perspective == protocol.PerspectiveServer {
   125  		c.streamMx.Lock()
   126  		c.streams[str.StreamID()] = datagrams
   127  		c.streamMx.Unlock()
   128  		str = newStateTrackingStream(str, func(s streamState, e error) { c.onStreamStateChange(str.StreamID(), s, e) })
   129  	}
   130  	return str, datagrams, nil
   131  }
   132  
   133  func (c *connection) HandleUnidirectionalStreams(hijack func(StreamType, quic.ConnectionTracingID, quic.ReceiveStream, error) (hijacked bool)) {
   134  	var (
   135  		rcvdControlStr      atomic.Bool
   136  		rcvdQPACKEncoderStr atomic.Bool
   137  		rcvdQPACKDecoderStr atomic.Bool
   138  	)
   139  
   140  	for {
   141  		str, err := c.Connection.AcceptUniStream(context.Background())
   142  		if err != nil {
   143  			if c.logger != nil {
   144  				c.logger.Debug("accepting unidirectional stream failed", "error", err)
   145  			}
   146  			return
   147  		}
   148  
   149  		go func(str quic.ReceiveStream) {
   150  			streamType, err := quicvarint.Read(quicvarint.NewReader(str))
   151  			if err != nil {
   152  				id := c.Connection.Context().Value(quic.ConnectionTracingKey).(quic.ConnectionTracingID)
   153  				if hijack != nil && hijack(StreamType(streamType), id, str, err) {
   154  					return
   155  				}
   156  				if c.logger != nil {
   157  					c.logger.Debug("reading stream type on stream failed", "stream ID", str.StreamID(), "error", err)
   158  				}
   159  				return
   160  			}
   161  			// We're only interested in the control stream here.
   162  			switch streamType {
   163  			case streamTypeControlStream:
   164  			case streamTypeQPACKEncoderStream:
   165  				if isFirst := rcvdQPACKEncoderStr.CompareAndSwap(false, true); !isFirst {
   166  					c.Connection.CloseWithError(quic.ApplicationErrorCode(ErrCodeStreamCreationError), "duplicate QPACK encoder stream")
   167  				}
   168  				// Our QPACK implementation doesn't use the dynamic table yet.
   169  				return
   170  			case streamTypeQPACKDecoderStream:
   171  				if isFirst := rcvdQPACKDecoderStr.CompareAndSwap(false, true); !isFirst {
   172  					c.Connection.CloseWithError(quic.ApplicationErrorCode(ErrCodeStreamCreationError), "duplicate QPACK decoder stream")
   173  				}
   174  				// Our QPACK implementation doesn't use the dynamic table yet.
   175  				return
   176  			case streamTypePushStream:
   177  				switch c.perspective {
   178  				case protocol.PerspectiveClient:
   179  					// we never increased the Push ID, so we don't expect any push streams
   180  					c.Connection.CloseWithError(quic.ApplicationErrorCode(ErrCodeIDError), "")
   181  				case protocol.PerspectiveServer:
   182  					// only the server can push
   183  					c.Connection.CloseWithError(quic.ApplicationErrorCode(ErrCodeStreamCreationError), "")
   184  				}
   185  				return
   186  			default:
   187  				if hijack != nil {
   188  					if hijack(
   189  						StreamType(streamType),
   190  						c.Connection.Context().Value(quic.ConnectionTracingKey).(quic.ConnectionTracingID),
   191  						str,
   192  						nil,
   193  					) {
   194  						return
   195  					}
   196  				}
   197  				str.CancelRead(quic.StreamErrorCode(ErrCodeStreamCreationError))
   198  				return
   199  			}
   200  			// Only a single control stream is allowed.
   201  			if isFirstControlStr := rcvdControlStr.CompareAndSwap(false, true); !isFirstControlStr {
   202  				c.Connection.CloseWithError(quic.ApplicationErrorCode(ErrCodeStreamCreationError), "duplicate control stream")
   203  				return
   204  			}
   205  			f, err := parseNextFrame(str, nil)
   206  			if err != nil {
   207  				c.Connection.CloseWithError(quic.ApplicationErrorCode(ErrCodeFrameError), "")
   208  				return
   209  			}
   210  			sf, ok := f.(*settingsFrame)
   211  			if !ok {
   212  				c.Connection.CloseWithError(quic.ApplicationErrorCode(ErrCodeMissingSettings), "")
   213  				return
   214  			}
   215  			c.settings = &Settings{
   216  				EnableDatagrams:       sf.Datagram,
   217  				EnableExtendedConnect: sf.ExtendedConnect,
   218  				Other:                 sf.Other,
   219  			}
   220  			close(c.receivedSettings)
   221  			if !sf.Datagram {
   222  				return
   223  			}
   224  			// If datagram support was enabled on our side as well as on the server side,
   225  			// we can expect it to have been negotiated both on the transport and on the HTTP/3 layer.
   226  			// Note: ConnectionState() will block until the handshake is complete (relevant when using 0-RTT).
   227  			if c.enableDatagrams && !c.Connection.ConnectionState().SupportsDatagrams {
   228  				c.Connection.CloseWithError(quic.ApplicationErrorCode(ErrCodeSettingsError), "missing QUIC Datagram support")
   229  				return
   230  			}
   231  			go func() {
   232  				if err := c.receiveDatagrams(); err != nil {
   233  					if c.logger != nil {
   234  						c.logger.Debug("receiving datagrams failed", "error", err)
   235  					}
   236  				}
   237  			}()
   238  		}(str)
   239  	}
   240  }
   241  
   242  func (c *connection) sendDatagram(streamID protocol.StreamID, b []byte) error {
   243  	// TODO: this creates a lot of garbage and an additional copy
   244  	data := make([]byte, 0, len(b)+8)
   245  	data = quicvarint.Append(data, uint64(streamID/4))
   246  	data = append(data, b...)
   247  	return c.Connection.SendDatagram(data)
   248  }
   249  
   250  func (c *connection) receiveDatagrams() error {
   251  	for {
   252  		b, err := c.Connection.ReceiveDatagram(context.Background())
   253  		if err != nil {
   254  			return err
   255  		}
   256  		// TODO: this is quite wasteful in terms of allocations
   257  		r := bytes.NewReader(b)
   258  		quarterStreamID, err := quicvarint.Read(r)
   259  		if err != nil {
   260  			c.Connection.CloseWithError(quic.ApplicationErrorCode(ErrCodeDatagramError), "")
   261  			return fmt.Errorf("could not read quarter stream id: %w", err)
   262  		}
   263  		if quarterStreamID > maxQuarterStreamID {
   264  			c.Connection.CloseWithError(quic.ApplicationErrorCode(ErrCodeDatagramError), "")
   265  			return fmt.Errorf("invalid quarter stream id: %w", err)
   266  		}
   267  		streamID := protocol.StreamID(4 * quarterStreamID)
   268  		c.streamMx.Lock()
   269  		dg, ok := c.streams[streamID]
   270  		if !ok {
   271  			c.streamMx.Unlock()
   272  			return nil
   273  		}
   274  		c.streamMx.Unlock()
   275  		dg.enqueue(b[len(b)-r.Len():])
   276  	}
   277  }
   278  
   279  // ReceivedSettings returns a channel that is closed once the peer's SETTINGS frame was received.
   280  func (c *connection) ReceivedSettings() <-chan struct{} { return c.receivedSettings }
   281  
   282  // Settings returns the settings received on this connection.
   283  // It is only valid to call this function after the channel returned by ReceivedSettings was closed.
   284  func (c *connection) Settings() *Settings { return c.settings }