github.com/datastax/go-cassandra-native-protocol@v0.0.0-20220706104457-5e8aad05cf90/client/inflight.go (about) 1 // Copyright 2020 DataStax 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package client 16 17 import ( 18 "context" 19 "fmt" 20 "sync" 21 "sync/atomic" 22 "time" 23 24 "github.com/rs/zerolog/log" 25 26 "github.com/datastax/go-cassandra-native-protocol/frame" 27 "github.com/datastax/go-cassandra-native-protocol/message" 28 "github.com/datastax/go-cassandra-native-protocol/primitive" 29 ) 30 31 type inFlightRequestsHandler struct { 32 connectionId string 33 ctx context.Context 34 maxInFlight int 35 maxPending int 36 timeout time.Duration 37 streamIds chan int16 38 inFlight map[int16]*inFlightRequest 39 inFlightLock *sync.RWMutex 40 closed int32 41 } 42 43 func (h *inFlightRequestsHandler) String() string { 44 return fmt.Sprintf("%v: [in-flight handler]", h.connectionId) 45 } 46 47 func newInFlightRequestsHandler( 48 connectionId string, 49 ctx context.Context, 50 maxInFlight int, 51 maxPending int, 52 timeout time.Duration, 53 ) *inFlightRequestsHandler { 54 handler := &inFlightRequestsHandler{ 55 connectionId: connectionId, 56 ctx: ctx, 57 maxInFlight: maxInFlight, 58 maxPending: maxPending, 59 timeout: timeout, 60 streamIds: make(chan int16, maxInFlight), 61 inFlight: make(map[int16]*inFlightRequest, maxInFlight), 62 inFlightLock: &sync.RWMutex{}, 63 } 64 for i := 1; i <= maxInFlight; i++ { 65 handler.streamIds <- int16(i) 66 } 67 return handler 68 } 69 70 func (h *inFlightRequestsHandler) onOutgoingFrameEnqueued(f *frame.Frame) (InFlightRequest, error) { 71 if h.isClosed() { 72 return nil, fmt.Errorf("%v: handler closed", h) 73 } 74 var err error 75 streamId := f.Header.StreamId 76 managedStreamId := streamId == ManagedStreamId 77 if managedStreamId { 78 if streamId, err = h.borrowStreamId(); err != nil { 79 return nil, err 80 } else { 81 f.Header.StreamId = streamId 82 } 83 } 84 h.inFlightLock.RLock() 85 if len(h.inFlight) == h.maxInFlight { 86 err = fmt.Errorf("%v: too many in-flight requests: %v", h, h.maxInFlight) 87 } else if _, found := h.inFlight[streamId]; found { 88 err = fmt.Errorf("%v: stream id already in use: %d", h, streamId) 89 } 90 h.inFlightLock.RUnlock() 91 if err == nil { 92 var inFlight *inFlightRequest 93 inFlight, err = h.addInFlight(streamId, managedStreamId) 94 if err == nil { 95 inFlight.startTimeout() 96 return inFlight, nil 97 } 98 } 99 return nil, err 100 } 101 102 func (h *inFlightRequestsHandler) onIncomingFrameReceived(f *frame.Frame) error { 103 if h.isClosed() { 104 return fmt.Errorf("%v: handler closed", h) 105 } 106 streamId := f.Header.StreamId 107 var err error 108 var inFlight *inFlightRequest 109 var found bool 110 h.inFlightLock.RLock() 111 if inFlight, found = h.inFlight[streamId]; !found { 112 err = fmt.Errorf("%v: unknown stream id: %d", h, streamId) 113 } 114 h.inFlightLock.RUnlock() 115 if err == nil { 116 if isLastFrame(f) { 117 h.removeInFlight(streamId) 118 if inFlight.managedStreamId { 119 if err := h.releaseStreamId(streamId); err != nil { 120 return err 121 } 122 } 123 } 124 err = inFlight.onFrameReceived(f) 125 } 126 return err 127 } 128 129 func (h *inFlightRequestsHandler) addInFlight(streamId int16, managedStreamId bool) (*inFlightRequest, error) { 130 inFlight := newInFlightRequest(h.String(), streamId, managedStreamId, h.ctx, h.maxPending, h.timeout) 131 h.inFlightLock.Lock() 132 defer h.inFlightLock.Unlock() 133 if h.isClosed() { 134 return nil, fmt.Errorf("%v: handler closed", h) 135 } 136 h.inFlight[streamId] = inFlight 137 return inFlight, nil 138 } 139 140 func (h *inFlightRequestsHandler) removeInFlight(streamId int16) { 141 h.inFlightLock.Lock() 142 defer h.inFlightLock.Unlock() 143 if _, found := h.inFlight[streamId]; found { 144 delete(h.inFlight, streamId) 145 } 146 } 147 148 func (h *inFlightRequestsHandler) borrowStreamId() (int16, error) { 149 if h.isClosed() { 150 return -1, fmt.Errorf("%v: handler closed", h) 151 } 152 select { 153 case id, ok := <-h.streamIds: 154 if !ok { 155 return -1, fmt.Errorf("%v: handler closed", h) 156 } 157 log.Debug().Msgf("%v: borrowed stream id: %v", h, id) 158 return id, nil 159 default: 160 return -1, fmt.Errorf("%v: no stream id available", h) 161 } 162 } 163 164 func (h *inFlightRequestsHandler) releaseStreamId(id int16) error { 165 if h.isClosed() { 166 return fmt.Errorf("%v: handler closed", h) 167 } 168 select { 169 case h.streamIds <- id: 170 log.Debug().Msgf("%v: released stream id: %v", h, id) 171 return nil 172 default: 173 return fmt.Errorf("%v: stream id %d: release failed", h, id) 174 } 175 } 176 177 func (h *inFlightRequestsHandler) isClosed() bool { 178 return atomic.LoadInt32(&h.closed) == 1 179 } 180 181 func (h *inFlightRequestsHandler) setClosed() bool { 182 return atomic.CompareAndSwapInt32(&h.closed, 0, 1) 183 } 184 185 func (h *inFlightRequestsHandler) close() { 186 if h.setClosed() { 187 log.Trace().Msgf("%v: closing", h) 188 h.inFlightLock.Lock() 189 for streamId, inFlight := range h.inFlight { 190 delete(h.inFlight, streamId) 191 inFlight.close(fmt.Errorf("%v: handler closed", h)) 192 } 193 h.inFlightLock.Unlock() 194 streamIds := h.streamIds 195 h.streamIds = nil 196 close(streamIds) 197 log.Trace().Msgf("%v: successfully closed", h) 198 } 199 } 200 201 type inFlightRequest struct { 202 handlerId string 203 streamId int16 204 managedStreamId bool 205 _incoming chan *frame.Frame // used internally; will be set to nil on close 206 incoming chan *frame.Frame // exposed externally; never nil 207 err error 208 done bool 209 timeout time.Duration 210 ctx context.Context 211 cancel context.CancelFunc 212 timeoutCtx context.Context 213 timeoutCancel context.CancelFunc 214 215 // lock guards the closing of incoming chan and the assignment of done and err; 216 // required to fulfill the interface contract: 217 // if Incoming is closed, IsDone must return true; if it was closed because of an error, 218 // Err must return that error. 219 lock *sync.RWMutex 220 } 221 222 func (r *inFlightRequest) StreamId() int16 { 223 return r.streamId 224 } 225 226 func (r *inFlightRequest) Incoming() <-chan *frame.Frame { 227 r.lock.RLock() 228 defer r.lock.RUnlock() 229 return r.incoming 230 } 231 232 func (r *inFlightRequest) IsDone() bool { 233 r.lock.RLock() 234 defer r.lock.RUnlock() 235 return r.done 236 } 237 238 func (r *inFlightRequest) Err() error { 239 r.lock.RLock() 240 defer r.lock.RUnlock() 241 return r.err 242 } 243 244 func newInFlightRequest( 245 handlerId string, 246 streamId int16, 247 managedStreamId bool, 248 ctx context.Context, 249 maxPending int, 250 timeout time.Duration, 251 ) *inFlightRequest { 252 ctx, cancel := context.WithCancel(ctx) 253 incoming := make(chan *frame.Frame, maxPending) 254 return &inFlightRequest{ 255 handlerId: handlerId, 256 streamId: streamId, 257 managedStreamId: managedStreamId, 258 _incoming: incoming, 259 incoming: incoming, 260 timeout: timeout, 261 ctx: ctx, 262 cancel: cancel, 263 lock: &sync.RWMutex{}, 264 } 265 } 266 267 func (r *inFlightRequest) String() string { 268 return fmt.Sprintf("%v [stream id %d]", r.handlerId, r.streamId) 269 } 270 271 func (r *inFlightRequest) onFrameReceived(f *frame.Frame) error { 272 select { 273 case r._incoming <- f: 274 if isLastFrame(f) { 275 r.stopTimeout() 276 r.close(nil) 277 } else { 278 r.resetTimeout() 279 } 280 return nil 281 case <-r.ctx.Done(): 282 return fmt.Errorf("%v: request closed", r) 283 default: 284 err := fmt.Errorf("%v: too many pending incoming frames: %d", r, len(r.incoming)) 285 r.close(err) 286 return err 287 } 288 } 289 290 func (r *inFlightRequest) startTimeout() { 291 r.timeoutCtx, r.timeoutCancel = context.WithTimeout(r.ctx, r.timeout) 292 log.Trace().Msgf("%v: timeout started", r) 293 go func() { 294 select { 295 case <-r.timeoutCtx.Done(): 296 switch r.timeoutCtx.Err() { 297 case context.DeadlineExceeded: 298 err := fmt.Errorf("%v: timed out waiting for incoming frames", r) 299 r.close(err) 300 case context.Canceled: 301 log.Trace().Msgf("%v: timeout canceled", r) 302 } 303 } 304 }() 305 } 306 307 func (r *inFlightRequest) stopTimeout() { 308 if r.timeoutCancel != nil { 309 r.timeoutCancel() 310 } 311 } 312 313 func (r inFlightRequest) resetTimeout() { 314 r.stopTimeout() 315 r.startTimeout() 316 } 317 318 func (r *inFlightRequest) close(err error) { 319 // need to hold the lock to keep the 3 states in sync: done, incoming and err 320 r.lock.Lock() 321 if !r.done { 322 log.Trace().Msgf("%v: closing", r) 323 r.cancel() 324 // set _incoming to nil first to avoid potential panic in onFrameReceived 325 r._incoming = nil 326 close(r.incoming) 327 r.err = err 328 r.done = true 329 } 330 r.lock.Unlock() 331 log.Trace().Msgf("%v: successfully closed", r) 332 } 333 334 func isLastFrame(f *frame.Frame) bool { 335 if f.Header.OpCode == primitive.OpCodeResult { 336 result := f.Body.Message.(message.Result) 337 if result.GetResultType() == primitive.ResultTypeRows { 338 rows := result.(*message.RowsResult) 339 if rows.Metadata.Flags()&primitive.RowsFlagDseContinuousPaging != 0 { 340 return rows.Metadata.LastContinuousPage 341 } 342 } 343 } 344 return true 345 }