github.com/ooni/psiphon/tunnel-core@v0.0.0-20230105123940-fe12a24c96ee/oovendor/quic-go/streams_map.go (about)

     1  package quic
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"net"
     8  	"sync"
     9  
    10  	"github.com/ooni/psiphon/tunnel-core/oovendor/quic-go/internal/flowcontrol"
    11  	"github.com/ooni/psiphon/tunnel-core/oovendor/quic-go/internal/protocol"
    12  	"github.com/ooni/psiphon/tunnel-core/oovendor/quic-go/internal/qerr"
    13  	"github.com/ooni/psiphon/tunnel-core/oovendor/quic-go/internal/wire"
    14  )
    15  
    16  type streamError struct {
    17  	message string
    18  	nums    []protocol.StreamNum
    19  }
    20  
    21  func (e streamError) Error() string {
    22  	return e.message
    23  }
    24  
    25  func convertStreamError(err error, stype protocol.StreamType, pers protocol.Perspective) error {
    26  	strError, ok := err.(streamError)
    27  	if !ok {
    28  		return err
    29  	}
    30  	ids := make([]interface{}, len(strError.nums))
    31  	for i, num := range strError.nums {
    32  		ids[i] = num.StreamID(stype, pers)
    33  	}
    34  	return fmt.Errorf(strError.Error(), ids...)
    35  }
    36  
    37  type streamOpenErr struct{ error }
    38  
    39  var _ net.Error = &streamOpenErr{}
    40  
    41  func (e streamOpenErr) Temporary() bool { return e.error == errTooManyOpenStreams }
    42  func (streamOpenErr) Timeout() bool     { return false }
    43  
    44  // errTooManyOpenStreams is used internally by the outgoing streams maps.
    45  var errTooManyOpenStreams = errors.New("too many open streams")
    46  
    47  type streamsMap struct {
    48  	perspective protocol.Perspective
    49  	version     protocol.VersionNumber
    50  
    51  	maxIncomingBidiStreams uint64
    52  	maxIncomingUniStreams  uint64
    53  
    54  	sender            streamSender
    55  	newFlowController func(protocol.StreamID) flowcontrol.StreamFlowController
    56  
    57  	mutex               sync.Mutex
    58  	outgoingBidiStreams *outgoingBidiStreamsMap
    59  	outgoingUniStreams  *outgoingUniStreamsMap
    60  	incomingBidiStreams *incomingBidiStreamsMap
    61  	incomingUniStreams  *incomingUniStreamsMap
    62  	reset               bool
    63  }
    64  
    65  var _ streamManager = &streamsMap{}
    66  
    67  func newStreamsMap(
    68  	sender streamSender,
    69  	newFlowController func(protocol.StreamID) flowcontrol.StreamFlowController,
    70  	maxIncomingBidiStreams uint64,
    71  	maxIncomingUniStreams uint64,
    72  	perspective protocol.Perspective,
    73  	version protocol.VersionNumber,
    74  ) streamManager {
    75  	m := &streamsMap{
    76  		perspective:            perspective,
    77  		newFlowController:      newFlowController,
    78  		maxIncomingBidiStreams: maxIncomingBidiStreams,
    79  		maxIncomingUniStreams:  maxIncomingUniStreams,
    80  		sender:                 sender,
    81  		version:                version,
    82  	}
    83  	m.initMaps()
    84  	return m
    85  }
    86  
    87  func (m *streamsMap) initMaps() {
    88  	m.outgoingBidiStreams = newOutgoingBidiStreamsMap(
    89  		func(num protocol.StreamNum) streamI {
    90  			id := num.StreamID(protocol.StreamTypeBidi, m.perspective)
    91  			return newStream(id, m.sender, m.newFlowController(id), m.version)
    92  		},
    93  		m.sender.queueControlFrame,
    94  	)
    95  	m.incomingBidiStreams = newIncomingBidiStreamsMap(
    96  		func(num protocol.StreamNum) streamI {
    97  			id := num.StreamID(protocol.StreamTypeBidi, m.perspective.Opposite())
    98  			return newStream(id, m.sender, m.newFlowController(id), m.version)
    99  		},
   100  		m.maxIncomingBidiStreams,
   101  		m.sender.queueControlFrame,
   102  	)
   103  	m.outgoingUniStreams = newOutgoingUniStreamsMap(
   104  		func(num protocol.StreamNum) sendStreamI {
   105  			id := num.StreamID(protocol.StreamTypeUni, m.perspective)
   106  			return newSendStream(id, m.sender, m.newFlowController(id), m.version)
   107  		},
   108  		m.sender.queueControlFrame,
   109  	)
   110  	m.incomingUniStreams = newIncomingUniStreamsMap(
   111  		func(num protocol.StreamNum) receiveStreamI {
   112  			id := num.StreamID(protocol.StreamTypeUni, m.perspective.Opposite())
   113  			return newReceiveStream(id, m.sender, m.newFlowController(id), m.version)
   114  		},
   115  		m.maxIncomingUniStreams,
   116  		m.sender.queueControlFrame,
   117  	)
   118  }
   119  
   120  func (m *streamsMap) OpenStream() (Stream, error) {
   121  	m.mutex.Lock()
   122  	reset := m.reset
   123  	mm := m.outgoingBidiStreams
   124  	m.mutex.Unlock()
   125  	if reset {
   126  		return nil, Err0RTTRejected
   127  	}
   128  	str, err := mm.OpenStream()
   129  	return str, convertStreamError(err, protocol.StreamTypeBidi, m.perspective)
   130  }
   131  
   132  func (m *streamsMap) OpenStreamSync(ctx context.Context) (Stream, error) {
   133  	m.mutex.Lock()
   134  	reset := m.reset
   135  	mm := m.outgoingBidiStreams
   136  	m.mutex.Unlock()
   137  	if reset {
   138  		return nil, Err0RTTRejected
   139  	}
   140  	str, err := mm.OpenStreamSync(ctx)
   141  	return str, convertStreamError(err, protocol.StreamTypeBidi, m.perspective)
   142  }
   143  
   144  func (m *streamsMap) OpenUniStream() (SendStream, error) {
   145  	m.mutex.Lock()
   146  	reset := m.reset
   147  	mm := m.outgoingUniStreams
   148  	m.mutex.Unlock()
   149  	if reset {
   150  		return nil, Err0RTTRejected
   151  	}
   152  	str, err := mm.OpenStream()
   153  	return str, convertStreamError(err, protocol.StreamTypeBidi, m.perspective)
   154  }
   155  
   156  func (m *streamsMap) OpenUniStreamSync(ctx context.Context) (SendStream, error) {
   157  	m.mutex.Lock()
   158  	reset := m.reset
   159  	mm := m.outgoingUniStreams
   160  	m.mutex.Unlock()
   161  	if reset {
   162  		return nil, Err0RTTRejected
   163  	}
   164  	str, err := mm.OpenStreamSync(ctx)
   165  	return str, convertStreamError(err, protocol.StreamTypeUni, m.perspective)
   166  }
   167  
   168  func (m *streamsMap) AcceptStream(ctx context.Context) (Stream, error) {
   169  	m.mutex.Lock()
   170  	reset := m.reset
   171  	mm := m.incomingBidiStreams
   172  	m.mutex.Unlock()
   173  	if reset {
   174  		return nil, Err0RTTRejected
   175  	}
   176  	str, err := mm.AcceptStream(ctx)
   177  	return str, convertStreamError(err, protocol.StreamTypeBidi, m.perspective.Opposite())
   178  }
   179  
   180  func (m *streamsMap) AcceptUniStream(ctx context.Context) (ReceiveStream, error) {
   181  	m.mutex.Lock()
   182  	reset := m.reset
   183  	mm := m.incomingUniStreams
   184  	m.mutex.Unlock()
   185  	if reset {
   186  		return nil, Err0RTTRejected
   187  	}
   188  	str, err := mm.AcceptStream(ctx)
   189  	return str, convertStreamError(err, protocol.StreamTypeUni, m.perspective.Opposite())
   190  }
   191  
   192  func (m *streamsMap) DeleteStream(id protocol.StreamID) error {
   193  	num := id.StreamNum()
   194  	switch id.Type() {
   195  	case protocol.StreamTypeUni:
   196  		if id.InitiatedBy() == m.perspective {
   197  			return convertStreamError(m.outgoingUniStreams.DeleteStream(num), protocol.StreamTypeUni, m.perspective)
   198  		}
   199  		return convertStreamError(m.incomingUniStreams.DeleteStream(num), protocol.StreamTypeUni, m.perspective.Opposite())
   200  	case protocol.StreamTypeBidi:
   201  		if id.InitiatedBy() == m.perspective {
   202  			return convertStreamError(m.outgoingBidiStreams.DeleteStream(num), protocol.StreamTypeBidi, m.perspective)
   203  		}
   204  		return convertStreamError(m.incomingBidiStreams.DeleteStream(num), protocol.StreamTypeBidi, m.perspective.Opposite())
   205  	}
   206  	panic("")
   207  }
   208  
   209  func (m *streamsMap) GetOrOpenReceiveStream(id protocol.StreamID) (receiveStreamI, error) {
   210  	str, err := m.getOrOpenReceiveStream(id)
   211  	if err != nil {
   212  		return nil, &qerr.TransportError{
   213  			ErrorCode:    qerr.StreamStateError,
   214  			ErrorMessage: err.Error(),
   215  		}
   216  	}
   217  	return str, nil
   218  }
   219  
   220  func (m *streamsMap) getOrOpenReceiveStream(id protocol.StreamID) (receiveStreamI, error) {
   221  	num := id.StreamNum()
   222  	switch id.Type() {
   223  	case protocol.StreamTypeUni:
   224  		if id.InitiatedBy() == m.perspective {
   225  			// an outgoing unidirectional stream is a send stream, not a receive stream
   226  			return nil, fmt.Errorf("peer attempted to open receive stream %d", id)
   227  		}
   228  		str, err := m.incomingUniStreams.GetOrOpenStream(num)
   229  		return str, convertStreamError(err, protocol.StreamTypeUni, m.perspective)
   230  	case protocol.StreamTypeBidi:
   231  		var str receiveStreamI
   232  		var err error
   233  		if id.InitiatedBy() == m.perspective {
   234  			str, err = m.outgoingBidiStreams.GetStream(num)
   235  		} else {
   236  			str, err = m.incomingBidiStreams.GetOrOpenStream(num)
   237  		}
   238  		return str, convertStreamError(err, protocol.StreamTypeBidi, id.InitiatedBy())
   239  	}
   240  	panic("")
   241  }
   242  
   243  func (m *streamsMap) GetOrOpenSendStream(id protocol.StreamID) (sendStreamI, error) {
   244  	str, err := m.getOrOpenSendStream(id)
   245  	if err != nil {
   246  		return nil, &qerr.TransportError{
   247  			ErrorCode:    qerr.StreamStateError,
   248  			ErrorMessage: err.Error(),
   249  		}
   250  	}
   251  	return str, nil
   252  }
   253  
   254  func (m *streamsMap) getOrOpenSendStream(id protocol.StreamID) (sendStreamI, error) {
   255  	num := id.StreamNum()
   256  	switch id.Type() {
   257  	case protocol.StreamTypeUni:
   258  		if id.InitiatedBy() == m.perspective {
   259  			str, err := m.outgoingUniStreams.GetStream(num)
   260  			return str, convertStreamError(err, protocol.StreamTypeUni, m.perspective)
   261  		}
   262  		// an incoming unidirectional stream is a receive stream, not a send stream
   263  		return nil, fmt.Errorf("peer attempted to open send stream %d", id)
   264  	case protocol.StreamTypeBidi:
   265  		var str sendStreamI
   266  		var err error
   267  		if id.InitiatedBy() == m.perspective {
   268  			str, err = m.outgoingBidiStreams.GetStream(num)
   269  		} else {
   270  			str, err = m.incomingBidiStreams.GetOrOpenStream(num)
   271  		}
   272  		return str, convertStreamError(err, protocol.StreamTypeBidi, id.InitiatedBy())
   273  	}
   274  	panic("")
   275  }
   276  
   277  func (m *streamsMap) HandleMaxStreamsFrame(f *wire.MaxStreamsFrame) {
   278  	switch f.Type {
   279  	case protocol.StreamTypeUni:
   280  		m.outgoingUniStreams.SetMaxStream(f.MaxStreamNum)
   281  	case protocol.StreamTypeBidi:
   282  		m.outgoingBidiStreams.SetMaxStream(f.MaxStreamNum)
   283  	}
   284  }
   285  
   286  func (m *streamsMap) UpdateLimits(p *wire.TransportParameters) {
   287  	m.outgoingBidiStreams.UpdateSendWindow(p.InitialMaxStreamDataBidiRemote)
   288  	m.outgoingBidiStreams.SetMaxStream(p.MaxBidiStreamNum)
   289  	m.outgoingUniStreams.UpdateSendWindow(p.InitialMaxStreamDataUni)
   290  	m.outgoingUniStreams.SetMaxStream(p.MaxUniStreamNum)
   291  }
   292  
   293  func (m *streamsMap) CloseWithError(err error) {
   294  	m.outgoingBidiStreams.CloseWithError(err)
   295  	m.outgoingUniStreams.CloseWithError(err)
   296  	m.incomingBidiStreams.CloseWithError(err)
   297  	m.incomingUniStreams.CloseWithError(err)
   298  }
   299  
   300  // ResetFor0RTT resets is used when 0-RTT is rejected. In that case, the streams maps are
   301  // 1. closed with an Err0RTTRejected, making calls to Open{Uni}Stream{Sync} / Accept{Uni}Stream return that error.
   302  // 2. reset to their initial state, such that we can immediately process new incoming stream data.
   303  // Afterwards, calls to Open{Uni}Stream{Sync} / Accept{Uni}Stream will continue to return the error,
   304  // until UseResetMaps() has been called.
   305  func (m *streamsMap) ResetFor0RTT() {
   306  	m.mutex.Lock()
   307  	defer m.mutex.Unlock()
   308  	m.reset = true
   309  	m.CloseWithError(Err0RTTRejected)
   310  	m.initMaps()
   311  }
   312  
   313  func (m *streamsMap) UseResetMaps() {
   314  	m.mutex.Lock()
   315  	m.reset = false
   316  	m.mutex.Unlock()
   317  }