github.com/datastax/go-cassandra-native-protocol@v0.0.0-20220706104457-5e8aad05cf90/client/inflight.go (about)

     1  // Copyright 2020 DataStax
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package client
    16  
    17  import (
    18  	"context"
    19  	"fmt"
    20  	"sync"
    21  	"sync/atomic"
    22  	"time"
    23  
    24  	"github.com/rs/zerolog/log"
    25  
    26  	"github.com/datastax/go-cassandra-native-protocol/frame"
    27  	"github.com/datastax/go-cassandra-native-protocol/message"
    28  	"github.com/datastax/go-cassandra-native-protocol/primitive"
    29  )
    30  
    31  type inFlightRequestsHandler struct {
    32  	connectionId string
    33  	ctx          context.Context
    34  	maxInFlight  int
    35  	maxPending   int
    36  	timeout      time.Duration
    37  	streamIds    chan int16
    38  	inFlight     map[int16]*inFlightRequest
    39  	inFlightLock *sync.RWMutex
    40  	closed       int32
    41  }
    42  
    43  func (h *inFlightRequestsHandler) String() string {
    44  	return fmt.Sprintf("%v: [in-flight handler]", h.connectionId)
    45  }
    46  
    47  func newInFlightRequestsHandler(
    48  	connectionId string,
    49  	ctx context.Context,
    50  	maxInFlight int,
    51  	maxPending int,
    52  	timeout time.Duration,
    53  ) *inFlightRequestsHandler {
    54  	handler := &inFlightRequestsHandler{
    55  		connectionId: connectionId,
    56  		ctx:          ctx,
    57  		maxInFlight:  maxInFlight,
    58  		maxPending:   maxPending,
    59  		timeout:      timeout,
    60  		streamIds:    make(chan int16, maxInFlight),
    61  		inFlight:     make(map[int16]*inFlightRequest, maxInFlight),
    62  		inFlightLock: &sync.RWMutex{},
    63  	}
    64  	for i := 1; i <= maxInFlight; i++ {
    65  		handler.streamIds <- int16(i)
    66  	}
    67  	return handler
    68  }
    69  
    70  func (h *inFlightRequestsHandler) onOutgoingFrameEnqueued(f *frame.Frame) (InFlightRequest, error) {
    71  	if h.isClosed() {
    72  		return nil, fmt.Errorf("%v: handler closed", h)
    73  	}
    74  	var err error
    75  	streamId := f.Header.StreamId
    76  	managedStreamId := streamId == ManagedStreamId
    77  	if managedStreamId {
    78  		if streamId, err = h.borrowStreamId(); err != nil {
    79  			return nil, err
    80  		} else {
    81  			f.Header.StreamId = streamId
    82  		}
    83  	}
    84  	h.inFlightLock.RLock()
    85  	if len(h.inFlight) == h.maxInFlight {
    86  		err = fmt.Errorf("%v: too many in-flight requests: %v", h, h.maxInFlight)
    87  	} else if _, found := h.inFlight[streamId]; found {
    88  		err = fmt.Errorf("%v: stream id already in use: %d", h, streamId)
    89  	}
    90  	h.inFlightLock.RUnlock()
    91  	if err == nil {
    92  		var inFlight *inFlightRequest
    93  		inFlight, err = h.addInFlight(streamId, managedStreamId)
    94  		if err == nil {
    95  			inFlight.startTimeout()
    96  			return inFlight, nil
    97  		}
    98  	}
    99  	return nil, err
   100  }
   101  
   102  func (h *inFlightRequestsHandler) onIncomingFrameReceived(f *frame.Frame) error {
   103  	if h.isClosed() {
   104  		return fmt.Errorf("%v: handler closed", h)
   105  	}
   106  	streamId := f.Header.StreamId
   107  	var err error
   108  	var inFlight *inFlightRequest
   109  	var found bool
   110  	h.inFlightLock.RLock()
   111  	if inFlight, found = h.inFlight[streamId]; !found {
   112  		err = fmt.Errorf("%v: unknown stream id: %d", h, streamId)
   113  	}
   114  	h.inFlightLock.RUnlock()
   115  	if err == nil {
   116  		if isLastFrame(f) {
   117  			h.removeInFlight(streamId)
   118  			if inFlight.managedStreamId {
   119  				if err := h.releaseStreamId(streamId); err != nil {
   120  					return err
   121  				}
   122  			}
   123  		}
   124  		err = inFlight.onFrameReceived(f)
   125  	}
   126  	return err
   127  }
   128  
   129  func (h *inFlightRequestsHandler) addInFlight(streamId int16, managedStreamId bool) (*inFlightRequest, error) {
   130  	inFlight := newInFlightRequest(h.String(), streamId, managedStreamId, h.ctx, h.maxPending, h.timeout)
   131  	h.inFlightLock.Lock()
   132  	defer h.inFlightLock.Unlock()
   133  	if h.isClosed() {
   134  		return nil, fmt.Errorf("%v: handler closed", h)
   135  	}
   136  	h.inFlight[streamId] = inFlight
   137  	return inFlight, nil
   138  }
   139  
   140  func (h *inFlightRequestsHandler) removeInFlight(streamId int16) {
   141  	h.inFlightLock.Lock()
   142  	defer h.inFlightLock.Unlock()
   143  	if _, found := h.inFlight[streamId]; found {
   144  		delete(h.inFlight, streamId)
   145  	}
   146  }
   147  
   148  func (h *inFlightRequestsHandler) borrowStreamId() (int16, error) {
   149  	if h.isClosed() {
   150  		return -1, fmt.Errorf("%v: handler closed", h)
   151  	}
   152  	select {
   153  	case id, ok := <-h.streamIds:
   154  		if !ok {
   155  			return -1, fmt.Errorf("%v: handler closed", h)
   156  		}
   157  		log.Debug().Msgf("%v: borrowed stream id: %v", h, id)
   158  		return id, nil
   159  	default:
   160  		return -1, fmt.Errorf("%v: no stream id available", h)
   161  	}
   162  }
   163  
   164  func (h *inFlightRequestsHandler) releaseStreamId(id int16) error {
   165  	if h.isClosed() {
   166  		return fmt.Errorf("%v: handler closed", h)
   167  	}
   168  	select {
   169  	case h.streamIds <- id:
   170  		log.Debug().Msgf("%v: released stream id: %v", h, id)
   171  		return nil
   172  	default:
   173  		return fmt.Errorf("%v: stream id %d: release failed", h, id)
   174  	}
   175  }
   176  
   177  func (h *inFlightRequestsHandler) isClosed() bool {
   178  	return atomic.LoadInt32(&h.closed) == 1
   179  }
   180  
   181  func (h *inFlightRequestsHandler) setClosed() bool {
   182  	return atomic.CompareAndSwapInt32(&h.closed, 0, 1)
   183  }
   184  
   185  func (h *inFlightRequestsHandler) close() {
   186  	if h.setClosed() {
   187  		log.Trace().Msgf("%v: closing", h)
   188  		h.inFlightLock.Lock()
   189  		for streamId, inFlight := range h.inFlight {
   190  			delete(h.inFlight, streamId)
   191  			inFlight.close(fmt.Errorf("%v: handler closed", h))
   192  		}
   193  		h.inFlightLock.Unlock()
   194  		streamIds := h.streamIds
   195  		h.streamIds = nil
   196  		close(streamIds)
   197  		log.Trace().Msgf("%v: successfully closed", h)
   198  	}
   199  }
   200  
   201  type inFlightRequest struct {
   202  	handlerId       string
   203  	streamId        int16
   204  	managedStreamId bool
   205  	_incoming       chan *frame.Frame // used internally; will be set to nil on close
   206  	incoming        chan *frame.Frame // exposed externally; never nil
   207  	err             error
   208  	done            bool
   209  	timeout         time.Duration
   210  	ctx             context.Context
   211  	cancel          context.CancelFunc
   212  	timeoutCtx      context.Context
   213  	timeoutCancel   context.CancelFunc
   214  
   215  	// lock guards the closing of incoming chan and the assignment of done and err;
   216  	// required to fulfill the interface contract:
   217  	// if Incoming is closed, IsDone must return true; if it was closed because of an error,
   218  	// Err must return that error.
   219  	lock *sync.RWMutex
   220  }
   221  
   222  func (r *inFlightRequest) StreamId() int16 {
   223  	return r.streamId
   224  }
   225  
   226  func (r *inFlightRequest) Incoming() <-chan *frame.Frame {
   227  	r.lock.RLock()
   228  	defer r.lock.RUnlock()
   229  	return r.incoming
   230  }
   231  
   232  func (r *inFlightRequest) IsDone() bool {
   233  	r.lock.RLock()
   234  	defer r.lock.RUnlock()
   235  	return r.done
   236  }
   237  
   238  func (r *inFlightRequest) Err() error {
   239  	r.lock.RLock()
   240  	defer r.lock.RUnlock()
   241  	return r.err
   242  }
   243  
   244  func newInFlightRequest(
   245  	handlerId string,
   246  	streamId int16,
   247  	managedStreamId bool,
   248  	ctx context.Context,
   249  	maxPending int,
   250  	timeout time.Duration,
   251  ) *inFlightRequest {
   252  	ctx, cancel := context.WithCancel(ctx)
   253  	incoming := make(chan *frame.Frame, maxPending)
   254  	return &inFlightRequest{
   255  		handlerId:       handlerId,
   256  		streamId:        streamId,
   257  		managedStreamId: managedStreamId,
   258  		_incoming:       incoming,
   259  		incoming:        incoming,
   260  		timeout:         timeout,
   261  		ctx:             ctx,
   262  		cancel:          cancel,
   263  		lock:            &sync.RWMutex{},
   264  	}
   265  }
   266  
   267  func (r *inFlightRequest) String() string {
   268  	return fmt.Sprintf("%v [stream id %d]", r.handlerId, r.streamId)
   269  }
   270  
   271  func (r *inFlightRequest) onFrameReceived(f *frame.Frame) error {
   272  	select {
   273  	case r._incoming <- f:
   274  		if isLastFrame(f) {
   275  			r.stopTimeout()
   276  			r.close(nil)
   277  		} else {
   278  			r.resetTimeout()
   279  		}
   280  		return nil
   281  	case <-r.ctx.Done():
   282  		return fmt.Errorf("%v: request closed", r)
   283  	default:
   284  		err := fmt.Errorf("%v: too many pending incoming frames: %d", r, len(r.incoming))
   285  		r.close(err)
   286  		return err
   287  	}
   288  }
   289  
   290  func (r *inFlightRequest) startTimeout() {
   291  	r.timeoutCtx, r.timeoutCancel = context.WithTimeout(r.ctx, r.timeout)
   292  	log.Trace().Msgf("%v: timeout started", r)
   293  	go func() {
   294  		select {
   295  		case <-r.timeoutCtx.Done():
   296  			switch r.timeoutCtx.Err() {
   297  			case context.DeadlineExceeded:
   298  				err := fmt.Errorf("%v: timed out waiting for incoming frames", r)
   299  				r.close(err)
   300  			case context.Canceled:
   301  				log.Trace().Msgf("%v: timeout canceled", r)
   302  			}
   303  		}
   304  	}()
   305  }
   306  
   307  func (r *inFlightRequest) stopTimeout() {
   308  	if r.timeoutCancel != nil {
   309  		r.timeoutCancel()
   310  	}
   311  }
   312  
   313  func (r inFlightRequest) resetTimeout() {
   314  	r.stopTimeout()
   315  	r.startTimeout()
   316  }
   317  
   318  func (r *inFlightRequest) close(err error) {
   319  	// need to hold the lock to keep the 3 states in sync: done, incoming and err
   320  	r.lock.Lock()
   321  	if !r.done {
   322  		log.Trace().Msgf("%v: closing", r)
   323  		r.cancel()
   324  		// set _incoming to nil first to avoid potential panic in onFrameReceived
   325  		r._incoming = nil
   326  		close(r.incoming)
   327  		r.err = err
   328  		r.done = true
   329  	}
   330  	r.lock.Unlock()
   331  	log.Trace().Msgf("%v: successfully closed", r)
   332  }
   333  
   334  func isLastFrame(f *frame.Frame) bool {
   335  	if f.Header.OpCode == primitive.OpCodeResult {
   336  		result := f.Body.Message.(message.Result)
   337  		if result.GetResultType() == primitive.ResultTypeRows {
   338  			rows := result.(*message.RowsResult)
   339  			if rows.Metadata.Flags()&primitive.RowsFlagDseContinuousPaging != 0 {
   340  				return rows.Metadata.LastContinuousPage
   341  			}
   342  		}
   343  	}
   344  	return true
   345  }