github.com/KinWaiYuen/client-go/v2@v2.5.4/oracle/oracles/pd.go (about)

     1  // Copyright 2021 TiKV Authors
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  // NOTE: The code in this file is based on code from the
    16  // TiDB project, licensed under the Apache License v 2.0
    17  //
    18  // https://github.com/pingcap/tidb/tree/cc5e161ac06827589c4966674597c137cc9e809c/store/tikv/oracle/oracles/pd.go
    19  //
    20  
    21  // Copyright 2016 PingCAP, Inc.
    22  //
    23  // Licensed under the Apache License, Version 2.0 (the "License");
    24  // you may not use this file except in compliance with the License.
    25  // You may obtain a copy of the License at
    26  //
    27  //     http://www.apache.org/licenses/LICENSE-2.0
    28  //
    29  // Unless required by applicable law or agreed to in writing, software
    30  // distributed under the License is distributed on an "AS IS" BASIS,
    31  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    32  // See the License for the specific language governing permissions and
    33  // limitations under the License.
    34  
    35  package oracles
    36  
    37  import (
    38  	"context"
    39  	"strings"
    40  	"sync"
    41  	"sync/atomic"
    42  	"time"
    43  
    44  	"github.com/KinWaiYuen/client-go/v2/internal/logutil"
    45  	"github.com/KinWaiYuen/client-go/v2/metrics"
    46  	"github.com/KinWaiYuen/client-go/v2/oracle"
    47  	"github.com/pingcap/errors"
    48  	pd "github.com/tikv/pd/client"
    49  	"go.uber.org/zap"
    50  )
    51  
    52  var _ oracle.Oracle = &pdOracle{}
    53  
    54  const slowDist = 30 * time.Millisecond
    55  
    56  // pdOracle is an Oracle that uses a placement driver client as source.
    57  type pdOracle struct {
    58  	c pd.Client
    59  	// txn_scope (string) -> lastTSPointer (*uint64)
    60  	lastTSMap sync.Map
    61  	// txn_scope (string) -> lastArrivalTSPointer (*uint64)
    62  	lastArrivalTSMap sync.Map
    63  	quit             chan struct{}
    64  }
    65  
    66  // NewPdOracle create an Oracle that uses a pd client source.
    67  // Refer https://github.com/tikv/pd/blob/master/client/client.go for more details.
    68  // PdOracle mantains `lastTS` to store the last timestamp got from PD server. If
    69  // `GetTimestamp()` is not called after `updateInterval`, it will be called by
    70  // itself to keep up with the timestamp on PD server.
    71  func NewPdOracle(pdClient pd.Client, updateInterval time.Duration) (oracle.Oracle, error) {
    72  	o := &pdOracle{
    73  		c:    pdClient,
    74  		quit: make(chan struct{}),
    75  	}
    76  	ctx := context.TODO()
    77  	go o.updateTS(ctx, updateInterval)
    78  	// Initialize the timestamp of the global txnScope by Get.
    79  	_, err := o.GetTimestamp(ctx, &oracle.Option{TxnScope: oracle.GlobalTxnScope})
    80  	if err != nil {
    81  		o.Close()
    82  		return nil, errors.Trace(err)
    83  	}
    84  	return o, nil
    85  }
    86  
    87  // IsExpired returns whether lockTS+TTL is expired, both are ms. It uses `lastTS`
    88  // to compare, may return false negative result temporarily.
    89  func (o *pdOracle) IsExpired(lockTS, TTL uint64, opt *oracle.Option) bool {
    90  	lastTS, exist := o.getLastTS(opt.TxnScope)
    91  	if !exist {
    92  		return true
    93  	}
    94  	return oracle.ExtractPhysical(lastTS) >= oracle.ExtractPhysical(lockTS)+int64(TTL)
    95  }
    96  
    97  // GetTimestamp gets a new increasing time.
    98  func (o *pdOracle) GetTimestamp(ctx context.Context, opt *oracle.Option) (uint64, error) {
    99  	ts, err := o.getTimestamp(ctx, opt.TxnScope)
   100  	if err != nil {
   101  		return 0, errors.Trace(err)
   102  	}
   103  	o.setLastTS(ts, opt.TxnScope)
   104  	return ts, nil
   105  }
   106  
   107  type tsFuture struct {
   108  	pd.TSFuture
   109  	o        *pdOracle
   110  	txnScope string
   111  }
   112  
   113  // Wait implements the oracle.Future interface.
   114  func (f *tsFuture) Wait() (uint64, error) {
   115  	now := time.Now()
   116  	physical, logical, err := f.TSFuture.Wait()
   117  	metrics.TiKVTSFutureWaitDuration.Observe(time.Since(now).Seconds())
   118  	if err != nil {
   119  		return 0, errors.Trace(err)
   120  	}
   121  	ts := oracle.ComposeTS(physical, logical)
   122  	f.o.setLastTS(ts, f.txnScope)
   123  	return ts, nil
   124  }
   125  
   126  func (o *pdOracle) GetTimestampAsync(ctx context.Context, opt *oracle.Option) oracle.Future {
   127  	var ts pd.TSFuture
   128  	if opt.TxnScope == oracle.GlobalTxnScope || opt.TxnScope == "" {
   129  		ts = o.c.GetTSAsync(ctx)
   130  	} else {
   131  		ts = o.c.GetLocalTSAsync(ctx, opt.TxnScope)
   132  	}
   133  	return &tsFuture{ts, o, opt.TxnScope}
   134  }
   135  
   136  func (o *pdOracle) getTimestamp(ctx context.Context, txnScope string) (uint64, error) {
   137  	now := time.Now()
   138  	var (
   139  		physical, logical int64
   140  		err               error
   141  	)
   142  	if txnScope == oracle.GlobalTxnScope || txnScope == "" {
   143  		physical, logical, err = o.c.GetTS(ctx)
   144  	} else {
   145  		physical, logical, err = o.c.GetLocalTS(ctx, txnScope)
   146  	}
   147  	if err != nil {
   148  		return 0, errors.Trace(err)
   149  	}
   150  	dist := time.Since(now)
   151  	if dist > slowDist {
   152  		logutil.Logger(ctx).Warn("get timestamp too slow",
   153  			zap.Duration("cost time", dist))
   154  	}
   155  	return oracle.ComposeTS(physical, logical), nil
   156  }
   157  
   158  func (o *pdOracle) getArrivalTimestamp() uint64 {
   159  	return oracle.GoTimeToTS(time.Now())
   160  }
   161  
   162  func (o *pdOracle) setLastTS(ts uint64, txnScope string) {
   163  	if txnScope == "" {
   164  		txnScope = oracle.GlobalTxnScope
   165  	}
   166  	lastTSInterface, ok := o.lastTSMap.Load(txnScope)
   167  	if !ok {
   168  		lastTSInterface, _ = o.lastTSMap.LoadOrStore(txnScope, new(uint64))
   169  	}
   170  	lastTSPointer := lastTSInterface.(*uint64)
   171  	for {
   172  		lastTS := atomic.LoadUint64(lastTSPointer)
   173  		if ts <= lastTS {
   174  			return
   175  		}
   176  		if atomic.CompareAndSwapUint64(lastTSPointer, lastTS, ts) {
   177  			break
   178  		}
   179  	}
   180  	o.setLastArrivalTS(o.getArrivalTimestamp(), txnScope)
   181  }
   182  
   183  func (o *pdOracle) setLastArrivalTS(ts uint64, txnScope string) {
   184  	if txnScope == "" {
   185  		txnScope = oracle.GlobalTxnScope
   186  	}
   187  	lastTSInterface, ok := o.lastArrivalTSMap.Load(txnScope)
   188  	if !ok {
   189  		lastTSInterface, _ = o.lastArrivalTSMap.LoadOrStore(txnScope, new(uint64))
   190  	}
   191  	lastTSPointer := lastTSInterface.(*uint64)
   192  	for {
   193  		lastTS := atomic.LoadUint64(lastTSPointer)
   194  		if ts <= lastTS {
   195  			return
   196  		}
   197  		if atomic.CompareAndSwapUint64(lastTSPointer, lastTS, ts) {
   198  			return
   199  		}
   200  	}
   201  }
   202  
   203  func (o *pdOracle) getLastTS(txnScope string) (uint64, bool) {
   204  	if txnScope == "" {
   205  		txnScope = oracle.GlobalTxnScope
   206  	}
   207  	lastTSInterface, ok := o.lastTSMap.Load(txnScope)
   208  	if !ok {
   209  		return 0, false
   210  	}
   211  	return atomic.LoadUint64(lastTSInterface.(*uint64)), true
   212  }
   213  
   214  func (o *pdOracle) getLastArrivalTS(txnScope string) (uint64, bool) {
   215  	if txnScope == "" {
   216  		txnScope = oracle.GlobalTxnScope
   217  	}
   218  	lastArrivalTSInterface, ok := o.lastArrivalTSMap.Load(txnScope)
   219  	if !ok {
   220  		return 0, false
   221  	}
   222  	return atomic.LoadUint64(lastArrivalTSInterface.(*uint64)), true
   223  }
   224  
   225  func (o *pdOracle) updateTS(ctx context.Context, interval time.Duration) {
   226  	ticker := time.NewTicker(interval)
   227  	defer ticker.Stop()
   228  	for {
   229  		select {
   230  		case <-ticker.C:
   231  			// Update the timestamp for each txnScope
   232  			o.lastTSMap.Range(func(key, _ interface{}) bool {
   233  				txnScope := key.(string)
   234  				ts, err := o.getTimestamp(ctx, txnScope)
   235  				if err != nil {
   236  					logutil.Logger(ctx).Error("updateTS error", zap.String("txnScope", txnScope), zap.Error(err))
   237  					return true
   238  				}
   239  				o.setLastTS(ts, txnScope)
   240  				return true
   241  			})
   242  		case <-o.quit:
   243  			return
   244  		}
   245  	}
   246  }
   247  
   248  // UntilExpired implement oracle.Oracle interface.
   249  func (o *pdOracle) UntilExpired(lockTS uint64, TTL uint64, opt *oracle.Option) int64 {
   250  	lastTS, ok := o.getLastTS(opt.TxnScope)
   251  	if !ok {
   252  		return 0
   253  	}
   254  	return oracle.ExtractPhysical(lockTS) + int64(TTL) - oracle.ExtractPhysical(lastTS)
   255  }
   256  
   257  func (o *pdOracle) Close() {
   258  	close(o.quit)
   259  }
   260  
   261  // A future that resolves immediately to a low resolution timestamp.
   262  type lowResolutionTsFuture struct {
   263  	ts  uint64
   264  	err error
   265  }
   266  
   267  // Wait implements the oracle.Future interface.
   268  func (f lowResolutionTsFuture) Wait() (uint64, error) {
   269  	return f.ts, f.err
   270  }
   271  
   272  // GetLowResolutionTimestamp gets a new increasing time.
   273  func (o *pdOracle) GetLowResolutionTimestamp(ctx context.Context, opt *oracle.Option) (uint64, error) {
   274  	lastTS, ok := o.getLastTS(opt.TxnScope)
   275  	if !ok {
   276  		return 0, errors.Errorf("get low resolution timestamp fail, invalid txnScope = %s", opt.TxnScope)
   277  	}
   278  	return lastTS, nil
   279  }
   280  
   281  func (o *pdOracle) GetLowResolutionTimestampAsync(ctx context.Context, opt *oracle.Option) oracle.Future {
   282  	lastTS, ok := o.getLastTS(opt.TxnScope)
   283  	if !ok {
   284  		return lowResolutionTsFuture{
   285  			ts:  0,
   286  			err: errors.Errorf("get low resolution timestamp async fail, invalid txnScope = %s", opt.TxnScope),
   287  		}
   288  	}
   289  	return lowResolutionTsFuture{
   290  		ts:  lastTS,
   291  		err: nil,
   292  	}
   293  }
   294  
   295  func (o *pdOracle) getStaleTimestamp(txnScope string, prevSecond uint64) (uint64, error) {
   296  	ts, ok := o.getLastTS(txnScope)
   297  	if !ok {
   298  		return 0, errors.Errorf("get stale timestamp fail, txnScope: %s", txnScope)
   299  	}
   300  	arrivalTS, ok := o.getLastArrivalTS(txnScope)
   301  	if !ok {
   302  		return 0, errors.Errorf("get stale arrival timestamp fail, txnScope: %s", txnScope)
   303  	}
   304  	arrivalTime := oracle.GetTimeFromTS(arrivalTS)
   305  	physicalTime := oracle.GetTimeFromTS(ts)
   306  	if uint64(physicalTime.Unix()) <= prevSecond {
   307  		return 0, errors.Errorf("invalid prevSecond %v", prevSecond)
   308  	}
   309  
   310  	staleTime := physicalTime.Add(-arrivalTime.Sub(time.Now().Add(-time.Duration(prevSecond) * time.Second)))
   311  
   312  	return oracle.GoTimeToTS(staleTime), nil
   313  }
   314  
   315  // GetStaleTimestamp generate a TSO which represents for the TSO prevSecond secs ago.
   316  func (o *pdOracle) GetStaleTimestamp(ctx context.Context, txnScope string, prevSecond uint64) (ts uint64, err error) {
   317  	ts, err = o.getStaleTimestamp(txnScope, prevSecond)
   318  	if err != nil {
   319  		if !strings.HasPrefix(err.Error(), "invalid prevSecond") {
   320  			// If any error happened, we will try to fetch tso and set it as last ts.
   321  			_, tErr := o.GetTimestamp(ctx, &oracle.Option{TxnScope: txnScope})
   322  			if tErr != nil {
   323  				return 0, errors.Trace(tErr)
   324  			}
   325  		}
   326  		return 0, errors.Trace(err)
   327  	}
   328  	return ts, nil
   329  }