github.com/sagernet/quic-go@v0.43.1-beta.1/streams_map.go (about)

     1  package quic
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"net"
     8  	"sync"
     9  
    10  	"github.com/sagernet/quic-go/internal/flowcontrol"
    11  	"github.com/sagernet/quic-go/internal/protocol"
    12  	"github.com/sagernet/quic-go/internal/qerr"
    13  	"github.com/sagernet/quic-go/internal/wire"
    14  )
    15  
    16  type streamError struct {
    17  	message string
    18  	nums    []protocol.StreamNum
    19  }
    20  
    21  func (e streamError) Error() string {
    22  	return e.message
    23  }
    24  
    25  func convertStreamError(err error, stype protocol.StreamType, pers protocol.Perspective) error {
    26  	strError, ok := err.(streamError)
    27  	if !ok {
    28  		return err
    29  	}
    30  	ids := make([]interface{}, len(strError.nums))
    31  	for i, num := range strError.nums {
    32  		ids[i] = num.StreamID(stype, pers)
    33  	}
    34  	return fmt.Errorf(strError.Error(), ids...)
    35  }
    36  
    37  type streamOpenErr struct{ error }
    38  
    39  var _ net.Error = &streamOpenErr{}
    40  
    41  func (e streamOpenErr) Temporary() bool { return e.error == errTooManyOpenStreams }
    42  func (streamOpenErr) Timeout() bool     { return false }
    43  
    44  // errTooManyOpenStreams is used internally by the outgoing streams maps.
    45  var errTooManyOpenStreams = errors.New("too many open streams")
    46  
    47  type streamsMap struct {
    48  	ctx         context.Context // not used for cancellations, but carries the values associated with the connection
    49  	perspective protocol.Perspective
    50  
    51  	maxIncomingBidiStreams uint64
    52  	maxIncomingUniStreams  uint64
    53  
    54  	sender            streamSender
    55  	newFlowController func(protocol.StreamID) flowcontrol.StreamFlowController
    56  
    57  	mutex               sync.Mutex
    58  	outgoingBidiStreams *outgoingStreamsMap[streamI]
    59  	outgoingUniStreams  *outgoingStreamsMap[sendStreamI]
    60  	incomingBidiStreams *incomingStreamsMap[streamI]
    61  	incomingUniStreams  *incomingStreamsMap[receiveStreamI]
    62  	reset               bool
    63  }
    64  
    65  var _ streamManager = &streamsMap{}
    66  
    67  func newStreamsMap(
    68  	ctx context.Context,
    69  	sender streamSender,
    70  	newFlowController func(protocol.StreamID) flowcontrol.StreamFlowController,
    71  	maxIncomingBidiStreams uint64,
    72  	maxIncomingUniStreams uint64,
    73  	perspective protocol.Perspective,
    74  ) streamManager {
    75  	m := &streamsMap{
    76  		ctx:                    ctx,
    77  		perspective:            perspective,
    78  		newFlowController:      newFlowController,
    79  		maxIncomingBidiStreams: maxIncomingBidiStreams,
    80  		maxIncomingUniStreams:  maxIncomingUniStreams,
    81  		sender:                 sender,
    82  	}
    83  	m.initMaps()
    84  	return m
    85  }
    86  
    87  func (m *streamsMap) initMaps() {
    88  	m.outgoingBidiStreams = newOutgoingStreamsMap(
    89  		protocol.StreamTypeBidi,
    90  		func(num protocol.StreamNum) streamI {
    91  			id := num.StreamID(protocol.StreamTypeBidi, m.perspective)
    92  			return newStream(m.ctx, id, m.sender, m.newFlowController(id))
    93  		},
    94  		m.sender.queueControlFrame,
    95  	)
    96  	m.incomingBidiStreams = newIncomingStreamsMap(
    97  		protocol.StreamTypeBidi,
    98  		func(num protocol.StreamNum) streamI {
    99  			id := num.StreamID(protocol.StreamTypeBidi, m.perspective.Opposite())
   100  			return newStream(m.ctx, id, m.sender, m.newFlowController(id))
   101  		},
   102  		m.maxIncomingBidiStreams,
   103  		m.sender.queueControlFrame,
   104  	)
   105  	m.outgoingUniStreams = newOutgoingStreamsMap(
   106  		protocol.StreamTypeUni,
   107  		func(num protocol.StreamNum) sendStreamI {
   108  			id := num.StreamID(protocol.StreamTypeUni, m.perspective)
   109  			return newSendStream(m.ctx, id, m.sender, m.newFlowController(id))
   110  		},
   111  		m.sender.queueControlFrame,
   112  	)
   113  	m.incomingUniStreams = newIncomingStreamsMap(
   114  		protocol.StreamTypeUni,
   115  		func(num protocol.StreamNum) receiveStreamI {
   116  			id := num.StreamID(protocol.StreamTypeUni, m.perspective.Opposite())
   117  			return newReceiveStream(id, m.sender, m.newFlowController(id))
   118  		},
   119  		m.maxIncomingUniStreams,
   120  		m.sender.queueControlFrame,
   121  	)
   122  }
   123  
   124  func (m *streamsMap) OpenStream() (Stream, error) {
   125  	m.mutex.Lock()
   126  	reset := m.reset
   127  	mm := m.outgoingBidiStreams
   128  	m.mutex.Unlock()
   129  	if reset {
   130  		return nil, Err0RTTRejected
   131  	}
   132  	str, err := mm.OpenStream()
   133  	return str, convertStreamError(err, protocol.StreamTypeBidi, m.perspective)
   134  }
   135  
   136  func (m *streamsMap) OpenStreamSync(ctx context.Context) (Stream, error) {
   137  	m.mutex.Lock()
   138  	reset := m.reset
   139  	mm := m.outgoingBidiStreams
   140  	m.mutex.Unlock()
   141  	if reset {
   142  		return nil, Err0RTTRejected
   143  	}
   144  	str, err := mm.OpenStreamSync(ctx)
   145  	return str, convertStreamError(err, protocol.StreamTypeBidi, m.perspective)
   146  }
   147  
   148  func (m *streamsMap) OpenUniStream() (SendStream, error) {
   149  	m.mutex.Lock()
   150  	reset := m.reset
   151  	mm := m.outgoingUniStreams
   152  	m.mutex.Unlock()
   153  	if reset {
   154  		return nil, Err0RTTRejected
   155  	}
   156  	str, err := mm.OpenStream()
   157  	return str, convertStreamError(err, protocol.StreamTypeBidi, m.perspective)
   158  }
   159  
   160  func (m *streamsMap) OpenUniStreamSync(ctx context.Context) (SendStream, error) {
   161  	m.mutex.Lock()
   162  	reset := m.reset
   163  	mm := m.outgoingUniStreams
   164  	m.mutex.Unlock()
   165  	if reset {
   166  		return nil, Err0RTTRejected
   167  	}
   168  	str, err := mm.OpenStreamSync(ctx)
   169  	return str, convertStreamError(err, protocol.StreamTypeUni, m.perspective)
   170  }
   171  
   172  func (m *streamsMap) AcceptStream(ctx context.Context) (Stream, error) {
   173  	m.mutex.Lock()
   174  	reset := m.reset
   175  	mm := m.incomingBidiStreams
   176  	m.mutex.Unlock()
   177  	if reset {
   178  		return nil, Err0RTTRejected
   179  	}
   180  	str, err := mm.AcceptStream(ctx)
   181  	return str, convertStreamError(err, protocol.StreamTypeBidi, m.perspective.Opposite())
   182  }
   183  
   184  func (m *streamsMap) AcceptUniStream(ctx context.Context) (ReceiveStream, error) {
   185  	m.mutex.Lock()
   186  	reset := m.reset
   187  	mm := m.incomingUniStreams
   188  	m.mutex.Unlock()
   189  	if reset {
   190  		return nil, Err0RTTRejected
   191  	}
   192  	str, err := mm.AcceptStream(ctx)
   193  	return str, convertStreamError(err, protocol.StreamTypeUni, m.perspective.Opposite())
   194  }
   195  
   196  func (m *streamsMap) DeleteStream(id protocol.StreamID) error {
   197  	num := id.StreamNum()
   198  	switch id.Type() {
   199  	case protocol.StreamTypeUni:
   200  		if id.InitiatedBy() == m.perspective {
   201  			return convertStreamError(m.outgoingUniStreams.DeleteStream(num), protocol.StreamTypeUni, m.perspective)
   202  		}
   203  		return convertStreamError(m.incomingUniStreams.DeleteStream(num), protocol.StreamTypeUni, m.perspective.Opposite())
   204  	case protocol.StreamTypeBidi:
   205  		if id.InitiatedBy() == m.perspective {
   206  			return convertStreamError(m.outgoingBidiStreams.DeleteStream(num), protocol.StreamTypeBidi, m.perspective)
   207  		}
   208  		return convertStreamError(m.incomingBidiStreams.DeleteStream(num), protocol.StreamTypeBidi, m.perspective.Opposite())
   209  	}
   210  	panic("")
   211  }
   212  
   213  func (m *streamsMap) GetOrOpenReceiveStream(id protocol.StreamID) (receiveStreamI, error) {
   214  	str, err := m.getOrOpenReceiveStream(id)
   215  	if err != nil {
   216  		return nil, &qerr.TransportError{
   217  			ErrorCode:    qerr.StreamStateError,
   218  			ErrorMessage: err.Error(),
   219  		}
   220  	}
   221  	return str, nil
   222  }
   223  
   224  func (m *streamsMap) getOrOpenReceiveStream(id protocol.StreamID) (receiveStreamI, error) {
   225  	num := id.StreamNum()
   226  	switch id.Type() {
   227  	case protocol.StreamTypeUni:
   228  		if id.InitiatedBy() == m.perspective {
   229  			// an outgoing unidirectional stream is a send stream, not a receive stream
   230  			return nil, fmt.Errorf("peer attempted to open receive stream %d", id)
   231  		}
   232  		str, err := m.incomingUniStreams.GetOrOpenStream(num)
   233  		return str, convertStreamError(err, protocol.StreamTypeUni, m.perspective)
   234  	case protocol.StreamTypeBidi:
   235  		var str receiveStreamI
   236  		var err error
   237  		if id.InitiatedBy() == m.perspective {
   238  			str, err = m.outgoingBidiStreams.GetStream(num)
   239  		} else {
   240  			str, err = m.incomingBidiStreams.GetOrOpenStream(num)
   241  		}
   242  		return str, convertStreamError(err, protocol.StreamTypeBidi, id.InitiatedBy())
   243  	}
   244  	panic("")
   245  }
   246  
   247  func (m *streamsMap) GetOrOpenSendStream(id protocol.StreamID) (sendStreamI, error) {
   248  	str, err := m.getOrOpenSendStream(id)
   249  	if err != nil {
   250  		return nil, &qerr.TransportError{
   251  			ErrorCode:    qerr.StreamStateError,
   252  			ErrorMessage: err.Error(),
   253  		}
   254  	}
   255  	return str, nil
   256  }
   257  
   258  func (m *streamsMap) getOrOpenSendStream(id protocol.StreamID) (sendStreamI, error) {
   259  	num := id.StreamNum()
   260  	switch id.Type() {
   261  	case protocol.StreamTypeUni:
   262  		if id.InitiatedBy() == m.perspective {
   263  			str, err := m.outgoingUniStreams.GetStream(num)
   264  			return str, convertStreamError(err, protocol.StreamTypeUni, m.perspective)
   265  		}
   266  		// an incoming unidirectional stream is a receive stream, not a send stream
   267  		return nil, fmt.Errorf("peer attempted to open send stream %d", id)
   268  	case protocol.StreamTypeBidi:
   269  		var str sendStreamI
   270  		var err error
   271  		if id.InitiatedBy() == m.perspective {
   272  			str, err = m.outgoingBidiStreams.GetStream(num)
   273  		} else {
   274  			str, err = m.incomingBidiStreams.GetOrOpenStream(num)
   275  		}
   276  		return str, convertStreamError(err, protocol.StreamTypeBidi, id.InitiatedBy())
   277  	}
   278  	panic("")
   279  }
   280  
   281  func (m *streamsMap) HandleMaxStreamsFrame(f *wire.MaxStreamsFrame) {
   282  	switch f.Type {
   283  	case protocol.StreamTypeUni:
   284  		m.outgoingUniStreams.SetMaxStream(f.MaxStreamNum)
   285  	case protocol.StreamTypeBidi:
   286  		m.outgoingBidiStreams.SetMaxStream(f.MaxStreamNum)
   287  	}
   288  }
   289  
   290  func (m *streamsMap) UpdateLimits(p *wire.TransportParameters) {
   291  	m.outgoingBidiStreams.UpdateSendWindow(p.InitialMaxStreamDataBidiRemote)
   292  	m.outgoingBidiStreams.SetMaxStream(p.MaxBidiStreamNum)
   293  	m.outgoingUniStreams.UpdateSendWindow(p.InitialMaxStreamDataUni)
   294  	m.outgoingUniStreams.SetMaxStream(p.MaxUniStreamNum)
   295  }
   296  
   297  func (m *streamsMap) CloseWithError(err error) {
   298  	m.outgoingBidiStreams.CloseWithError(err)
   299  	m.outgoingUniStreams.CloseWithError(err)
   300  	m.incomingBidiStreams.CloseWithError(err)
   301  	m.incomingUniStreams.CloseWithError(err)
   302  }
   303  
   304  // ResetFor0RTT resets is used when 0-RTT is rejected. In that case, the streams maps are
   305  // 1. closed with an Err0RTTRejected, making calls to Open{Uni}Stream{Sync} / Accept{Uni}Stream return that error.
   306  // 2. reset to their initial state, such that we can immediately process new incoming stream data.
   307  // Afterwards, calls to Open{Uni}Stream{Sync} / Accept{Uni}Stream will continue to return the error,
   308  // until UseResetMaps() has been called.
   309  func (m *streamsMap) ResetFor0RTT() {
   310  	m.mutex.Lock()
   311  	defer m.mutex.Unlock()
   312  	m.reset = true
   313  	m.CloseWithError(Err0RTTRejected)
   314  	m.initMaps()
   315  }
   316  
   317  func (m *streamsMap) UseResetMaps() {
   318  	m.mutex.Lock()
   319  	m.reset = false
   320  	m.mutex.Unlock()
   321  }