github.com/psiphon-labs/psiphon-tunnel-core@v2.0.28+incompatible/psiphon/common/quic/gquic-go/streams_map_legacy.go (about)

     1  package gquic
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  	"sync"
     7  
     8  	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/quic/gquic-go/internal/handshake"
     9  	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/quic/gquic-go/internal/protocol"
    10  	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/quic/gquic-go/internal/utils"
    11  	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/quic/gquic-go/internal/wire"
    12  	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/quic/gquic-go/qerr"
    13  )
    14  
    15  type streamsMapLegacy struct {
    16  	mutex sync.RWMutex
    17  
    18  	perspective protocol.Perspective
    19  
    20  	streams map[protocol.StreamID]streamI
    21  
    22  	nextStreamToOpen          protocol.StreamID // StreamID of the next Stream that will be returned by OpenStream()
    23  	highestStreamOpenedByPeer protocol.StreamID
    24  	nextStreamOrErrCond       sync.Cond
    25  	openStreamOrErrCond       sync.Cond
    26  
    27  	closeErr           error
    28  	nextStreamToAccept protocol.StreamID
    29  
    30  	newStream func(protocol.StreamID) streamI
    31  
    32  	numOutgoingStreams uint32
    33  	numIncomingStreams uint32
    34  	maxIncomingStreams uint32
    35  	maxOutgoingStreams uint32
    36  }
    37  
    38  var _ streamManager = &streamsMapLegacy{}
    39  
    40  var errMapAccess = errors.New("streamsMap: Error accessing the streams map")
    41  
    42  func newStreamsMapLegacy(newStream func(protocol.StreamID) streamI, maxStreams int, pers protocol.Perspective) streamManager {
    43  	// add some tolerance to the maximum incoming streams value
    44  	maxIncomingStreams := utils.MaxUint32(
    45  		uint32(maxStreams)+protocol.MaxStreamsMinimumIncrement,
    46  		uint32(float64(maxStreams)*float64(protocol.MaxStreamsMultiplier)),
    47  	)
    48  	sm := streamsMapLegacy{
    49  		perspective:        pers,
    50  		streams:            make(map[protocol.StreamID]streamI),
    51  		newStream:          newStream,
    52  		maxIncomingStreams: maxIncomingStreams,
    53  	}
    54  	sm.nextStreamOrErrCond.L = &sm.mutex
    55  	sm.openStreamOrErrCond.L = &sm.mutex
    56  
    57  	nextServerInitiatedStream := protocol.StreamID(2)
    58  	nextClientInitiatedStream := protocol.StreamID(3)
    59  	if pers == protocol.PerspectiveServer {
    60  		sm.highestStreamOpenedByPeer = 1
    61  	}
    62  	if pers == protocol.PerspectiveServer {
    63  		sm.nextStreamToOpen = nextServerInitiatedStream
    64  		sm.nextStreamToAccept = nextClientInitiatedStream
    65  	} else {
    66  		sm.nextStreamToOpen = nextClientInitiatedStream
    67  		sm.nextStreamToAccept = nextServerInitiatedStream
    68  	}
    69  	return &sm
    70  }
    71  
    72  // getStreamPerspective says which side should initiate a stream
    73  func (m *streamsMapLegacy) streamInitiatedBy(id protocol.StreamID) protocol.Perspective {
    74  	if id%2 == 0 {
    75  		return protocol.PerspectiveServer
    76  	}
    77  	return protocol.PerspectiveClient
    78  }
    79  
    80  func (m *streamsMapLegacy) GetOrOpenReceiveStream(id protocol.StreamID) (receiveStreamI, error) {
    81  	// every bidirectional stream is also a receive stream
    82  	return m.getOrOpenStream(id)
    83  }
    84  
    85  func (m *streamsMapLegacy) GetOrOpenSendStream(id protocol.StreamID) (sendStreamI, error) {
    86  	// every bidirectional stream is also a send stream
    87  	return m.getOrOpenStream(id)
    88  }
    89  
    90  // getOrOpenStream either returns an existing stream, a newly opened stream, or nil if a stream with the provided ID is already closed.
    91  // Newly opened streams should only originate from the client. To open a stream from the server, OpenStream should be used.
    92  func (m *streamsMapLegacy) getOrOpenStream(id protocol.StreamID) (streamI, error) {
    93  	m.mutex.RLock()
    94  	s, ok := m.streams[id]
    95  	m.mutex.RUnlock()
    96  	if ok {
    97  		return s, nil
    98  	}
    99  
   100  	// ... we don't have an existing stream
   101  	m.mutex.Lock()
   102  	defer m.mutex.Unlock()
   103  	// We need to check whether another invocation has already created a stream (between RUnlock() and Lock()).
   104  	s, ok = m.streams[id]
   105  	if ok {
   106  		return s, nil
   107  	}
   108  
   109  	if m.perspective == m.streamInitiatedBy(id) {
   110  		if id <= m.nextStreamToOpen { // this is a stream opened by us. Must have been closed already
   111  			return nil, nil
   112  		}
   113  		return nil, qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("peer attempted to open stream %d", id))
   114  	}
   115  	if id <= m.highestStreamOpenedByPeer { // this is a peer-initiated stream that doesn't exist anymore. Must have been closed already
   116  		return nil, nil
   117  	}
   118  
   119  	for sid := m.highestStreamOpenedByPeer + 2; sid <= id; sid += 2 {
   120  		if _, err := m.openRemoteStream(sid); err != nil {
   121  			return nil, err
   122  		}
   123  	}
   124  
   125  	m.nextStreamOrErrCond.Broadcast()
   126  	return m.streams[id], nil
   127  }
   128  
   129  func (m *streamsMapLegacy) openRemoteStream(id protocol.StreamID) (streamI, error) {
   130  	if m.numIncomingStreams >= m.maxIncomingStreams {
   131  		return nil, qerr.TooManyOpenStreams
   132  	}
   133  	// maxNewStreamIDDelta is the maximum difference between and a newly opened Stream and the highest StreamID that a client has ever opened
   134  	// note that the number of streams is half this value, since the client can only open streams with open StreamID
   135  	maxStreamIDDelta := protocol.StreamID(4 * m.maxIncomingStreams)
   136  	if id+maxStreamIDDelta < m.highestStreamOpenedByPeer {
   137  		return nil, qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("attempted to open stream %d, which is a lot smaller than the highest opened stream, %d", id, m.highestStreamOpenedByPeer))
   138  	}
   139  
   140  	m.numIncomingStreams++
   141  	if id > m.highestStreamOpenedByPeer {
   142  		m.highestStreamOpenedByPeer = id
   143  	}
   144  
   145  	s := m.newStream(id)
   146  	return s, m.putStream(s)
   147  }
   148  
   149  func (m *streamsMapLegacy) openStreamImpl() (streamI, error) {
   150  	if m.numOutgoingStreams >= m.maxOutgoingStreams {
   151  		return nil, qerr.TooManyOpenStreams
   152  	}
   153  
   154  	m.numOutgoingStreams++
   155  	s := m.newStream(m.nextStreamToOpen)
   156  	m.nextStreamToOpen += 2
   157  	return s, m.putStream(s)
   158  }
   159  
   160  // OpenStream opens the next available stream
   161  func (m *streamsMapLegacy) OpenStream() (Stream, error) {
   162  	m.mutex.Lock()
   163  	defer m.mutex.Unlock()
   164  
   165  	if m.closeErr != nil {
   166  		return nil, m.closeErr
   167  	}
   168  	return m.openStreamImpl()
   169  }
   170  
   171  func (m *streamsMapLegacy) OpenStreamSync() (Stream, error) {
   172  	m.mutex.Lock()
   173  	defer m.mutex.Unlock()
   174  
   175  	for {
   176  		if m.closeErr != nil {
   177  			return nil, m.closeErr
   178  		}
   179  		str, err := m.openStreamImpl()
   180  		if err == nil {
   181  			return str, err
   182  		}
   183  		if err != nil && err != qerr.TooManyOpenStreams {
   184  			return nil, err
   185  		}
   186  		m.openStreamOrErrCond.Wait()
   187  	}
   188  }
   189  
   190  func (m *streamsMapLegacy) OpenUniStream() (SendStream, error) {
   191  	return nil, errors.New("gQUIC doesn't support unidirectional streams")
   192  }
   193  
   194  func (m *streamsMapLegacy) OpenUniStreamSync() (SendStream, error) {
   195  	return nil, errors.New("gQUIC doesn't support unidirectional streams")
   196  }
   197  
   198  // AcceptStream returns the next stream opened by the peer
   199  // it blocks until a new stream is opened
   200  func (m *streamsMapLegacy) AcceptStream() (Stream, error) {
   201  	m.mutex.Lock()
   202  	defer m.mutex.Unlock()
   203  	var str streamI
   204  	for {
   205  		var ok bool
   206  		if m.closeErr != nil {
   207  			return nil, m.closeErr
   208  		}
   209  		str, ok = m.streams[m.nextStreamToAccept]
   210  		if ok {
   211  			break
   212  		}
   213  		m.nextStreamOrErrCond.Wait()
   214  	}
   215  	m.nextStreamToAccept += 2
   216  	return str, nil
   217  }
   218  
   219  func (m *streamsMapLegacy) AcceptUniStream() (ReceiveStream, error) {
   220  	return nil, errors.New("gQUIC doesn't support unidirectional streams")
   221  }
   222  
   223  func (m *streamsMapLegacy) DeleteStream(id protocol.StreamID) error {
   224  	m.mutex.Lock()
   225  	defer m.mutex.Unlock()
   226  	_, ok := m.streams[id]
   227  	if !ok {
   228  		return errMapAccess
   229  	}
   230  	delete(m.streams, id)
   231  	if m.streamInitiatedBy(id) == m.perspective {
   232  		m.numOutgoingStreams--
   233  	} else {
   234  		m.numIncomingStreams--
   235  	}
   236  	m.openStreamOrErrCond.Signal()
   237  	return nil
   238  }
   239  
   240  func (m *streamsMapLegacy) putStream(s streamI) error {
   241  	id := s.StreamID()
   242  	if _, ok := m.streams[id]; ok {
   243  		return fmt.Errorf("a stream with ID %d already exists", id)
   244  	}
   245  	m.streams[id] = s
   246  	return nil
   247  }
   248  
   249  func (m *streamsMapLegacy) CloseWithError(err error) {
   250  	m.mutex.Lock()
   251  	defer m.mutex.Unlock()
   252  	m.closeErr = err
   253  	m.nextStreamOrErrCond.Broadcast()
   254  	m.openStreamOrErrCond.Broadcast()
   255  	for _, s := range m.streams {
   256  		s.closeForShutdown(err)
   257  	}
   258  }
   259  
   260  // TODO(#952): this won't be needed when gQUIC supports stateless handshakes
   261  func (m *streamsMapLegacy) UpdateLimits(params *handshake.TransportParameters) {
   262  	m.mutex.Lock()
   263  	m.maxOutgoingStreams = params.MaxStreams
   264  	for id, str := range m.streams {
   265  		str.handleMaxStreamDataFrame(&wire.MaxStreamDataFrame{
   266  			StreamID:   id,
   267  			ByteOffset: params.StreamFlowControlWindow,
   268  		})
   269  	}
   270  	m.mutex.Unlock()
   271  	m.openStreamOrErrCond.Broadcast()
   272  }
   273  
   274  // should never be called, since MAX_STREAM_ID frames can only be unpacked for IETF QUIC
   275  func (m *streamsMapLegacy) HandleMaxStreamIDFrame(f *wire.MaxStreamIDFrame) error {
   276  	return errors.New("gQUIC doesn't have MAX_STREAM_ID frames")
   277  }