github.com/sagernet/quic-go@v0.43.1-beta.1/ech/streams_map.go (about) 1 package quic 2 3 import ( 4 "context" 5 "errors" 6 "fmt" 7 "net" 8 "sync" 9 10 "github.com/sagernet/quic-go/internal/flowcontrol" 11 "github.com/sagernet/quic-go/internal/protocol" 12 "github.com/sagernet/quic-go/internal/qerr" 13 "github.com/sagernet/quic-go/internal/wire" 14 ) 15 16 type streamError struct { 17 message string 18 nums []protocol.StreamNum 19 } 20 21 func (e streamError) Error() string { 22 return e.message 23 } 24 25 func convertStreamError(err error, stype protocol.StreamType, pers protocol.Perspective) error { 26 strError, ok := err.(streamError) 27 if !ok { 28 return err 29 } 30 ids := make([]interface{}, len(strError.nums)) 31 for i, num := range strError.nums { 32 ids[i] = num.StreamID(stype, pers) 33 } 34 return fmt.Errorf(strError.Error(), ids...) 35 } 36 37 type streamOpenErr struct{ error } 38 39 var _ net.Error = &streamOpenErr{} 40 41 func (e streamOpenErr) Temporary() bool { return e.error == errTooManyOpenStreams } 42 func (streamOpenErr) Timeout() bool { return false } 43 44 // errTooManyOpenStreams is used internally by the outgoing streams maps. 45 var errTooManyOpenStreams = errors.New("too many open streams") 46 47 type streamsMap struct { 48 ctx context.Context // not used for cancellations, but carries the values associated with the connection 49 perspective protocol.Perspective 50 51 maxIncomingBidiStreams uint64 52 maxIncomingUniStreams uint64 53 54 sender streamSender 55 newFlowController func(protocol.StreamID) flowcontrol.StreamFlowController 56 57 mutex sync.Mutex 58 outgoingBidiStreams *outgoingStreamsMap[streamI] 59 outgoingUniStreams *outgoingStreamsMap[sendStreamI] 60 incomingBidiStreams *incomingStreamsMap[streamI] 61 incomingUniStreams *incomingStreamsMap[receiveStreamI] 62 reset bool 63 } 64 65 var _ streamManager = &streamsMap{} 66 67 func newStreamsMap( 68 ctx context.Context, 69 sender streamSender, 70 newFlowController func(protocol.StreamID) flowcontrol.StreamFlowController, 71 maxIncomingBidiStreams uint64, 72 maxIncomingUniStreams uint64, 73 perspective protocol.Perspective, 74 ) streamManager { 75 m := &streamsMap{ 76 ctx: ctx, 77 perspective: perspective, 78 newFlowController: newFlowController, 79 maxIncomingBidiStreams: maxIncomingBidiStreams, 80 maxIncomingUniStreams: maxIncomingUniStreams, 81 sender: sender, 82 } 83 m.initMaps() 84 return m 85 } 86 87 func (m *streamsMap) initMaps() { 88 m.outgoingBidiStreams = newOutgoingStreamsMap( 89 protocol.StreamTypeBidi, 90 func(num protocol.StreamNum) streamI { 91 id := num.StreamID(protocol.StreamTypeBidi, m.perspective) 92 return newStream(m.ctx, id, m.sender, m.newFlowController(id)) 93 }, 94 m.sender.queueControlFrame, 95 ) 96 m.incomingBidiStreams = newIncomingStreamsMap( 97 protocol.StreamTypeBidi, 98 func(num protocol.StreamNum) streamI { 99 id := num.StreamID(protocol.StreamTypeBidi, m.perspective.Opposite()) 100 return newStream(m.ctx, id, m.sender, m.newFlowController(id)) 101 }, 102 m.maxIncomingBidiStreams, 103 m.sender.queueControlFrame, 104 ) 105 m.outgoingUniStreams = newOutgoingStreamsMap( 106 protocol.StreamTypeUni, 107 func(num protocol.StreamNum) sendStreamI { 108 id := num.StreamID(protocol.StreamTypeUni, m.perspective) 109 return newSendStream(m.ctx, id, m.sender, m.newFlowController(id)) 110 }, 111 m.sender.queueControlFrame, 112 ) 113 m.incomingUniStreams = newIncomingStreamsMap( 114 protocol.StreamTypeUni, 115 func(num protocol.StreamNum) receiveStreamI { 116 id := num.StreamID(protocol.StreamTypeUni, m.perspective.Opposite()) 117 return newReceiveStream(id, m.sender, m.newFlowController(id)) 118 }, 119 m.maxIncomingUniStreams, 120 m.sender.queueControlFrame, 121 ) 122 } 123 124 func (m *streamsMap) OpenStream() (Stream, error) { 125 m.mutex.Lock() 126 reset := m.reset 127 mm := m.outgoingBidiStreams 128 m.mutex.Unlock() 129 if reset { 130 return nil, Err0RTTRejected 131 } 132 str, err := mm.OpenStream() 133 return str, convertStreamError(err, protocol.StreamTypeBidi, m.perspective) 134 } 135 136 func (m *streamsMap) OpenStreamSync(ctx context.Context) (Stream, error) { 137 m.mutex.Lock() 138 reset := m.reset 139 mm := m.outgoingBidiStreams 140 m.mutex.Unlock() 141 if reset { 142 return nil, Err0RTTRejected 143 } 144 str, err := mm.OpenStreamSync(ctx) 145 return str, convertStreamError(err, protocol.StreamTypeBidi, m.perspective) 146 } 147 148 func (m *streamsMap) OpenUniStream() (SendStream, error) { 149 m.mutex.Lock() 150 reset := m.reset 151 mm := m.outgoingUniStreams 152 m.mutex.Unlock() 153 if reset { 154 return nil, Err0RTTRejected 155 } 156 str, err := mm.OpenStream() 157 return str, convertStreamError(err, protocol.StreamTypeBidi, m.perspective) 158 } 159 160 func (m *streamsMap) OpenUniStreamSync(ctx context.Context) (SendStream, error) { 161 m.mutex.Lock() 162 reset := m.reset 163 mm := m.outgoingUniStreams 164 m.mutex.Unlock() 165 if reset { 166 return nil, Err0RTTRejected 167 } 168 str, err := mm.OpenStreamSync(ctx) 169 return str, convertStreamError(err, protocol.StreamTypeUni, m.perspective) 170 } 171 172 func (m *streamsMap) AcceptStream(ctx context.Context) (Stream, error) { 173 m.mutex.Lock() 174 reset := m.reset 175 mm := m.incomingBidiStreams 176 m.mutex.Unlock() 177 if reset { 178 return nil, Err0RTTRejected 179 } 180 str, err := mm.AcceptStream(ctx) 181 return str, convertStreamError(err, protocol.StreamTypeBidi, m.perspective.Opposite()) 182 } 183 184 func (m *streamsMap) AcceptUniStream(ctx context.Context) (ReceiveStream, error) { 185 m.mutex.Lock() 186 reset := m.reset 187 mm := m.incomingUniStreams 188 m.mutex.Unlock() 189 if reset { 190 return nil, Err0RTTRejected 191 } 192 str, err := mm.AcceptStream(ctx) 193 return str, convertStreamError(err, protocol.StreamTypeUni, m.perspective.Opposite()) 194 } 195 196 func (m *streamsMap) DeleteStream(id protocol.StreamID) error { 197 num := id.StreamNum() 198 switch id.Type() { 199 case protocol.StreamTypeUni: 200 if id.InitiatedBy() == m.perspective { 201 return convertStreamError(m.outgoingUniStreams.DeleteStream(num), protocol.StreamTypeUni, m.perspective) 202 } 203 return convertStreamError(m.incomingUniStreams.DeleteStream(num), protocol.StreamTypeUni, m.perspective.Opposite()) 204 case protocol.StreamTypeBidi: 205 if id.InitiatedBy() == m.perspective { 206 return convertStreamError(m.outgoingBidiStreams.DeleteStream(num), protocol.StreamTypeBidi, m.perspective) 207 } 208 return convertStreamError(m.incomingBidiStreams.DeleteStream(num), protocol.StreamTypeBidi, m.perspective.Opposite()) 209 } 210 panic("") 211 } 212 213 func (m *streamsMap) GetOrOpenReceiveStream(id protocol.StreamID) (receiveStreamI, error) { 214 str, err := m.getOrOpenReceiveStream(id) 215 if err != nil { 216 return nil, &qerr.TransportError{ 217 ErrorCode: qerr.StreamStateError, 218 ErrorMessage: err.Error(), 219 } 220 } 221 return str, nil 222 } 223 224 func (m *streamsMap) getOrOpenReceiveStream(id protocol.StreamID) (receiveStreamI, error) { 225 num := id.StreamNum() 226 switch id.Type() { 227 case protocol.StreamTypeUni: 228 if id.InitiatedBy() == m.perspective { 229 // an outgoing unidirectional stream is a send stream, not a receive stream 230 return nil, fmt.Errorf("peer attempted to open receive stream %d", id) 231 } 232 str, err := m.incomingUniStreams.GetOrOpenStream(num) 233 return str, convertStreamError(err, protocol.StreamTypeUni, m.perspective) 234 case protocol.StreamTypeBidi: 235 var str receiveStreamI 236 var err error 237 if id.InitiatedBy() == m.perspective { 238 str, err = m.outgoingBidiStreams.GetStream(num) 239 } else { 240 str, err = m.incomingBidiStreams.GetOrOpenStream(num) 241 } 242 return str, convertStreamError(err, protocol.StreamTypeBidi, id.InitiatedBy()) 243 } 244 panic("") 245 } 246 247 func (m *streamsMap) GetOrOpenSendStream(id protocol.StreamID) (sendStreamI, error) { 248 str, err := m.getOrOpenSendStream(id) 249 if err != nil { 250 return nil, &qerr.TransportError{ 251 ErrorCode: qerr.StreamStateError, 252 ErrorMessage: err.Error(), 253 } 254 } 255 return str, nil 256 } 257 258 func (m *streamsMap) getOrOpenSendStream(id protocol.StreamID) (sendStreamI, error) { 259 num := id.StreamNum() 260 switch id.Type() { 261 case protocol.StreamTypeUni: 262 if id.InitiatedBy() == m.perspective { 263 str, err := m.outgoingUniStreams.GetStream(num) 264 return str, convertStreamError(err, protocol.StreamTypeUni, m.perspective) 265 } 266 // an incoming unidirectional stream is a receive stream, not a send stream 267 return nil, fmt.Errorf("peer attempted to open send stream %d", id) 268 case protocol.StreamTypeBidi: 269 var str sendStreamI 270 var err error 271 if id.InitiatedBy() == m.perspective { 272 str, err = m.outgoingBidiStreams.GetStream(num) 273 } else { 274 str, err = m.incomingBidiStreams.GetOrOpenStream(num) 275 } 276 return str, convertStreamError(err, protocol.StreamTypeBidi, id.InitiatedBy()) 277 } 278 panic("") 279 } 280 281 func (m *streamsMap) HandleMaxStreamsFrame(f *wire.MaxStreamsFrame) { 282 switch f.Type { 283 case protocol.StreamTypeUni: 284 m.outgoingUniStreams.SetMaxStream(f.MaxStreamNum) 285 case protocol.StreamTypeBidi: 286 m.outgoingBidiStreams.SetMaxStream(f.MaxStreamNum) 287 } 288 } 289 290 func (m *streamsMap) UpdateLimits(p *wire.TransportParameters) { 291 m.outgoingBidiStreams.UpdateSendWindow(p.InitialMaxStreamDataBidiRemote) 292 m.outgoingBidiStreams.SetMaxStream(p.MaxBidiStreamNum) 293 m.outgoingUniStreams.UpdateSendWindow(p.InitialMaxStreamDataUni) 294 m.outgoingUniStreams.SetMaxStream(p.MaxUniStreamNum) 295 } 296 297 func (m *streamsMap) CloseWithError(err error) { 298 m.outgoingBidiStreams.CloseWithError(err) 299 m.outgoingUniStreams.CloseWithError(err) 300 m.incomingBidiStreams.CloseWithError(err) 301 m.incomingUniStreams.CloseWithError(err) 302 } 303 304 // ResetFor0RTT resets is used when 0-RTT is rejected. In that case, the streams maps are 305 // 1. closed with an Err0RTTRejected, making calls to Open{Uni}Stream{Sync} / Accept{Uni}Stream return that error. 306 // 2. reset to their initial state, such that we can immediately process new incoming stream data. 307 // Afterwards, calls to Open{Uni}Stream{Sync} / Accept{Uni}Stream will continue to return the error, 308 // until UseResetMaps() has been called. 309 func (m *streamsMap) ResetFor0RTT() { 310 m.mutex.Lock() 311 defer m.mutex.Unlock() 312 m.reset = true 313 m.CloseWithError(Err0RTTRejected) 314 m.initMaps() 315 } 316 317 func (m *streamsMap) UseResetMaps() { 318 m.mutex.Lock() 319 m.reset = false 320 m.mutex.Unlock() 321 }