github.com/daeuniverse/quic-go@v0.0.0-20240413031024-943f218e0810/streams_map.go (about)

     1  package quic
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"net"
     8  	"sync"
     9  
    10  	"github.com/daeuniverse/quic-go/internal/flowcontrol"
    11  	"github.com/daeuniverse/quic-go/internal/protocol"
    12  	"github.com/daeuniverse/quic-go/internal/qerr"
    13  	"github.com/daeuniverse/quic-go/internal/wire"
    14  )
    15  
    16  type streamError struct {
    17  	message string
    18  	nums    []protocol.StreamNum
    19  }
    20  
    21  func (e streamError) Error() string {
    22  	return e.message
    23  }
    24  
    25  func convertStreamError(err error, stype protocol.StreamType, pers protocol.Perspective) error {
    26  	strError, ok := err.(streamError)
    27  	if !ok {
    28  		return err
    29  	}
    30  	ids := make([]interface{}, len(strError.nums))
    31  	for i, num := range strError.nums {
    32  		ids[i] = num.StreamID(stype, pers)
    33  	}
    34  	return fmt.Errorf(strError.Error(), ids...)
    35  }
    36  
    37  type streamOpenErr struct{ error }
    38  
    39  var _ net.Error = &streamOpenErr{}
    40  
    41  func (e streamOpenErr) Temporary() bool { return e.error == errTooManyOpenStreams }
    42  func (streamOpenErr) Timeout() bool     { return false }
    43  
    44  // errTooManyOpenStreams is used internally by the outgoing streams maps.
    45  var errTooManyOpenStreams = errors.New("too many open streams")
    46  
    47  type streamsMap struct {
    48  	perspective protocol.Perspective
    49  
    50  	maxIncomingBidiStreams uint64
    51  	maxIncomingUniStreams  uint64
    52  
    53  	sender            streamSender
    54  	newFlowController func(protocol.StreamID) flowcontrol.StreamFlowController
    55  
    56  	mutex               sync.Mutex
    57  	outgoingBidiStreams *outgoingStreamsMap[streamI]
    58  	outgoingUniStreams  *outgoingStreamsMap[sendStreamI]
    59  	incomingBidiStreams *incomingStreamsMap[streamI]
    60  	incomingUniStreams  *incomingStreamsMap[receiveStreamI]
    61  	capabilityCallback  func(n int64)
    62  	reset               bool
    63  }
    64  
    65  var _ streamManager = &streamsMap{}
    66  
    67  func newStreamsMap(
    68  	sender streamSender,
    69  	newFlowController func(protocol.StreamID) flowcontrol.StreamFlowController,
    70  	maxIncomingBidiStreams uint64,
    71  	maxIncomingUniStreams uint64,
    72  	perspective protocol.Perspective,
    73  	capabilityCallback func(n int64),
    74  ) streamManager {
    75  	m := &streamsMap{
    76  		perspective:            perspective,
    77  		newFlowController:      newFlowController,
    78  		maxIncomingBidiStreams: maxIncomingBidiStreams,
    79  		maxIncomingUniStreams:  maxIncomingUniStreams,
    80  		sender:                 sender,
    81  		capabilityCallback:     capabilityCallback,
    82  	}
    83  	m.initMaps()
    84  	return m
    85  }
    86  
    87  func (m *streamsMap) initMaps() {
    88  	m.outgoingBidiStreams = newOutgoingStreamsMap(
    89  		protocol.StreamTypeBidi,
    90  		func(num protocol.StreamNum) streamI {
    91  			id := num.StreamID(protocol.StreamTypeBidi, m.perspective)
    92  			return newStream(id, m.sender, m.newFlowController(id))
    93  		},
    94  		m.sender.queueControlFrame,
    95  		m.capabilityCallback,
    96  	)
    97  	m.incomingBidiStreams = newIncomingStreamsMap(
    98  		protocol.StreamTypeBidi,
    99  		func(num protocol.StreamNum) streamI {
   100  			id := num.StreamID(protocol.StreamTypeBidi, m.perspective.Opposite())
   101  			return newStream(id, m.sender, m.newFlowController(id))
   102  		},
   103  		m.maxIncomingBidiStreams,
   104  		m.sender.queueControlFrame,
   105  	)
   106  	m.outgoingUniStreams = newOutgoingStreamsMap(
   107  		protocol.StreamTypeUni,
   108  		func(num protocol.StreamNum) sendStreamI {
   109  			id := num.StreamID(protocol.StreamTypeUni, m.perspective)
   110  			return newSendStream(id, m.sender, m.newFlowController(id))
   111  		},
   112  		m.sender.queueControlFrame,
   113  		m.capabilityCallback,
   114  	)
   115  	m.incomingUniStreams = newIncomingStreamsMap(
   116  		protocol.StreamTypeUni,
   117  		func(num protocol.StreamNum) receiveStreamI {
   118  			id := num.StreamID(protocol.StreamTypeUni, m.perspective.Opposite())
   119  			return newReceiveStream(id, m.sender, m.newFlowController(id))
   120  		},
   121  		m.maxIncomingUniStreams,
   122  		m.sender.queueControlFrame,
   123  	)
   124  }
   125  
   126  func (m *streamsMap) OpenStream() (Stream, error) {
   127  	m.mutex.Lock()
   128  	reset := m.reset
   129  	mm := m.outgoingBidiStreams
   130  	m.mutex.Unlock()
   131  	if reset {
   132  		return nil, Err0RTTRejected
   133  	}
   134  	str, err := mm.OpenStream()
   135  	return str, convertStreamError(err, protocol.StreamTypeBidi, m.perspective)
   136  }
   137  
   138  func (m *streamsMap) OpenStreamSync(ctx context.Context) (Stream, error) {
   139  	m.mutex.Lock()
   140  	reset := m.reset
   141  	mm := m.outgoingBidiStreams
   142  	m.mutex.Unlock()
   143  	if reset {
   144  		return nil, Err0RTTRejected
   145  	}
   146  	str, err := mm.OpenStreamSync(ctx)
   147  	return str, convertStreamError(err, protocol.StreamTypeBidi, m.perspective)
   148  }
   149  
   150  func (m *streamsMap) OpenUniStream() (SendStream, error) {
   151  	m.mutex.Lock()
   152  	reset := m.reset
   153  	mm := m.outgoingUniStreams
   154  	m.mutex.Unlock()
   155  	if reset {
   156  		return nil, Err0RTTRejected
   157  	}
   158  	str, err := mm.OpenStream()
   159  	return str, convertStreamError(err, protocol.StreamTypeBidi, m.perspective)
   160  }
   161  
   162  func (m *streamsMap) OpenUniStreamSync(ctx context.Context) (SendStream, error) {
   163  	m.mutex.Lock()
   164  	reset := m.reset
   165  	mm := m.outgoingUniStreams
   166  	m.mutex.Unlock()
   167  	if reset {
   168  		return nil, Err0RTTRejected
   169  	}
   170  	str, err := mm.OpenStreamSync(ctx)
   171  	return str, convertStreamError(err, protocol.StreamTypeUni, m.perspective)
   172  }
   173  
   174  func (m *streamsMap) AcceptStream(ctx context.Context) (Stream, error) {
   175  	m.mutex.Lock()
   176  	reset := m.reset
   177  	mm := m.incomingBidiStreams
   178  	m.mutex.Unlock()
   179  	if reset {
   180  		return nil, Err0RTTRejected
   181  	}
   182  	str, err := mm.AcceptStream(ctx)
   183  	return str, convertStreamError(err, protocol.StreamTypeBidi, m.perspective.Opposite())
   184  }
   185  
   186  func (m *streamsMap) AcceptUniStream(ctx context.Context) (ReceiveStream, error) {
   187  	m.mutex.Lock()
   188  	reset := m.reset
   189  	mm := m.incomingUniStreams
   190  	m.mutex.Unlock()
   191  	if reset {
   192  		return nil, Err0RTTRejected
   193  	}
   194  	str, err := mm.AcceptStream(ctx)
   195  	return str, convertStreamError(err, protocol.StreamTypeUni, m.perspective.Opposite())
   196  }
   197  
   198  func (m *streamsMap) DeleteStream(id protocol.StreamID) error {
   199  	num := id.StreamNum()
   200  	switch id.Type() {
   201  	case protocol.StreamTypeUni:
   202  		if id.InitiatedBy() == m.perspective {
   203  			return convertStreamError(m.outgoingUniStreams.DeleteStream(num), protocol.StreamTypeUni, m.perspective)
   204  		}
   205  		return convertStreamError(m.incomingUniStreams.DeleteStream(num), protocol.StreamTypeUni, m.perspective.Opposite())
   206  	case protocol.StreamTypeBidi:
   207  		if id.InitiatedBy() == m.perspective {
   208  			return convertStreamError(m.outgoingBidiStreams.DeleteStream(num), protocol.StreamTypeBidi, m.perspective)
   209  		}
   210  		return convertStreamError(m.incomingBidiStreams.DeleteStream(num), protocol.StreamTypeBidi, m.perspective.Opposite())
   211  	}
   212  	panic("")
   213  }
   214  
   215  func (m *streamsMap) GetOrOpenReceiveStream(id protocol.StreamID) (receiveStreamI, error) {
   216  	str, err := m.getOrOpenReceiveStream(id)
   217  	if err != nil {
   218  		return nil, &qerr.TransportError{
   219  			ErrorCode:    qerr.StreamStateError,
   220  			ErrorMessage: err.Error(),
   221  		}
   222  	}
   223  	return str, nil
   224  }
   225  
   226  func (m *streamsMap) getOrOpenReceiveStream(id protocol.StreamID) (receiveStreamI, error) {
   227  	num := id.StreamNum()
   228  	switch id.Type() {
   229  	case protocol.StreamTypeUni:
   230  		if id.InitiatedBy() == m.perspective {
   231  			// an outgoing unidirectional stream is a send stream, not a receive stream
   232  			return nil, fmt.Errorf("peer attempted to open receive stream %d", id)
   233  		}
   234  		str, err := m.incomingUniStreams.GetOrOpenStream(num)
   235  		return str, convertStreamError(err, protocol.StreamTypeUni, m.perspective)
   236  	case protocol.StreamTypeBidi:
   237  		var str receiveStreamI
   238  		var err error
   239  		if id.InitiatedBy() == m.perspective {
   240  			str, err = m.outgoingBidiStreams.GetStream(num)
   241  		} else {
   242  			str, err = m.incomingBidiStreams.GetOrOpenStream(num)
   243  		}
   244  		return str, convertStreamError(err, protocol.StreamTypeBidi, id.InitiatedBy())
   245  	}
   246  	panic("")
   247  }
   248  
   249  func (m *streamsMap) GetOrOpenSendStream(id protocol.StreamID) (sendStreamI, error) {
   250  	str, err := m.getOrOpenSendStream(id)
   251  	if err != nil {
   252  		return nil, &qerr.TransportError{
   253  			ErrorCode:    qerr.StreamStateError,
   254  			ErrorMessage: err.Error(),
   255  		}
   256  	}
   257  	return str, nil
   258  }
   259  
   260  func (m *streamsMap) getOrOpenSendStream(id protocol.StreamID) (sendStreamI, error) {
   261  	num := id.StreamNum()
   262  	switch id.Type() {
   263  	case protocol.StreamTypeUni:
   264  		if id.InitiatedBy() == m.perspective {
   265  			str, err := m.outgoingUniStreams.GetStream(num)
   266  			return str, convertStreamError(err, protocol.StreamTypeUni, m.perspective)
   267  		}
   268  		// an incoming unidirectional stream is a receive stream, not a send stream
   269  		return nil, fmt.Errorf("peer attempted to open send stream %d", id)
   270  	case protocol.StreamTypeBidi:
   271  		var str sendStreamI
   272  		var err error
   273  		if id.InitiatedBy() == m.perspective {
   274  			str, err = m.outgoingBidiStreams.GetStream(num)
   275  		} else {
   276  			str, err = m.incomingBidiStreams.GetOrOpenStream(num)
   277  		}
   278  		return str, convertStreamError(err, protocol.StreamTypeBidi, id.InitiatedBy())
   279  	}
   280  	panic("")
   281  }
   282  
   283  func (m *streamsMap) HandleMaxStreamsFrame(f *wire.MaxStreamsFrame) {
   284  	switch f.Type {
   285  	case protocol.StreamTypeUni:
   286  		m.outgoingUniStreams.SetMaxStream(f.MaxStreamNum)
   287  	case protocol.StreamTypeBidi:
   288  		m.outgoingBidiStreams.SetMaxStream(f.MaxStreamNum)
   289  	}
   290  }
   291  
   292  func (m *streamsMap) UpdateLimits(p *wire.TransportParameters) {
   293  	m.outgoingBidiStreams.UpdateSendWindow(p.InitialMaxStreamDataBidiRemote)
   294  	m.outgoingBidiStreams.SetMaxStream(p.MaxBidiStreamNum)
   295  	m.outgoingUniStreams.UpdateSendWindow(p.InitialMaxStreamDataUni)
   296  	m.outgoingUniStreams.SetMaxStream(p.MaxUniStreamNum)
   297  }
   298  
   299  func (m *streamsMap) CloseWithError(err error) {
   300  	m.outgoingBidiStreams.CloseWithError(err)
   301  	m.outgoingUniStreams.CloseWithError(err)
   302  	m.incomingBidiStreams.CloseWithError(err)
   303  	m.incomingUniStreams.CloseWithError(err)
   304  }
   305  
   306  // ResetFor0RTT resets is used when 0-RTT is rejected. In that case, the streams maps are
   307  // 1. closed with an Err0RTTRejected, making calls to Open{Uni}Stream{Sync} / Accept{Uni}Stream return that error.
   308  // 2. reset to their initial state, such that we can immediately process new incoming stream data.
   309  // Afterwards, calls to Open{Uni}Stream{Sync} / Accept{Uni}Stream will continue to return the error,
   310  // until UseResetMaps() has been called.
   311  func (m *streamsMap) ResetFor0RTT() {
   312  	m.mutex.Lock()
   313  	defer m.mutex.Unlock()
   314  	m.reset = true
   315  	m.CloseWithError(Err0RTTRejected)
   316  	m.initMaps()
   317  }
   318  
   319  func (m *streamsMap) UseResetMaps() {
   320  	m.mutex.Lock()
   321  	m.reset = false
   322  	m.mutex.Unlock()
   323  }