go.temporal.io/server@v1.23.0/common/persistence/nosql/nosqlplugin/cassandra/gocql/session.go (about)

     1  // The MIT License
     2  //
     3  // Copyright (c) 2020 Temporal Technologies Inc.  All rights reserved.
     4  //
     5  // Copyright (c) 2020 Uber Technologies, Inc.
     6  //
     7  // Permission is hereby granted, free of charge, to any person obtaining a copy
     8  // of this software and associated documentation files (the "Software"), to deal
     9  // in the Software without restriction, including without limitation the rights
    10  // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
    11  // copies of the Software, and to permit persons to whom the Software is
    12  // furnished to do so, subject to the following conditions:
    13  //
    14  // The above copyright notice and this permission notice shall be included in
    15  // all copies or substantial portions of the Software.
    16  //
    17  // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    18  // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    19  // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    20  // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    21  // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    22  // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
    23  // THE SOFTWARE.
    24  
    25  package gocql
    26  
    27  import (
    28  	"context"
    29  	"sync"
    30  	"sync/atomic"
    31  	"syscall"
    32  	"time"
    33  
    34  	"github.com/gocql/gocql"
    35  
    36  	"go.temporal.io/server/common"
    37  	"go.temporal.io/server/common/log"
    38  	"go.temporal.io/server/common/log/tag"
    39  	"go.temporal.io/server/common/metrics"
    40  )
    41  
    42  var _ Session = (*session)(nil)
    43  
    44  const (
    45  	sessionRefreshMinInternal = 5 * time.Second
    46  )
    47  
    48  const (
    49  	refreshThrottleTagValue = "throttle"
    50  	refreshErrorTagValue    = "error"
    51  )
    52  
    53  type (
    54  	session struct {
    55  		status               int32
    56  		newClusterConfigFunc func() (*gocql.ClusterConfig, error)
    57  		atomic.Value         // *gocql.Session
    58  		logger               log.Logger
    59  
    60  		sync.Mutex
    61  		sessionInitTime time.Time
    62  		metricsHandler  metrics.Handler
    63  	}
    64  )
    65  
    66  func NewSession(
    67  	newClusterConfigFunc func() (*gocql.ClusterConfig, error),
    68  	logger log.Logger,
    69  	metricsHandler metrics.Handler,
    70  ) (*session, error) {
    71  
    72  	gocqlSession, err := initSession(newClusterConfigFunc, metricsHandler)
    73  	if err != nil {
    74  		return nil, err
    75  	}
    76  
    77  	session := &session{
    78  		status:               common.DaemonStatusStarted,
    79  		newClusterConfigFunc: newClusterConfigFunc,
    80  		logger:               logger,
    81  		metricsHandler:       metricsHandler,
    82  
    83  		sessionInitTime: time.Now().UTC(),
    84  	}
    85  	session.Value.Store(gocqlSession)
    86  	return session, nil
    87  }
    88  
    89  func (s *session) refresh() {
    90  	if atomic.LoadInt32(&s.status) != common.DaemonStatusStarted {
    91  		return
    92  	}
    93  
    94  	s.Lock()
    95  	defer s.Unlock()
    96  
    97  	if time.Now().UTC().Sub(s.sessionInitTime) < sessionRefreshMinInternal {
    98  		s.logger.Warn("gocql wrapper: did not refresh gocql session because the last refresh was too close",
    99  			tag.NewDurationTag("min_refresh_interval_seconds", sessionRefreshMinInternal))
   100  		handler := s.metricsHandler.WithTags(metrics.FailureTag(refreshThrottleTagValue))
   101  		handler.Counter(metrics.CassandraSessionRefreshFailures.Name()).Record(1)
   102  		return
   103  	}
   104  
   105  	newSession, err := initSession(s.newClusterConfigFunc, s.metricsHandler)
   106  	if err != nil {
   107  		s.logger.Error("gocql wrapper: unable to refresh gocql session", tag.Error(err))
   108  		handler := s.metricsHandler.WithTags(metrics.FailureTag(refreshErrorTagValue))
   109  		handler.Counter(metrics.CassandraSessionRefreshFailures.Name()).Record(1)
   110  		return
   111  	}
   112  
   113  	s.sessionInitTime = time.Now().UTC()
   114  	oldSession := s.Value.Load().(*gocql.Session)
   115  	s.Value.Store(newSession)
   116  	go oldSession.Close()
   117  	s.logger.Warn("gocql wrapper: successfully refreshed gocql session")
   118  }
   119  
   120  func initSession(
   121  	newClusterConfigFunc func() (*gocql.ClusterConfig, error),
   122  	metricsHandler metrics.Handler,
   123  ) (*gocql.Session, error) {
   124  	cluster, err := newClusterConfigFunc()
   125  	if err != nil {
   126  		return nil, err
   127  	}
   128  	start := time.Now()
   129  	defer func() {
   130  		metricsHandler.Timer(metrics.CassandraInitSessionLatency.Name()).Record(time.Since(start))
   131  	}()
   132  	return cluster.CreateSession()
   133  }
   134  
   135  func (s *session) Query(
   136  	stmt string,
   137  	values ...interface{},
   138  ) Query {
   139  	q := s.Value.Load().(*gocql.Session).Query(stmt, values...)
   140  	if q == nil {
   141  		return nil
   142  	}
   143  
   144  	return &query{
   145  		session:    s,
   146  		gocqlQuery: q,
   147  	}
   148  }
   149  
   150  func (s *session) NewBatch(
   151  	batchType BatchType,
   152  ) Batch {
   153  	b := s.Value.Load().(*gocql.Session).NewBatch(mustConvertBatchType(batchType))
   154  	if b == nil {
   155  		return nil
   156  	}
   157  	return &batch{
   158  		session:    s,
   159  		gocqlBatch: b,
   160  	}
   161  }
   162  
   163  func (s *session) ExecuteBatch(
   164  	b Batch,
   165  ) (retError error) {
   166  	defer func() { s.handleError(retError) }()
   167  
   168  	return s.Value.Load().(*gocql.Session).ExecuteBatch(b.(*batch).gocqlBatch)
   169  }
   170  
   171  func (s *session) MapExecuteBatchCAS(
   172  	b Batch,
   173  	previous map[string]interface{},
   174  ) (_ bool, _ Iter, retError error) {
   175  	defer func() { s.handleError(retError) }()
   176  
   177  	applied, iter, err := s.Value.Load().(*gocql.Session).MapExecuteBatchCAS(b.(*batch).gocqlBatch, previous)
   178  	return applied, iter, err
   179  }
   180  
   181  func (s *session) AwaitSchemaAgreement(
   182  	ctx context.Context,
   183  ) (retError error) {
   184  	defer func() { s.handleError(retError) }()
   185  
   186  	return s.Value.Load().(*gocql.Session).AwaitSchemaAgreement(ctx)
   187  }
   188  
   189  func (s *session) Close() {
   190  	if !atomic.CompareAndSwapInt32(
   191  		&s.status,
   192  		common.DaemonStatusStarted,
   193  		common.DaemonStatusStopped,
   194  	) {
   195  		return
   196  	}
   197  	s.Value.Load().(*gocql.Session).Close()
   198  }
   199  
   200  func (s *session) handleError(
   201  	err error,
   202  ) {
   203  	switch err {
   204  	case gocql.ErrNoConnections,
   205  		gocql.ErrSessionClosed,
   206  		gocql.ErrConnectionClosed,
   207  		syscall.ECONNRESET:
   208  		s.refresh()
   209  	default:
   210  		// noop
   211  	}
   212  }