github.com/danielpfeifer02/quic-go-prio-packs@v0.41.0-28/streams_map.go (about)

     1  package quic
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"net"
     8  	"sync"
     9  
    10  	"github.com/danielpfeifer02/quic-go-prio-packs/internal/flowcontrol"
    11  	"github.com/danielpfeifer02/quic-go-prio-packs/internal/protocol"
    12  	"github.com/danielpfeifer02/quic-go-prio-packs/internal/qerr"
    13  	"github.com/danielpfeifer02/quic-go-prio-packs/internal/wire"
    14  	"github.com/danielpfeifer02/quic-go-prio-packs/priority_setting"
    15  )
    16  
    17  type streamError struct {
    18  	message string
    19  	nums    []protocol.StreamNum
    20  }
    21  
    22  func (e streamError) Error() string {
    23  	return e.message
    24  }
    25  
    26  func convertStreamError(err error, stype protocol.StreamType, pers protocol.Perspective) error {
    27  	strError, ok := err.(streamError)
    28  	if !ok {
    29  		return err
    30  	}
    31  	ids := make([]interface{}, len(strError.nums))
    32  	for i, num := range strError.nums {
    33  		ids[i] = num.StreamID(stype, pers)
    34  	}
    35  	return fmt.Errorf(strError.Error(), ids...)
    36  }
    37  
    38  type streamOpenErr struct{ error }
    39  
    40  var _ net.Error = &streamOpenErr{}
    41  
    42  func (e streamOpenErr) Temporary() bool { return e.error == errTooManyOpenStreams }
    43  func (streamOpenErr) Timeout() bool     { return false }
    44  
    45  // errTooManyOpenStreams is used internally by the outgoing streams maps.
    46  var errTooManyOpenStreams = errors.New("too many open streams")
    47  
    48  type streamsMap struct {
    49  	perspective protocol.Perspective
    50  
    51  	maxIncomingBidiStreams uint64
    52  	maxIncomingUniStreams  uint64
    53  
    54  	sender            streamSender
    55  	newFlowController func(protocol.StreamID) flowcontrol.StreamFlowController
    56  
    57  	mutex               sync.Mutex
    58  	outgoingBidiStreams *outgoingStreamsMap[streamI]
    59  	outgoingUniStreams  *outgoingStreamsMap[sendStreamI]
    60  	incomingBidiStreams *incomingStreamsMap[streamI]
    61  	incomingUniStreams  *incomingStreamsMap[receiveStreamI]
    62  	reset               bool
    63  }
    64  
    65  var _ streamManager = &streamsMap{}
    66  
    67  func newStreamsMap(
    68  	sender streamSender,
    69  	newFlowController func(protocol.StreamID) flowcontrol.StreamFlowController,
    70  	maxIncomingBidiStreams uint64,
    71  	maxIncomingUniStreams uint64,
    72  	perspective protocol.Perspective,
    73  ) streamManager {
    74  	m := &streamsMap{
    75  		perspective:            perspective,
    76  		newFlowController:      newFlowController,
    77  		maxIncomingBidiStreams: maxIncomingBidiStreams,
    78  		maxIncomingUniStreams:  maxIncomingUniStreams,
    79  		sender:                 sender,
    80  	}
    81  	m.initMaps()
    82  	return m
    83  }
    84  
    85  func (m *streamsMap) initMaps() {
    86  	m.outgoingBidiStreams = newOutgoingStreamsMap(
    87  		protocol.StreamTypeBidi,
    88  		func(num protocol.StreamNum) streamI {
    89  			id := num.StreamID(protocol.StreamTypeBidi, m.perspective)
    90  			return newStream(id, m.sender, m.newFlowController(id))
    91  		},
    92  		m.sender.queueControlFrame,
    93  	)
    94  	m.incomingBidiStreams = newIncomingStreamsMap(
    95  		protocol.StreamTypeBidi,
    96  		func(num protocol.StreamNum) streamI {
    97  			id := num.StreamID(protocol.StreamTypeBidi, m.perspective.Opposite())
    98  			return newStream(id, m.sender, m.newFlowController(id))
    99  		},
   100  		m.maxIncomingBidiStreams,
   101  		m.sender.queueControlFrame,
   102  	)
   103  	m.outgoingUniStreams = newOutgoingStreamsMap(
   104  		protocol.StreamTypeUni,
   105  		func(num protocol.StreamNum) sendStreamI {
   106  			id := num.StreamID(protocol.StreamTypeUni, m.perspective)
   107  			return newSendStream(id, m.sender, m.newFlowController(id))
   108  		},
   109  		m.sender.queueControlFrame,
   110  	)
   111  	m.incomingUniStreams = newIncomingStreamsMap(
   112  		protocol.StreamTypeUni,
   113  		func(num protocol.StreamNum) receiveStreamI {
   114  			id := num.StreamID(protocol.StreamTypeUni, m.perspective.Opposite())
   115  			return newReceiveStream(id, m.sender, m.newFlowController(id))
   116  		},
   117  		m.maxIncomingUniStreams,
   118  		m.sender.queueControlFrame,
   119  	)
   120  }
   121  
   122  // PRIO_PACKS_TAG
   123  // OpenStream wrapper for priority setting
   124  func (m *streamsMap) OpenStreamWithPriority(priority priority_setting.Priority) (Stream, error) {
   125  	str, err := m.OpenStream()
   126  	if err != nil {
   127  		return nil, err
   128  	}
   129  	str.SetPriority(priority)
   130  	return str, nil
   131  }
   132  
   133  func (m *streamsMap) OpenStream() (Stream, error) {
   134  	m.mutex.Lock()
   135  	reset := m.reset
   136  	mm := m.outgoingBidiStreams
   137  	m.mutex.Unlock()
   138  	if reset {
   139  		return nil, Err0RTTRejected
   140  	}
   141  	str, err := mm.OpenStream()
   142  	return str, convertStreamError(err, protocol.StreamTypeBidi, m.perspective)
   143  }
   144  
   145  // PRIO_PACKS_TAG
   146  // OpenStream wrapper for priority setting
   147  func (m *streamsMap) OpenStreamSyncWithPriority(ctx context.Context, priority priority_setting.Priority) (Stream, error) {
   148  	str, err := m.OpenStreamSync(ctx)
   149  	if err != nil {
   150  		return nil, err
   151  	}
   152  	str.SetPriority(priority)
   153  	return str, nil
   154  }
   155  
   156  func (m *streamsMap) OpenStreamSync(ctx context.Context) (Stream, error) {
   157  	m.mutex.Lock()
   158  	reset := m.reset
   159  	mm := m.outgoingBidiStreams
   160  	m.mutex.Unlock()
   161  	if reset {
   162  		return nil, Err0RTTRejected
   163  	}
   164  	str, err := mm.OpenStreamSync(ctx)
   165  	return str, convertStreamError(err, protocol.StreamTypeBidi, m.perspective)
   166  }
   167  
   168  // PRIO_PACKS_TAG
   169  // OpenStream wrapper for priority setting
   170  func (m *streamsMap) OpenUniStreamWithPriority(priority priority_setting.Priority) (SendStream, error) {
   171  	str, err := m.OpenUniStream()
   172  	if err != nil {
   173  		return nil, err
   174  	}
   175  	str.SetPriority(priority)
   176  	return str, nil
   177  }
   178  
   179  func (m *streamsMap) OpenUniStream() (SendStream, error) {
   180  	m.mutex.Lock()
   181  	reset := m.reset
   182  	mm := m.outgoingUniStreams
   183  	m.mutex.Unlock()
   184  	if reset {
   185  		return nil, Err0RTTRejected
   186  	}
   187  	str, err := mm.OpenStream()
   188  	return str, convertStreamError(err, protocol.StreamTypeBidi, m.perspective)
   189  }
   190  
   191  // PRIO_PACKS_TAG
   192  // OpenStream wrapper for priority setting
   193  func (m *streamsMap) OpenUniStreamSyncWithPriority(ctx context.Context, priority priority_setting.Priority) (SendStream, error) {
   194  	str, err := m.OpenUniStreamSync(ctx)
   195  	if err != nil {
   196  		return nil, err
   197  	}
   198  	str.SetPriority(priority)
   199  	return str, nil
   200  }
   201  
   202  func (m *streamsMap) OpenUniStreamSync(ctx context.Context) (SendStream, error) {
   203  	m.mutex.Lock()
   204  	reset := m.reset
   205  	mm := m.outgoingUniStreams
   206  	m.mutex.Unlock()
   207  	if reset {
   208  		return nil, Err0RTTRejected
   209  	}
   210  	str, err := mm.OpenStreamSync(ctx)
   211  	return str, convertStreamError(err, protocol.StreamTypeUni, m.perspective)
   212  }
   213  
   214  func (m *streamsMap) AcceptStream(ctx context.Context) (Stream, error) {
   215  	m.mutex.Lock()
   216  	reset := m.reset
   217  	mm := m.incomingBidiStreams
   218  	m.mutex.Unlock()
   219  	if reset {
   220  		return nil, Err0RTTRejected
   221  	}
   222  	str, err := mm.AcceptStream(ctx)
   223  	return str, convertStreamError(err, protocol.StreamTypeBidi, m.perspective.Opposite())
   224  }
   225  
   226  func (m *streamsMap) AcceptUniStream(ctx context.Context) (ReceiveStream, error) {
   227  	m.mutex.Lock()
   228  	reset := m.reset
   229  	mm := m.incomingUniStreams
   230  	m.mutex.Unlock()
   231  	if reset {
   232  		return nil, Err0RTTRejected
   233  	}
   234  	str, err := mm.AcceptStream(ctx)
   235  	return str, convertStreamError(err, protocol.StreamTypeUni, m.perspective.Opposite())
   236  }
   237  
   238  func (m *streamsMap) DeleteStream(id protocol.StreamID) error {
   239  	num := id.StreamNum()
   240  	switch id.Type() {
   241  	case protocol.StreamTypeUni:
   242  		if id.InitiatedBy() == m.perspective {
   243  			return convertStreamError(m.outgoingUniStreams.DeleteStream(num), protocol.StreamTypeUni, m.perspective)
   244  		}
   245  		return convertStreamError(m.incomingUniStreams.DeleteStream(num), protocol.StreamTypeUni, m.perspective.Opposite())
   246  	case protocol.StreamTypeBidi:
   247  		if id.InitiatedBy() == m.perspective {
   248  			return convertStreamError(m.outgoingBidiStreams.DeleteStream(num), protocol.StreamTypeBidi, m.perspective)
   249  		}
   250  		return convertStreamError(m.incomingBidiStreams.DeleteStream(num), protocol.StreamTypeBidi, m.perspective.Opposite())
   251  	}
   252  	panic("")
   253  }
   254  
   255  func (m *streamsMap) GetOrOpenReceiveStream(id protocol.StreamID) (receiveStreamI, error) {
   256  	str, err := m.getOrOpenReceiveStream(id)
   257  	if err != nil {
   258  		return nil, &qerr.TransportError{
   259  			ErrorCode:    qerr.StreamStateError,
   260  			ErrorMessage: err.Error(),
   261  		}
   262  	}
   263  	return str, nil
   264  }
   265  
   266  func (m *streamsMap) getOrOpenReceiveStream(id protocol.StreamID) (receiveStreamI, error) {
   267  	num := id.StreamNum()
   268  	switch id.Type() {
   269  	case protocol.StreamTypeUni:
   270  		if id.InitiatedBy() == m.perspective {
   271  			// an outgoing unidirectional stream is a send stream, not a receive stream
   272  			return nil, fmt.Errorf("peer attempted to open receive stream %d", id)
   273  		}
   274  		str, err := m.incomingUniStreams.GetOrOpenStream(num)
   275  		return str, convertStreamError(err, protocol.StreamTypeUni, m.perspective)
   276  	case protocol.StreamTypeBidi:
   277  		var str receiveStreamI
   278  		var err error
   279  		if id.InitiatedBy() == m.perspective {
   280  			str, err = m.outgoingBidiStreams.GetStream(num)
   281  		} else {
   282  			str, err = m.incomingBidiStreams.GetOrOpenStream(num)
   283  		}
   284  		return str, convertStreamError(err, protocol.StreamTypeBidi, id.InitiatedBy())
   285  	}
   286  	panic("")
   287  }
   288  
   289  func (m *streamsMap) GetOrOpenSendStream(id protocol.StreamID) (sendStreamI, error) {
   290  	str, err := m.getOrOpenSendStream(id)
   291  	if err != nil {
   292  		return nil, &qerr.TransportError{
   293  			ErrorCode:    qerr.StreamStateError,
   294  			ErrorMessage: err.Error(),
   295  		}
   296  	}
   297  	return str, nil
   298  }
   299  
   300  func (m *streamsMap) getOrOpenSendStream(id protocol.StreamID) (sendStreamI, error) {
   301  	num := id.StreamNum()
   302  	switch id.Type() {
   303  	case protocol.StreamTypeUni:
   304  		if id.InitiatedBy() == m.perspective {
   305  			str, err := m.outgoingUniStreams.GetStream(num)
   306  			return str, convertStreamError(err, protocol.StreamTypeUni, m.perspective)
   307  		}
   308  		// an incoming unidirectional stream is a receive stream, not a send stream
   309  		return nil, fmt.Errorf("peer attempted to open send stream %d", id)
   310  	case protocol.StreamTypeBidi:
   311  		var str sendStreamI
   312  		var err error
   313  		if id.InitiatedBy() == m.perspective {
   314  			str, err = m.outgoingBidiStreams.GetStream(num)
   315  		} else {
   316  			str, err = m.incomingBidiStreams.GetOrOpenStream(num)
   317  		}
   318  		return str, convertStreamError(err, protocol.StreamTypeBidi, id.InitiatedBy())
   319  	}
   320  	panic("")
   321  }
   322  
   323  func (m *streamsMap) HandleMaxStreamsFrame(f *wire.MaxStreamsFrame) {
   324  	switch f.Type {
   325  	case protocol.StreamTypeUni:
   326  		m.outgoingUniStreams.SetMaxStream(f.MaxStreamNum)
   327  	case protocol.StreamTypeBidi:
   328  		m.outgoingBidiStreams.SetMaxStream(f.MaxStreamNum)
   329  	}
   330  }
   331  
   332  func (m *streamsMap) UpdateLimits(p *wire.TransportParameters) {
   333  	m.outgoingBidiStreams.UpdateSendWindow(p.InitialMaxStreamDataBidiRemote)
   334  	m.outgoingBidiStreams.SetMaxStream(p.MaxBidiStreamNum)
   335  	m.outgoingUniStreams.UpdateSendWindow(p.InitialMaxStreamDataUni)
   336  	m.outgoingUniStreams.SetMaxStream(p.MaxUniStreamNum)
   337  }
   338  
   339  func (m *streamsMap) CloseWithError(err error) {
   340  	m.outgoingBidiStreams.CloseWithError(err)
   341  	m.outgoingUniStreams.CloseWithError(err)
   342  	m.incomingBidiStreams.CloseWithError(err)
   343  	m.incomingUniStreams.CloseWithError(err)
   344  }
   345  
   346  // ResetFor0RTT resets is used when 0-RTT is rejected. In that case, the streams maps are
   347  // 1. closed with an Err0RTTRejected, making calls to Open{Uni}Stream{Sync} / Accept{Uni}Stream return that error.
   348  // 2. reset to their initial state, such that we can immediately process new incoming stream data.
   349  // Afterwards, calls to Open{Uni}Stream{Sync} / Accept{Uni}Stream will continue to return the error,
   350  // until UseResetMaps() has been called.
   351  func (m *streamsMap) ResetFor0RTT() {
   352  	m.mutex.Lock()
   353  	defer m.mutex.Unlock()
   354  	m.reset = true
   355  	m.CloseWithError(Err0RTTRejected)
   356  	m.initMaps()
   357  }
   358  
   359  func (m *streamsMap) UseResetMaps() {
   360  	m.mutex.Lock()
   361  	m.reset = false
   362  	m.mutex.Unlock()
   363  }
   364  
   365  // PRIO_PACKS_TAG
   366  func (m *streamsMap) GetPriority(id protocol.StreamID) priority_setting.Priority {
   367  	str, err := m.getOrOpenSendStream(id)
   368  	if err != nil {
   369  		return priority_setting.NoPriority
   370  	}
   371  	return str.Priority()
   372  }