github.com/daeuniverse/quic-go@v0.0.0-20240413031024-943f218e0810/streams_map_outgoing.go (about)

     1  package quic
     2  
     3  import (
     4  	"context"
     5  	"sync"
     6  
     7  	"github.com/daeuniverse/quic-go/internal/protocol"
     8  	"github.com/daeuniverse/quic-go/internal/wire"
     9  )
    10  
    11  type outgoingStream interface {
    12  	updateSendWindow(protocol.ByteCount)
    13  	closeForShutdown(error)
    14  }
    15  
    16  type outgoingStreamsMap[T outgoingStream] struct {
    17  	mutex sync.RWMutex
    18  
    19  	streamType protocol.StreamType
    20  	streams    map[protocol.StreamNum]T
    21  
    22  	openQueue      map[uint64]chan struct{}
    23  	lowestInQueue  uint64
    24  	highestInQueue uint64
    25  
    26  	nextStream         protocol.StreamNum // stream ID of the stream returned by OpenStream(Sync)
    27  	maxStream          protocol.StreamNum // the maximum stream ID we're allowed to open
    28  	blockedSent        bool               // was a STREAMS_BLOCKED sent for the current maxStream
    29  	capabilityCallback func(n int64)
    30  
    31  	newStream            func(protocol.StreamNum) T
    32  	queueStreamIDBlocked func(*wire.StreamsBlockedFrame)
    33  
    34  	closeErr error
    35  }
    36  
    37  func newOutgoingStreamsMap[T outgoingStream](
    38  	streamType protocol.StreamType,
    39  	newStream func(protocol.StreamNum) T,
    40  	queueControlFrame func(wire.Frame),
    41  	capabilityCallback func(n int64),
    42  ) *outgoingStreamsMap[T] {
    43  	if capabilityCallback == nil {
    44  		capabilityCallback = func(n int64) {}
    45  	}
    46  	return &outgoingStreamsMap[T]{
    47  		streamType:           streamType,
    48  		streams:              make(map[protocol.StreamNum]T),
    49  		openQueue:            make(map[uint64]chan struct{}),
    50  		maxStream:            protocol.InvalidStreamNum,
    51  		nextStream:           1,
    52  		newStream:            newStream,
    53  		queueStreamIDBlocked: func(f *wire.StreamsBlockedFrame) { queueControlFrame(f) },
    54  		capabilityCallback:   capabilityCallback,
    55  	}
    56  }
    57  
    58  func (m *outgoingStreamsMap[T]) OpenStream() (T, error) {
    59  	m.mutex.Lock()
    60  	defer m.mutex.Unlock()
    61  
    62  	if m.closeErr != nil {
    63  		return *new(T), m.closeErr
    64  	}
    65  
    66  	// if there are OpenStreamSync calls waiting, return an error here
    67  	if len(m.openQueue) > 0 || m.nextStream > m.maxStream {
    68  		m.maybeSendBlockedFrame()
    69  		return *new(T), streamOpenErr{errTooManyOpenStreams}
    70  	}
    71  	return m.openStream(), nil
    72  }
    73  
    74  func (m *outgoingStreamsMap[T]) OpenStreamSync(ctx context.Context) (T, error) {
    75  	m.mutex.Lock()
    76  	defer m.mutex.Unlock()
    77  
    78  	if m.closeErr != nil {
    79  		return *new(T), m.closeErr
    80  	}
    81  
    82  	if err := ctx.Err(); err != nil {
    83  		return *new(T), err
    84  	}
    85  
    86  	if len(m.openQueue) == 0 && m.nextStream <= m.maxStream {
    87  		return m.openStream(), nil
    88  	}
    89  
    90  	waitChan := make(chan struct{}, 1)
    91  	queuePos := m.highestInQueue
    92  	m.highestInQueue++
    93  	if len(m.openQueue) == 0 {
    94  		m.lowestInQueue = queuePos
    95  	}
    96  	m.openQueue[queuePos] = waitChan
    97  	m.maybeSendBlockedFrame()
    98  
    99  	for {
   100  		m.mutex.Unlock()
   101  		select {
   102  		case <-ctx.Done():
   103  			m.mutex.Lock()
   104  			delete(m.openQueue, queuePos)
   105  			return *new(T), ctx.Err()
   106  		case <-waitChan:
   107  		}
   108  		m.mutex.Lock()
   109  
   110  		if m.closeErr != nil {
   111  			return *new(T), m.closeErr
   112  		}
   113  		if m.nextStream > m.maxStream {
   114  			// no stream available. Continue waiting
   115  			continue
   116  		}
   117  		str := m.openStream()
   118  		delete(m.openQueue, queuePos)
   119  		m.lowestInQueue = queuePos + 1
   120  		m.unblockOpenSync()
   121  		return str, nil
   122  	}
   123  }
   124  
   125  func (m *outgoingStreamsMap[T]) openStream() T {
   126  	s := m.newStream(m.nextStream)
   127  	m.streams[m.nextStream] = s
   128  	m.nextStream++
   129  	m.capabilityCallback(int64(m.maxStream - m.nextStream))
   130  	return s
   131  }
   132  
   133  // maybeSendBlockedFrame queues a STREAMS_BLOCKED frame for the current stream offset,
   134  // if we haven't sent one for this offset yet
   135  func (m *outgoingStreamsMap[T]) maybeSendBlockedFrame() {
   136  	if m.blockedSent {
   137  		return
   138  	}
   139  
   140  	var streamNum protocol.StreamNum
   141  	if m.maxStream != protocol.InvalidStreamNum {
   142  		streamNum = m.maxStream
   143  	}
   144  	m.queueStreamIDBlocked(&wire.StreamsBlockedFrame{
   145  		Type:        m.streamType,
   146  		StreamLimit: streamNum,
   147  	})
   148  	m.blockedSent = true
   149  }
   150  
   151  func (m *outgoingStreamsMap[T]) GetStream(num protocol.StreamNum) (T, error) {
   152  	m.mutex.RLock()
   153  	if num >= m.nextStream {
   154  		m.mutex.RUnlock()
   155  		return *new(T), streamError{
   156  			message: "peer attempted to open stream %d",
   157  			nums:    []protocol.StreamNum{num},
   158  		}
   159  	}
   160  	s := m.streams[num]
   161  	m.mutex.RUnlock()
   162  	return s, nil
   163  }
   164  
   165  func (m *outgoingStreamsMap[T]) DeleteStream(num protocol.StreamNum) error {
   166  	m.mutex.Lock()
   167  	defer m.mutex.Unlock()
   168  
   169  	if _, ok := m.streams[num]; !ok {
   170  		return streamError{
   171  			message: "tried to delete unknown outgoing stream %d",
   172  			nums:    []protocol.StreamNum{num},
   173  		}
   174  	}
   175  	delete(m.streams, num)
   176  	return nil
   177  }
   178  
   179  func (m *outgoingStreamsMap[T]) SetMaxStream(num protocol.StreamNum) {
   180  	m.mutex.Lock()
   181  	defer m.mutex.Unlock()
   182  
   183  	if num <= m.maxStream {
   184  		return
   185  	}
   186  	m.maxStream = num
   187  	m.capabilityCallback(int64(m.maxStream - m.nextStream))
   188  	m.blockedSent = false
   189  	if m.maxStream < m.nextStream-1+protocol.StreamNum(len(m.openQueue)) {
   190  		m.maybeSendBlockedFrame()
   191  	}
   192  	m.unblockOpenSync()
   193  }
   194  
   195  // UpdateSendWindow is called when the peer's transport parameters are received.
   196  // Only in the case of a 0-RTT handshake will we have open streams at this point.
   197  // We might need to update the send window, in case the server increased it.
   198  func (m *outgoingStreamsMap[T]) UpdateSendWindow(limit protocol.ByteCount) {
   199  	m.mutex.Lock()
   200  	for _, str := range m.streams {
   201  		str.updateSendWindow(limit)
   202  	}
   203  	m.mutex.Unlock()
   204  }
   205  
   206  // unblockOpenSync unblocks the next OpenStreamSync go-routine to open a new stream
   207  func (m *outgoingStreamsMap[T]) unblockOpenSync() {
   208  	if len(m.openQueue) == 0 {
   209  		return
   210  	}
   211  	for qp := m.lowestInQueue; qp <= m.highestInQueue; qp++ {
   212  		c, ok := m.openQueue[qp]
   213  		if !ok { // entry was deleted because the context was canceled
   214  			continue
   215  		}
   216  		// unblockOpenSync is called both from OpenStreamSync and from SetMaxStream.
   217  		// It's sufficient to only unblock OpenStreamSync once.
   218  		select {
   219  		case c <- struct{}{}:
   220  		default:
   221  		}
   222  		return
   223  	}
   224  }
   225  
   226  func (m *outgoingStreamsMap[T]) CloseWithError(err error) {
   227  	m.mutex.Lock()
   228  	m.closeErr = err
   229  	for _, str := range m.streams {
   230  		str.closeForShutdown(err)
   231  	}
   232  	for _, c := range m.openQueue {
   233  		if c != nil {
   234  			close(c)
   235  		}
   236  	}
   237  	m.mutex.Unlock()
   238  }