github.com/sagernet/quic-go@v0.43.1-beta.1/ech/streams_map_outgoing.go (about)

     1  package quic
     2  
     3  import (
     4  	"context"
     5  	"sync"
     6  
     7  	"github.com/sagernet/quic-go/internal/protocol"
     8  	"github.com/sagernet/quic-go/internal/wire"
     9  )
    10  
    11  type outgoingStream interface {
    12  	updateSendWindow(protocol.ByteCount)
    13  	closeForShutdown(error)
    14  }
    15  
    16  type outgoingStreamsMap[T outgoingStream] struct {
    17  	mutex sync.RWMutex
    18  
    19  	streamType protocol.StreamType
    20  	streams    map[protocol.StreamNum]T
    21  
    22  	openQueue      map[uint64]chan struct{}
    23  	lowestInQueue  uint64
    24  	highestInQueue uint64
    25  
    26  	nextStream  protocol.StreamNum // stream ID of the stream returned by OpenStream(Sync)
    27  	maxStream   protocol.StreamNum // the maximum stream ID we're allowed to open
    28  	blockedSent bool               // was a STREAMS_BLOCKED sent for the current maxStream
    29  
    30  	newStream            func(protocol.StreamNum) T
    31  	queueStreamIDBlocked func(*wire.StreamsBlockedFrame)
    32  
    33  	closeErr error
    34  }
    35  
    36  func newOutgoingStreamsMap[T outgoingStream](
    37  	streamType protocol.StreamType,
    38  	newStream func(protocol.StreamNum) T,
    39  	queueControlFrame func(wire.Frame),
    40  ) *outgoingStreamsMap[T] {
    41  	return &outgoingStreamsMap[T]{
    42  		streamType:           streamType,
    43  		streams:              make(map[protocol.StreamNum]T),
    44  		openQueue:            make(map[uint64]chan struct{}),
    45  		maxStream:            protocol.InvalidStreamNum,
    46  		nextStream:           1,
    47  		newStream:            newStream,
    48  		queueStreamIDBlocked: func(f *wire.StreamsBlockedFrame) { queueControlFrame(f) },
    49  	}
    50  }
    51  
    52  func (m *outgoingStreamsMap[T]) OpenStream() (T, error) {
    53  	m.mutex.Lock()
    54  	defer m.mutex.Unlock()
    55  
    56  	if m.closeErr != nil {
    57  		return *new(T), m.closeErr
    58  	}
    59  
    60  	// if there are OpenStreamSync calls waiting, return an error here
    61  	if len(m.openQueue) > 0 || m.nextStream > m.maxStream {
    62  		m.maybeSendBlockedFrame()
    63  		return *new(T), streamOpenErr{errTooManyOpenStreams}
    64  	}
    65  	return m.openStream(), nil
    66  }
    67  
    68  func (m *outgoingStreamsMap[T]) OpenStreamSync(ctx context.Context) (T, error) {
    69  	m.mutex.Lock()
    70  	defer m.mutex.Unlock()
    71  
    72  	if m.closeErr != nil {
    73  		return *new(T), m.closeErr
    74  	}
    75  
    76  	if err := ctx.Err(); err != nil {
    77  		return *new(T), err
    78  	}
    79  
    80  	if len(m.openQueue) == 0 && m.nextStream <= m.maxStream {
    81  		return m.openStream(), nil
    82  	}
    83  
    84  	waitChan := make(chan struct{}, 1)
    85  	queuePos := m.highestInQueue
    86  	m.highestInQueue++
    87  	if len(m.openQueue) == 0 {
    88  		m.lowestInQueue = queuePos
    89  	}
    90  	m.openQueue[queuePos] = waitChan
    91  	m.maybeSendBlockedFrame()
    92  
    93  	for {
    94  		m.mutex.Unlock()
    95  		select {
    96  		case <-ctx.Done():
    97  			m.mutex.Lock()
    98  			delete(m.openQueue, queuePos)
    99  			return *new(T), ctx.Err()
   100  		case <-waitChan:
   101  		}
   102  		m.mutex.Lock()
   103  
   104  		if m.closeErr != nil {
   105  			return *new(T), m.closeErr
   106  		}
   107  		if m.nextStream > m.maxStream {
   108  			// no stream available. Continue waiting
   109  			continue
   110  		}
   111  		str := m.openStream()
   112  		delete(m.openQueue, queuePos)
   113  		m.lowestInQueue = queuePos + 1
   114  		m.unblockOpenSync()
   115  		return str, nil
   116  	}
   117  }
   118  
   119  func (m *outgoingStreamsMap[T]) openStream() T {
   120  	s := m.newStream(m.nextStream)
   121  	m.streams[m.nextStream] = s
   122  	m.nextStream++
   123  	return s
   124  }
   125  
   126  // maybeSendBlockedFrame queues a STREAMS_BLOCKED frame for the current stream offset,
   127  // if we haven't sent one for this offset yet
   128  func (m *outgoingStreamsMap[T]) maybeSendBlockedFrame() {
   129  	if m.blockedSent {
   130  		return
   131  	}
   132  
   133  	var streamNum protocol.StreamNum
   134  	if m.maxStream != protocol.InvalidStreamNum {
   135  		streamNum = m.maxStream
   136  	}
   137  	m.queueStreamIDBlocked(&wire.StreamsBlockedFrame{
   138  		Type:        m.streamType,
   139  		StreamLimit: streamNum,
   140  	})
   141  	m.blockedSent = true
   142  }
   143  
   144  func (m *outgoingStreamsMap[T]) GetStream(num protocol.StreamNum) (T, error) {
   145  	m.mutex.RLock()
   146  	if num >= m.nextStream {
   147  		m.mutex.RUnlock()
   148  		return *new(T), streamError{
   149  			message: "peer attempted to open stream %d",
   150  			nums:    []protocol.StreamNum{num},
   151  		}
   152  	}
   153  	s := m.streams[num]
   154  	m.mutex.RUnlock()
   155  	return s, nil
   156  }
   157  
   158  func (m *outgoingStreamsMap[T]) DeleteStream(num protocol.StreamNum) error {
   159  	m.mutex.Lock()
   160  	defer m.mutex.Unlock()
   161  
   162  	if _, ok := m.streams[num]; !ok {
   163  		return streamError{
   164  			message: "tried to delete unknown outgoing stream %d",
   165  			nums:    []protocol.StreamNum{num},
   166  		}
   167  	}
   168  	delete(m.streams, num)
   169  	return nil
   170  }
   171  
   172  func (m *outgoingStreamsMap[T]) SetMaxStream(num protocol.StreamNum) {
   173  	m.mutex.Lock()
   174  	defer m.mutex.Unlock()
   175  
   176  	if num <= m.maxStream {
   177  		return
   178  	}
   179  	m.maxStream = num
   180  	m.blockedSent = false
   181  	if m.maxStream < m.nextStream-1+protocol.StreamNum(len(m.openQueue)) {
   182  		m.maybeSendBlockedFrame()
   183  	}
   184  	m.unblockOpenSync()
   185  }
   186  
   187  // UpdateSendWindow is called when the peer's transport parameters are received.
   188  // Only in the case of a 0-RTT handshake will we have open streams at this point.
   189  // We might need to update the send window, in case the server increased it.
   190  func (m *outgoingStreamsMap[T]) UpdateSendWindow(limit protocol.ByteCount) {
   191  	m.mutex.Lock()
   192  	for _, str := range m.streams {
   193  		str.updateSendWindow(limit)
   194  	}
   195  	m.mutex.Unlock()
   196  }
   197  
   198  // unblockOpenSync unblocks the next OpenStreamSync go-routine to open a new stream
   199  func (m *outgoingStreamsMap[T]) unblockOpenSync() {
   200  	if len(m.openQueue) == 0 {
   201  		return
   202  	}
   203  	for qp := m.lowestInQueue; qp <= m.highestInQueue; qp++ {
   204  		c, ok := m.openQueue[qp]
   205  		if !ok { // entry was deleted because the context was canceled
   206  			continue
   207  		}
   208  		// unblockOpenSync is called both from OpenStreamSync and from SetMaxStream.
   209  		// It's sufficient to only unblock OpenStreamSync once.
   210  		select {
   211  		case c <- struct{}{}:
   212  		default:
   213  		}
   214  		return
   215  	}
   216  }
   217  
   218  func (m *outgoingStreamsMap[T]) CloseWithError(err error) {
   219  	m.mutex.Lock()
   220  	m.closeErr = err
   221  	for _, str := range m.streams {
   222  		str.closeForShutdown(err)
   223  	}
   224  	for _, c := range m.openQueue {
   225  		if c != nil {
   226  			close(c)
   227  		}
   228  	}
   229  	m.mutex.Unlock()
   230  }