github.com/daeuniverse/quic-go@v0.0.0-20240413031024-943f218e0810/streams_map_outgoing.go (about) 1 package quic 2 3 import ( 4 "context" 5 "sync" 6 7 "github.com/daeuniverse/quic-go/internal/protocol" 8 "github.com/daeuniverse/quic-go/internal/wire" 9 ) 10 11 type outgoingStream interface { 12 updateSendWindow(protocol.ByteCount) 13 closeForShutdown(error) 14 } 15 16 type outgoingStreamsMap[T outgoingStream] struct { 17 mutex sync.RWMutex 18 19 streamType protocol.StreamType 20 streams map[protocol.StreamNum]T 21 22 openQueue map[uint64]chan struct{} 23 lowestInQueue uint64 24 highestInQueue uint64 25 26 nextStream protocol.StreamNum // stream ID of the stream returned by OpenStream(Sync) 27 maxStream protocol.StreamNum // the maximum stream ID we're allowed to open 28 blockedSent bool // was a STREAMS_BLOCKED sent for the current maxStream 29 capabilityCallback func(n int64) 30 31 newStream func(protocol.StreamNum) T 32 queueStreamIDBlocked func(*wire.StreamsBlockedFrame) 33 34 closeErr error 35 } 36 37 func newOutgoingStreamsMap[T outgoingStream]( 38 streamType protocol.StreamType, 39 newStream func(protocol.StreamNum) T, 40 queueControlFrame func(wire.Frame), 41 capabilityCallback func(n int64), 42 ) *outgoingStreamsMap[T] { 43 if capabilityCallback == nil { 44 capabilityCallback = func(n int64) {} 45 } 46 return &outgoingStreamsMap[T]{ 47 streamType: streamType, 48 streams: make(map[protocol.StreamNum]T), 49 openQueue: make(map[uint64]chan struct{}), 50 maxStream: protocol.InvalidStreamNum, 51 nextStream: 1, 52 newStream: newStream, 53 queueStreamIDBlocked: func(f *wire.StreamsBlockedFrame) { queueControlFrame(f) }, 54 capabilityCallback: capabilityCallback, 55 } 56 } 57 58 func (m *outgoingStreamsMap[T]) OpenStream() (T, error) { 59 m.mutex.Lock() 60 defer m.mutex.Unlock() 61 62 if m.closeErr != nil { 63 return *new(T), m.closeErr 64 } 65 66 // if there are OpenStreamSync calls waiting, return an error here 67 if len(m.openQueue) > 0 || m.nextStream > m.maxStream { 68 m.maybeSendBlockedFrame() 69 return *new(T), streamOpenErr{errTooManyOpenStreams} 70 } 71 return m.openStream(), nil 72 } 73 74 func (m *outgoingStreamsMap[T]) OpenStreamSync(ctx context.Context) (T, error) { 75 m.mutex.Lock() 76 defer m.mutex.Unlock() 77 78 if m.closeErr != nil { 79 return *new(T), m.closeErr 80 } 81 82 if err := ctx.Err(); err != nil { 83 return *new(T), err 84 } 85 86 if len(m.openQueue) == 0 && m.nextStream <= m.maxStream { 87 return m.openStream(), nil 88 } 89 90 waitChan := make(chan struct{}, 1) 91 queuePos := m.highestInQueue 92 m.highestInQueue++ 93 if len(m.openQueue) == 0 { 94 m.lowestInQueue = queuePos 95 } 96 m.openQueue[queuePos] = waitChan 97 m.maybeSendBlockedFrame() 98 99 for { 100 m.mutex.Unlock() 101 select { 102 case <-ctx.Done(): 103 m.mutex.Lock() 104 delete(m.openQueue, queuePos) 105 return *new(T), ctx.Err() 106 case <-waitChan: 107 } 108 m.mutex.Lock() 109 110 if m.closeErr != nil { 111 return *new(T), m.closeErr 112 } 113 if m.nextStream > m.maxStream { 114 // no stream available. Continue waiting 115 continue 116 } 117 str := m.openStream() 118 delete(m.openQueue, queuePos) 119 m.lowestInQueue = queuePos + 1 120 m.unblockOpenSync() 121 return str, nil 122 } 123 } 124 125 func (m *outgoingStreamsMap[T]) openStream() T { 126 s := m.newStream(m.nextStream) 127 m.streams[m.nextStream] = s 128 m.nextStream++ 129 m.capabilityCallback(int64(m.maxStream - m.nextStream)) 130 return s 131 } 132 133 // maybeSendBlockedFrame queues a STREAMS_BLOCKED frame for the current stream offset, 134 // if we haven't sent one for this offset yet 135 func (m *outgoingStreamsMap[T]) maybeSendBlockedFrame() { 136 if m.blockedSent { 137 return 138 } 139 140 var streamNum protocol.StreamNum 141 if m.maxStream != protocol.InvalidStreamNum { 142 streamNum = m.maxStream 143 } 144 m.queueStreamIDBlocked(&wire.StreamsBlockedFrame{ 145 Type: m.streamType, 146 StreamLimit: streamNum, 147 }) 148 m.blockedSent = true 149 } 150 151 func (m *outgoingStreamsMap[T]) GetStream(num protocol.StreamNum) (T, error) { 152 m.mutex.RLock() 153 if num >= m.nextStream { 154 m.mutex.RUnlock() 155 return *new(T), streamError{ 156 message: "peer attempted to open stream %d", 157 nums: []protocol.StreamNum{num}, 158 } 159 } 160 s := m.streams[num] 161 m.mutex.RUnlock() 162 return s, nil 163 } 164 165 func (m *outgoingStreamsMap[T]) DeleteStream(num protocol.StreamNum) error { 166 m.mutex.Lock() 167 defer m.mutex.Unlock() 168 169 if _, ok := m.streams[num]; !ok { 170 return streamError{ 171 message: "tried to delete unknown outgoing stream %d", 172 nums: []protocol.StreamNum{num}, 173 } 174 } 175 delete(m.streams, num) 176 return nil 177 } 178 179 func (m *outgoingStreamsMap[T]) SetMaxStream(num protocol.StreamNum) { 180 m.mutex.Lock() 181 defer m.mutex.Unlock() 182 183 if num <= m.maxStream { 184 return 185 } 186 m.maxStream = num 187 m.capabilityCallback(int64(m.maxStream - m.nextStream)) 188 m.blockedSent = false 189 if m.maxStream < m.nextStream-1+protocol.StreamNum(len(m.openQueue)) { 190 m.maybeSendBlockedFrame() 191 } 192 m.unblockOpenSync() 193 } 194 195 // UpdateSendWindow is called when the peer's transport parameters are received. 196 // Only in the case of a 0-RTT handshake will we have open streams at this point. 197 // We might need to update the send window, in case the server increased it. 198 func (m *outgoingStreamsMap[T]) UpdateSendWindow(limit protocol.ByteCount) { 199 m.mutex.Lock() 200 for _, str := range m.streams { 201 str.updateSendWindow(limit) 202 } 203 m.mutex.Unlock() 204 } 205 206 // unblockOpenSync unblocks the next OpenStreamSync go-routine to open a new stream 207 func (m *outgoingStreamsMap[T]) unblockOpenSync() { 208 if len(m.openQueue) == 0 { 209 return 210 } 211 for qp := m.lowestInQueue; qp <= m.highestInQueue; qp++ { 212 c, ok := m.openQueue[qp] 213 if !ok { // entry was deleted because the context was canceled 214 continue 215 } 216 // unblockOpenSync is called both from OpenStreamSync and from SetMaxStream. 217 // It's sufficient to only unblock OpenStreamSync once. 218 select { 219 case c <- struct{}{}: 220 default: 221 } 222 return 223 } 224 } 225 226 func (m *outgoingStreamsMap[T]) CloseWithError(err error) { 227 m.mutex.Lock() 228 m.closeErr = err 229 for _, str := range m.streams { 230 str.closeForShutdown(err) 231 } 232 for _, c := range m.openQueue { 233 if c != nil { 234 close(c) 235 } 236 } 237 m.mutex.Unlock() 238 }