github.com/psiphon-labs/psiphon-tunnel-core@v2.0.28+incompatible/psiphon/common/quic/gquic-go/streams_map.go (about)

     1  package gquic
     2  
     3  import (
     4  	"fmt"
     5  
     6  	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/quic/gquic-go/internal/flowcontrol"
     7  	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/quic/gquic-go/internal/handshake"
     8  	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/quic/gquic-go/internal/protocol"
     9  	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/quic/gquic-go/internal/wire"
    10  )
    11  
    12  type streamType int
    13  
    14  const (
    15  	streamTypeOutgoingBidi streamType = iota
    16  	streamTypeIncomingBidi
    17  	streamTypeOutgoingUni
    18  	streamTypeIncomingUni
    19  )
    20  
    21  type streamsMap struct {
    22  	perspective protocol.Perspective
    23  
    24  	sender            streamSender
    25  	newFlowController func(protocol.StreamID) flowcontrol.StreamFlowController
    26  
    27  	outgoingBidiStreams *outgoingBidiStreamsMap
    28  	outgoingUniStreams  *outgoingUniStreamsMap
    29  	incomingBidiStreams *incomingBidiStreamsMap
    30  	incomingUniStreams  *incomingUniStreamsMap
    31  }
    32  
    33  var _ streamManager = &streamsMap{}
    34  
    35  func newStreamsMap(
    36  	sender streamSender,
    37  	newFlowController func(protocol.StreamID) flowcontrol.StreamFlowController,
    38  	maxIncomingStreams int,
    39  	maxIncomingUniStreams int,
    40  	perspective protocol.Perspective,
    41  	version protocol.VersionNumber,
    42  ) streamManager {
    43  	m := &streamsMap{
    44  		perspective:       perspective,
    45  		newFlowController: newFlowController,
    46  		sender:            sender,
    47  	}
    48  	var firstOutgoingBidiStream, firstOutgoingUniStream, firstIncomingBidiStream, firstIncomingUniStream protocol.StreamID
    49  	if perspective == protocol.PerspectiveServer {
    50  		firstOutgoingBidiStream = 1
    51  		firstIncomingBidiStream = 4 // the crypto stream is handled separately
    52  		firstOutgoingUniStream = 3
    53  		firstIncomingUniStream = 2
    54  	} else {
    55  		firstOutgoingBidiStream = 4 // the crypto stream is handled separately
    56  		firstIncomingBidiStream = 1
    57  		firstOutgoingUniStream = 2
    58  		firstIncomingUniStream = 3
    59  	}
    60  	newBidiStream := func(id protocol.StreamID) streamI {
    61  		return newStream(id, m.sender, m.newFlowController(id), version)
    62  	}
    63  	newUniSendStream := func(id protocol.StreamID) sendStreamI {
    64  		return newSendStream(id, m.sender, m.newFlowController(id), version)
    65  	}
    66  	newUniReceiveStream := func(id protocol.StreamID) receiveStreamI {
    67  		return newReceiveStream(id, m.sender, m.newFlowController(id), version)
    68  	}
    69  	m.outgoingBidiStreams = newOutgoingBidiStreamsMap(
    70  		firstOutgoingBidiStream,
    71  		newBidiStream,
    72  		sender.queueControlFrame,
    73  	)
    74  	m.incomingBidiStreams = newIncomingBidiStreamsMap(
    75  		firstIncomingBidiStream,
    76  		protocol.MaxBidiStreamID(maxIncomingStreams, perspective),
    77  		maxIncomingStreams,
    78  		sender.queueControlFrame,
    79  		newBidiStream,
    80  	)
    81  	m.outgoingUniStreams = newOutgoingUniStreamsMap(
    82  		firstOutgoingUniStream,
    83  		newUniSendStream,
    84  		sender.queueControlFrame,
    85  	)
    86  	m.incomingUniStreams = newIncomingUniStreamsMap(
    87  		firstIncomingUniStream,
    88  		protocol.MaxUniStreamID(maxIncomingUniStreams, perspective),
    89  		maxIncomingUniStreams,
    90  		sender.queueControlFrame,
    91  		newUniReceiveStream,
    92  	)
    93  	return m
    94  }
    95  
    96  func (m *streamsMap) getStreamType(id protocol.StreamID) streamType {
    97  	if m.perspective == protocol.PerspectiveServer {
    98  		switch id % 4 {
    99  		case 0:
   100  			return streamTypeIncomingBidi
   101  		case 1:
   102  			return streamTypeOutgoingBidi
   103  		case 2:
   104  			return streamTypeIncomingUni
   105  		case 3:
   106  			return streamTypeOutgoingUni
   107  		}
   108  	} else {
   109  		switch id % 4 {
   110  		case 0:
   111  			return streamTypeOutgoingBidi
   112  		case 1:
   113  			return streamTypeIncomingBidi
   114  		case 2:
   115  			return streamTypeOutgoingUni
   116  		case 3:
   117  			return streamTypeIncomingUni
   118  		}
   119  	}
   120  	panic("")
   121  }
   122  
   123  func (m *streamsMap) OpenStream() (Stream, error) {
   124  	return m.outgoingBidiStreams.OpenStream()
   125  }
   126  
   127  func (m *streamsMap) OpenStreamSync() (Stream, error) {
   128  	return m.outgoingBidiStreams.OpenStreamSync()
   129  }
   130  
   131  func (m *streamsMap) OpenUniStream() (SendStream, error) {
   132  	return m.outgoingUniStreams.OpenStream()
   133  }
   134  
   135  func (m *streamsMap) OpenUniStreamSync() (SendStream, error) {
   136  	return m.outgoingUniStreams.OpenStreamSync()
   137  }
   138  
   139  func (m *streamsMap) AcceptStream() (Stream, error) {
   140  	return m.incomingBidiStreams.AcceptStream()
   141  }
   142  
   143  func (m *streamsMap) AcceptUniStream() (ReceiveStream, error) {
   144  	return m.incomingUniStreams.AcceptStream()
   145  }
   146  
   147  func (m *streamsMap) DeleteStream(id protocol.StreamID) error {
   148  	switch m.getStreamType(id) {
   149  	case streamTypeIncomingBidi:
   150  		return m.incomingBidiStreams.DeleteStream(id)
   151  	case streamTypeOutgoingBidi:
   152  		return m.outgoingBidiStreams.DeleteStream(id)
   153  	case streamTypeIncomingUni:
   154  		return m.incomingUniStreams.DeleteStream(id)
   155  	case streamTypeOutgoingUni:
   156  		return m.outgoingUniStreams.DeleteStream(id)
   157  	default:
   158  		panic("invalid stream type")
   159  	}
   160  }
   161  
   162  func (m *streamsMap) GetOrOpenReceiveStream(id protocol.StreamID) (receiveStreamI, error) {
   163  	switch m.getStreamType(id) {
   164  	case streamTypeOutgoingBidi:
   165  		return m.outgoingBidiStreams.GetStream(id)
   166  	case streamTypeIncomingBidi:
   167  		return m.incomingBidiStreams.GetOrOpenStream(id)
   168  	case streamTypeIncomingUni:
   169  		return m.incomingUniStreams.GetOrOpenStream(id)
   170  	case streamTypeOutgoingUni:
   171  		// an outgoing unidirectional stream is a send stream, not a receive stream
   172  		return nil, fmt.Errorf("peer attempted to open receive stream %d", id)
   173  	default:
   174  		panic("invalid stream type")
   175  	}
   176  }
   177  
   178  func (m *streamsMap) GetOrOpenSendStream(id protocol.StreamID) (sendStreamI, error) {
   179  	switch m.getStreamType(id) {
   180  	case streamTypeOutgoingBidi:
   181  		return m.outgoingBidiStreams.GetStream(id)
   182  	case streamTypeIncomingBidi:
   183  		return m.incomingBidiStreams.GetOrOpenStream(id)
   184  	case streamTypeOutgoingUni:
   185  		return m.outgoingUniStreams.GetStream(id)
   186  	case streamTypeIncomingUni:
   187  		// an incoming unidirectional stream is a receive stream, not a send stream
   188  		return nil, fmt.Errorf("peer attempted to open send stream %d", id)
   189  	default:
   190  		panic("invalid stream type")
   191  	}
   192  }
   193  
   194  func (m *streamsMap) HandleMaxStreamIDFrame(f *wire.MaxStreamIDFrame) error {
   195  	id := f.StreamID
   196  	switch m.getStreamType(id) {
   197  	case streamTypeOutgoingBidi:
   198  		m.outgoingBidiStreams.SetMaxStream(id)
   199  		return nil
   200  	case streamTypeOutgoingUni:
   201  		m.outgoingUniStreams.SetMaxStream(id)
   202  		return nil
   203  	default:
   204  		return fmt.Errorf("received MAX_STREAM_DATA frame for incoming stream %d", id)
   205  	}
   206  }
   207  
   208  func (m *streamsMap) UpdateLimits(p *handshake.TransportParameters) {
   209  	// Max{Uni,Bidi}StreamID returns the highest stream ID that the peer is allowed to open.
   210  	// Invert the perspective to determine the value that we are allowed to open.
   211  	peerPers := protocol.PerspectiveServer
   212  	if m.perspective == protocol.PerspectiveServer {
   213  		peerPers = protocol.PerspectiveClient
   214  	}
   215  	m.outgoingBidiStreams.SetMaxStream(protocol.MaxBidiStreamID(int(p.MaxBidiStreams), peerPers))
   216  	m.outgoingUniStreams.SetMaxStream(protocol.MaxUniStreamID(int(p.MaxUniStreams), peerPers))
   217  }
   218  
   219  func (m *streamsMap) CloseWithError(err error) {
   220  	m.outgoingBidiStreams.CloseWithError(err)
   221  	m.outgoingUniStreams.CloseWithError(err)
   222  	m.incomingBidiStreams.CloseWithError(err)
   223  	m.incomingUniStreams.CloseWithError(err)
   224  }