github.com/MerlinKodo/quic-go@v0.39.2/streams_map.go (about)

     1  package quic
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"net"
     8  	"sync"
     9  
    10  	"github.com/MerlinKodo/quic-go/internal/flowcontrol"
    11  	"github.com/MerlinKodo/quic-go/internal/protocol"
    12  	"github.com/MerlinKodo/quic-go/internal/qerr"
    13  	"github.com/MerlinKodo/quic-go/internal/wire"
    14  )
    15  
    16  type streamError struct {
    17  	message string
    18  	nums    []protocol.StreamNum
    19  }
    20  
    21  func (e streamError) Error() string {
    22  	return e.message
    23  }
    24  
    25  func convertStreamError(err error, stype protocol.StreamType, pers protocol.Perspective) error {
    26  	strError, ok := err.(streamError)
    27  	if !ok {
    28  		return err
    29  	}
    30  	ids := make([]interface{}, len(strError.nums))
    31  	for i, num := range strError.nums {
    32  		ids[i] = num.StreamID(stype, pers)
    33  	}
    34  	return fmt.Errorf(strError.Error(), ids...)
    35  }
    36  
    37  type streamOpenErr struct{ error }
    38  
    39  var _ net.Error = &streamOpenErr{}
    40  
    41  func (e streamOpenErr) Temporary() bool { return e.error == errTooManyOpenStreams }
    42  func (streamOpenErr) Timeout() bool     { return false }
    43  
    44  // errTooManyOpenStreams is used internally by the outgoing streams maps.
    45  var errTooManyOpenStreams = errors.New("too many open streams")
    46  
    47  type streamsMap struct {
    48  	perspective protocol.Perspective
    49  
    50  	maxIncomingBidiStreams uint64
    51  	maxIncomingUniStreams  uint64
    52  
    53  	sender            streamSender
    54  	newFlowController func(protocol.StreamID) flowcontrol.StreamFlowController
    55  
    56  	mutex               sync.Mutex
    57  	outgoingBidiStreams *outgoingStreamsMap[streamI]
    58  	outgoingUniStreams  *outgoingStreamsMap[sendStreamI]
    59  	incomingBidiStreams *incomingStreamsMap[streamI]
    60  	incomingUniStreams  *incomingStreamsMap[receiveStreamI]
    61  	reset               bool
    62  }
    63  
    64  var _ streamManager = &streamsMap{}
    65  
    66  func newStreamsMap(
    67  	sender streamSender,
    68  	newFlowController func(protocol.StreamID) flowcontrol.StreamFlowController,
    69  	maxIncomingBidiStreams uint64,
    70  	maxIncomingUniStreams uint64,
    71  	perspective protocol.Perspective,
    72  ) streamManager {
    73  	m := &streamsMap{
    74  		perspective:            perspective,
    75  		newFlowController:      newFlowController,
    76  		maxIncomingBidiStreams: maxIncomingBidiStreams,
    77  		maxIncomingUniStreams:  maxIncomingUniStreams,
    78  		sender:                 sender,
    79  	}
    80  	m.initMaps()
    81  	return m
    82  }
    83  
    84  func (m *streamsMap) initMaps() {
    85  	m.outgoingBidiStreams = newOutgoingStreamsMap(
    86  		protocol.StreamTypeBidi,
    87  		func(num protocol.StreamNum) streamI {
    88  			id := num.StreamID(protocol.StreamTypeBidi, m.perspective)
    89  			return newStream(id, m.sender, m.newFlowController(id))
    90  		},
    91  		m.sender.queueControlFrame,
    92  	)
    93  	m.incomingBidiStreams = newIncomingStreamsMap(
    94  		protocol.StreamTypeBidi,
    95  		func(num protocol.StreamNum) streamI {
    96  			id := num.StreamID(protocol.StreamTypeBidi, m.perspective.Opposite())
    97  			return newStream(id, m.sender, m.newFlowController(id))
    98  		},
    99  		m.maxIncomingBidiStreams,
   100  		m.sender.queueControlFrame,
   101  	)
   102  	m.outgoingUniStreams = newOutgoingStreamsMap(
   103  		protocol.StreamTypeUni,
   104  		func(num protocol.StreamNum) sendStreamI {
   105  			id := num.StreamID(protocol.StreamTypeUni, m.perspective)
   106  			return newSendStream(id, m.sender, m.newFlowController(id))
   107  		},
   108  		m.sender.queueControlFrame,
   109  	)
   110  	m.incomingUniStreams = newIncomingStreamsMap(
   111  		protocol.StreamTypeUni,
   112  		func(num protocol.StreamNum) receiveStreamI {
   113  			id := num.StreamID(protocol.StreamTypeUni, m.perspective.Opposite())
   114  			return newReceiveStream(id, m.sender, m.newFlowController(id))
   115  		},
   116  		m.maxIncomingUniStreams,
   117  		m.sender.queueControlFrame,
   118  	)
   119  }
   120  
   121  func (m *streamsMap) OpenStream() (Stream, error) {
   122  	m.mutex.Lock()
   123  	reset := m.reset
   124  	mm := m.outgoingBidiStreams
   125  	m.mutex.Unlock()
   126  	if reset {
   127  		return nil, Err0RTTRejected
   128  	}
   129  	str, err := mm.OpenStream()
   130  	return str, convertStreamError(err, protocol.StreamTypeBidi, m.perspective)
   131  }
   132  
   133  func (m *streamsMap) OpenStreamSync(ctx context.Context) (Stream, error) {
   134  	m.mutex.Lock()
   135  	reset := m.reset
   136  	mm := m.outgoingBidiStreams
   137  	m.mutex.Unlock()
   138  	if reset {
   139  		return nil, Err0RTTRejected
   140  	}
   141  	str, err := mm.OpenStreamSync(ctx)
   142  	return str, convertStreamError(err, protocol.StreamTypeBidi, m.perspective)
   143  }
   144  
   145  func (m *streamsMap) OpenUniStream() (SendStream, error) {
   146  	m.mutex.Lock()
   147  	reset := m.reset
   148  	mm := m.outgoingUniStreams
   149  	m.mutex.Unlock()
   150  	if reset {
   151  		return nil, Err0RTTRejected
   152  	}
   153  	str, err := mm.OpenStream()
   154  	return str, convertStreamError(err, protocol.StreamTypeBidi, m.perspective)
   155  }
   156  
   157  func (m *streamsMap) OpenUniStreamSync(ctx context.Context) (SendStream, error) {
   158  	m.mutex.Lock()
   159  	reset := m.reset
   160  	mm := m.outgoingUniStreams
   161  	m.mutex.Unlock()
   162  	if reset {
   163  		return nil, Err0RTTRejected
   164  	}
   165  	str, err := mm.OpenStreamSync(ctx)
   166  	return str, convertStreamError(err, protocol.StreamTypeUni, m.perspective)
   167  }
   168  
   169  func (m *streamsMap) AcceptStream(ctx context.Context) (Stream, error) {
   170  	m.mutex.Lock()
   171  	reset := m.reset
   172  	mm := m.incomingBidiStreams
   173  	m.mutex.Unlock()
   174  	if reset {
   175  		return nil, Err0RTTRejected
   176  	}
   177  	str, err := mm.AcceptStream(ctx)
   178  	return str, convertStreamError(err, protocol.StreamTypeBidi, m.perspective.Opposite())
   179  }
   180  
   181  func (m *streamsMap) AcceptUniStream(ctx context.Context) (ReceiveStream, error) {
   182  	m.mutex.Lock()
   183  	reset := m.reset
   184  	mm := m.incomingUniStreams
   185  	m.mutex.Unlock()
   186  	if reset {
   187  		return nil, Err0RTTRejected
   188  	}
   189  	str, err := mm.AcceptStream(ctx)
   190  	return str, convertStreamError(err, protocol.StreamTypeUni, m.perspective.Opposite())
   191  }
   192  
   193  func (m *streamsMap) DeleteStream(id protocol.StreamID) error {
   194  	num := id.StreamNum()
   195  	switch id.Type() {
   196  	case protocol.StreamTypeUni:
   197  		if id.InitiatedBy() == m.perspective {
   198  			return convertStreamError(m.outgoingUniStreams.DeleteStream(num), protocol.StreamTypeUni, m.perspective)
   199  		}
   200  		return convertStreamError(m.incomingUniStreams.DeleteStream(num), protocol.StreamTypeUni, m.perspective.Opposite())
   201  	case protocol.StreamTypeBidi:
   202  		if id.InitiatedBy() == m.perspective {
   203  			return convertStreamError(m.outgoingBidiStreams.DeleteStream(num), protocol.StreamTypeBidi, m.perspective)
   204  		}
   205  		return convertStreamError(m.incomingBidiStreams.DeleteStream(num), protocol.StreamTypeBidi, m.perspective.Opposite())
   206  	}
   207  	panic("")
   208  }
   209  
   210  func (m *streamsMap) GetOrOpenReceiveStream(id protocol.StreamID) (receiveStreamI, error) {
   211  	str, err := m.getOrOpenReceiveStream(id)
   212  	if err != nil {
   213  		return nil, &qerr.TransportError{
   214  			ErrorCode:    qerr.StreamStateError,
   215  			ErrorMessage: err.Error(),
   216  		}
   217  	}
   218  	return str, nil
   219  }
   220  
   221  func (m *streamsMap) getOrOpenReceiveStream(id protocol.StreamID) (receiveStreamI, error) {
   222  	num := id.StreamNum()
   223  	switch id.Type() {
   224  	case protocol.StreamTypeUni:
   225  		if id.InitiatedBy() == m.perspective {
   226  			// an outgoing unidirectional stream is a send stream, not a receive stream
   227  			return nil, fmt.Errorf("peer attempted to open receive stream %d", id)
   228  		}
   229  		str, err := m.incomingUniStreams.GetOrOpenStream(num)
   230  		return str, convertStreamError(err, protocol.StreamTypeUni, m.perspective)
   231  	case protocol.StreamTypeBidi:
   232  		var str receiveStreamI
   233  		var err error
   234  		if id.InitiatedBy() == m.perspective {
   235  			str, err = m.outgoingBidiStreams.GetStream(num)
   236  		} else {
   237  			str, err = m.incomingBidiStreams.GetOrOpenStream(num)
   238  		}
   239  		return str, convertStreamError(err, protocol.StreamTypeBidi, id.InitiatedBy())
   240  	}
   241  	panic("")
   242  }
   243  
   244  func (m *streamsMap) GetOrOpenSendStream(id protocol.StreamID) (sendStreamI, error) {
   245  	str, err := m.getOrOpenSendStream(id)
   246  	if err != nil {
   247  		return nil, &qerr.TransportError{
   248  			ErrorCode:    qerr.StreamStateError,
   249  			ErrorMessage: err.Error(),
   250  		}
   251  	}
   252  	return str, nil
   253  }
   254  
   255  func (m *streamsMap) getOrOpenSendStream(id protocol.StreamID) (sendStreamI, error) {
   256  	num := id.StreamNum()
   257  	switch id.Type() {
   258  	case protocol.StreamTypeUni:
   259  		if id.InitiatedBy() == m.perspective {
   260  			str, err := m.outgoingUniStreams.GetStream(num)
   261  			return str, convertStreamError(err, protocol.StreamTypeUni, m.perspective)
   262  		}
   263  		// an incoming unidirectional stream is a receive stream, not a send stream
   264  		return nil, fmt.Errorf("peer attempted to open send stream %d", id)
   265  	case protocol.StreamTypeBidi:
   266  		var str sendStreamI
   267  		var err error
   268  		if id.InitiatedBy() == m.perspective {
   269  			str, err = m.outgoingBidiStreams.GetStream(num)
   270  		} else {
   271  			str, err = m.incomingBidiStreams.GetOrOpenStream(num)
   272  		}
   273  		return str, convertStreamError(err, protocol.StreamTypeBidi, id.InitiatedBy())
   274  	}
   275  	panic("")
   276  }
   277  
   278  func (m *streamsMap) HandleMaxStreamsFrame(f *wire.MaxStreamsFrame) {
   279  	switch f.Type {
   280  	case protocol.StreamTypeUni:
   281  		m.outgoingUniStreams.SetMaxStream(f.MaxStreamNum)
   282  	case protocol.StreamTypeBidi:
   283  		m.outgoingBidiStreams.SetMaxStream(f.MaxStreamNum)
   284  	}
   285  }
   286  
   287  func (m *streamsMap) UpdateLimits(p *wire.TransportParameters) {
   288  	m.outgoingBidiStreams.UpdateSendWindow(p.InitialMaxStreamDataBidiRemote)
   289  	m.outgoingBidiStreams.SetMaxStream(p.MaxBidiStreamNum)
   290  	m.outgoingUniStreams.UpdateSendWindow(p.InitialMaxStreamDataUni)
   291  	m.outgoingUniStreams.SetMaxStream(p.MaxUniStreamNum)
   292  }
   293  
   294  func (m *streamsMap) CloseWithError(err error) {
   295  	m.outgoingBidiStreams.CloseWithError(err)
   296  	m.outgoingUniStreams.CloseWithError(err)
   297  	m.incomingBidiStreams.CloseWithError(err)
   298  	m.incomingUniStreams.CloseWithError(err)
   299  }
   300  
   301  // ResetFor0RTT resets is used when 0-RTT is rejected. In that case, the streams maps are
   302  // 1. closed with an Err0RTTRejected, making calls to Open{Uni}Stream{Sync} / Accept{Uni}Stream return that error.
   303  // 2. reset to their initial state, such that we can immediately process new incoming stream data.
   304  // Afterwards, calls to Open{Uni}Stream{Sync} / Accept{Uni}Stream will continue to return the error,
   305  // until UseResetMaps() has been called.
   306  func (m *streamsMap) ResetFor0RTT() {
   307  	m.mutex.Lock()
   308  	defer m.mutex.Unlock()
   309  	m.reset = true
   310  	m.CloseWithError(Err0RTTRejected)
   311  	m.initMaps()
   312  }
   313  
   314  func (m *streamsMap) UseResetMaps() {
   315  	m.mutex.Lock()
   316  	m.reset = false
   317  	m.mutex.Unlock()
   318  }