github.com/decred/dcrlnd@v0.7.6/kvdb/etcd/stm.go (about)

     1  //go:build kvdb_etcd
     2  // +build kvdb_etcd
     3  
     4  package etcd
     5  
     6  import (
     7  	"context"
     8  	"fmt"
     9  	"math"
    10  	"strings"
    11  
    12  	"github.com/google/btree"
    13  	pb "go.etcd.io/etcd/api/v3/etcdserverpb"
    14  	v3 "go.etcd.io/etcd/client/v3"
    15  )
    16  
    17  type CommitStats struct {
    18  	Rset    int
    19  	Wset    int
    20  	Retries int
    21  }
    22  
    23  // KV stores a key/value pair.
    24  type KV struct {
    25  	key string
    26  	val string
    27  }
    28  
    29  // STM is an interface for software transactional memory.
    30  // All calls that return error will do so only if STM is manually handled and
    31  // abort the apply closure otherwise. In both case the returned error is a
    32  // DatabaseError.
    33  type STM interface {
    34  	// Get returns the value for a key and inserts the key in the txn's read
    35  	// set. Returns nil if there's no matching key, or the key is empty.
    36  	Get(key string) ([]byte, error)
    37  
    38  	// Put adds a value for a key to the txn's write set.
    39  	Put(key, val string)
    40  
    41  	// Del adds a delete operation for the key to the txn's write set.
    42  	Del(key string)
    43  
    44  	// First returns the first k/v that begins with prefix or nil if there's
    45  	// no such k/v pair. If the key is found it is inserted to the txn's
    46  	// read set. Returns nil if there's no match.
    47  	First(prefix string) (*KV, error)
    48  
    49  	// Last returns the last k/v that begins with prefix or nil if there's
    50  	// no such k/v pair. If the key is found it is inserted to the txn's
    51  	// read set. Returns nil if there's no match.
    52  	Last(prefix string) (*KV, error)
    53  
    54  	// Prev returns the previous k/v before key that begins with prefix or
    55  	// nil if there's no such k/v. If the key is found it is inserted to the
    56  	// read set. Returns nil if there's no match.
    57  	Prev(prefix, key string) (*KV, error)
    58  
    59  	// Next returns the next k/v after key that begins with prefix or nil
    60  	// if there's no such k/v. If the key is found it is inserted to the
    61  	// txn's read set. Returns nil if there's no match.
    62  	Next(prefix, key string) (*KV, error)
    63  
    64  	// Seek will return k/v at key beginning with prefix. If the key doesn't
    65  	// exists Seek will return the next k/v after key beginning with prefix.
    66  	// If a matching k/v is found it is inserted to the txn's read set. Returns
    67  	// nil if there's no match.
    68  	Seek(prefix, key string) (*KV, error)
    69  
    70  	// OnCommit calls the passed callback func upon commit.
    71  	OnCommit(func())
    72  
    73  	// Commit attempts to apply the txn's changes to the server.
    74  	// Commit may return CommitError if transaction is outdated and needs retry.
    75  	Commit() error
    76  
    77  	// Rollback entries the read and write sets such that a subsequent commit
    78  	// won't alter the database.
    79  	Rollback()
    80  
    81  	// Prefetch prefetches the passed keys and prefixes. For prefixes it'll
    82  	// fetch the whole range.
    83  	Prefetch(keys []string, prefix []string)
    84  }
    85  
    86  // CommitError is used to check if there was an error
    87  // due to stale data in the transaction.
    88  type CommitError struct{}
    89  
    90  // Error returns a static string for CommitError for
    91  // debugging/logging purposes.
    92  func (e CommitError) Error() string {
    93  	return "commit failed"
    94  }
    95  
    96  // DatabaseError is used to wrap errors that are not
    97  // related to stale data in the transaction.
    98  type DatabaseError struct {
    99  	msg string
   100  	err error
   101  }
   102  
   103  // Unwrap returns the wrapped error in a DatabaseError.
   104  func (e *DatabaseError) Unwrap() error {
   105  	return e.err
   106  }
   107  
   108  // Error simply converts DatabaseError to a string that
   109  // includes both the message and the wrapped error.
   110  func (e DatabaseError) Error() string {
   111  	return fmt.Sprintf("etcd error: %v - %v", e.msg, e.err)
   112  }
   113  
   114  // stmGet is the result of a read operation, a value and the mod revision of the
   115  // key/value.
   116  type stmGet struct {
   117  	KV
   118  	rev int64
   119  }
   120  
   121  // Less implements less operator for btree.BTree.
   122  func (c *stmGet) Less(than btree.Item) bool {
   123  	return c.key < than.(*stmGet).key
   124  }
   125  
   126  // readSet stores all reads done in an STM.
   127  type readSet struct {
   128  	// tree stores the items in the read set.
   129  	tree *btree.BTree
   130  
   131  	// fullRanges stores full range prefixes.
   132  	fullRanges map[string]struct{}
   133  }
   134  
   135  // stmPut stores a value and an operation (put/delete).
   136  type stmPut struct {
   137  	val string
   138  	op  v3.Op
   139  }
   140  
   141  // writeSet stroes all writes done in an STM.
   142  type writeSet map[string]stmPut
   143  
   144  // stm implements repeatable-read software transactional memory
   145  // over etcd.
   146  type stm struct {
   147  	// client is an etcd client handling all RPC communications
   148  	// to the etcd instance/cluster.
   149  	client *v3.Client
   150  
   151  	// manual is set to true for manual transactions which don't
   152  	// execute in the STM run loop.
   153  	manual bool
   154  
   155  	// txQueue is lightweight contention manager, which is used to detect
   156  	// transaction conflicts and reduce retries.
   157  	txQueue *commitQueue
   158  
   159  	// options stores optional settings passed by the user.
   160  	options *STMOptions
   161  
   162  	// rset holds read key values and revisions.
   163  	rset *readSet
   164  
   165  	// wset holds overwritten keys and their values.
   166  	wset writeSet
   167  
   168  	// getOpts are the opts used for gets.
   169  	getOpts []v3.OpOption
   170  
   171  	// revision stores the snapshot revision after first read.
   172  	revision int64
   173  
   174  	// onCommit gets called upon commit.
   175  	onCommit func()
   176  
   177  	// callCount tracks the number of times we called into etcd.
   178  	callCount int
   179  }
   180  
   181  // STMOptions can be used to pass optional settings
   182  // when an STM is created.
   183  type STMOptions struct {
   184  	// ctx holds an externally provided abort context.
   185  	ctx                 context.Context
   186  	commitStatsCallback func(bool, CommitStats)
   187  }
   188  
   189  // STMOptionFunc is a function that updates the passed STMOptions.
   190  type STMOptionFunc func(*STMOptions)
   191  
   192  // WithAbortContext specifies the context for permanently
   193  // aborting the transaction.
   194  func WithAbortContext(ctx context.Context) STMOptionFunc {
   195  	return func(so *STMOptions) {
   196  		so.ctx = ctx
   197  	}
   198  }
   199  
   200  func WithCommitStatsCallback(cb func(bool, CommitStats)) STMOptionFunc {
   201  	return func(so *STMOptions) {
   202  		so.commitStatsCallback = cb
   203  	}
   204  }
   205  
   206  // RunSTM runs the apply function by creating an STM using serializable snapshot
   207  // isolation, passing it to the apply and handling commit errors and retries.
   208  func RunSTM(cli *v3.Client, apply func(STM) error, txQueue *commitQueue,
   209  	so ...STMOptionFunc) (int, error) {
   210  
   211  	stm := makeSTM(cli, false, txQueue, so...)
   212  	err := runSTM(stm, apply)
   213  
   214  	return stm.callCount, err
   215  }
   216  
   217  // NewSTM creates a new STM instance, using serializable snapshot isolation.
   218  func NewSTM(cli *v3.Client, txQueue *commitQueue, so ...STMOptionFunc) STM {
   219  	return makeSTM(cli, true, txQueue, so...)
   220  }
   221  
   222  // makeSTM is the actual constructor of the stm. It first apply all passed
   223  // options then creates the stm object and resets it before returning.
   224  func makeSTM(cli *v3.Client, manual bool, txQueue *commitQueue,
   225  	so ...STMOptionFunc) *stm {
   226  
   227  	opts := &STMOptions{
   228  		ctx: cli.Ctx(),
   229  	}
   230  
   231  	// Apply all functional options.
   232  	for _, fo := range so {
   233  		fo(opts)
   234  	}
   235  
   236  	s := &stm{
   237  		client:  cli,
   238  		manual:  manual,
   239  		txQueue: txQueue,
   240  		options: opts,
   241  		rset:    newReadSet(),
   242  	}
   243  
   244  	// Reset read and write set.
   245  	s.rollback(true)
   246  
   247  	return s
   248  }
   249  
   250  // runSTM implements the run loop of the STM, running the apply func, catching
   251  // errors and handling commit. The loop will quit on every error except
   252  // CommitError which is used to indicate a necessary retry.
   253  func runSTM(s *stm, apply func(STM) error) error {
   254  	var (
   255  		retries    int
   256  		stats      CommitStats
   257  		executeErr error
   258  	)
   259  
   260  	done := make(chan struct{})
   261  
   262  	execute := func() {
   263  		defer close(done)
   264  
   265  		for {
   266  			select {
   267  			// Check if the STM is aborted and break the retry loop
   268  			// if it is.
   269  			case <-s.options.ctx.Done():
   270  				executeErr = fmt.Errorf("aborted")
   271  				return
   272  
   273  			default:
   274  			}
   275  
   276  			stats, executeErr = s.commit()
   277  
   278  			// Re-apply only upon commit error (meaning the
   279  			// keys were changed).
   280  			if _, ok := executeErr.(CommitError); !ok {
   281  				// Anything that's not a CommitError
   282  				// aborts the transaction.
   283  				return
   284  			}
   285  
   286  			// Rollback the write set before trying to re-apply.
   287  			// Upon commit we retrieved the latest version of all
   288  			// previously fetched keys and ranges so we don't need
   289  			// to rollback the read set.
   290  			s.rollback(false)
   291  			retries++
   292  
   293  			// Re-apply the transaction closure.
   294  			if executeErr = apply(s); executeErr != nil {
   295  				return
   296  			}
   297  		}
   298  	}
   299  
   300  	// Run the tx closure to construct the read and write sets.
   301  	// Also we expect that if there are no conflicting transactions
   302  	// in the queue, then we only run apply once.
   303  	if preApplyErr := apply(s); preApplyErr != nil {
   304  		return preApplyErr
   305  	}
   306  
   307  	// Make a copy of the read/write set keys here. The reason why we need
   308  	// to do this is because subsequent applies may change (shrink) these
   309  	// sets and so when we decrease reference counts in the commit queue in
   310  	// done(...) we'd potentially miss removing references which would
   311  	// result in queueing up transactions and contending DB access.
   312  	// Copying these strings is cheap due to Go's immutable string which is
   313  	// always a reference.
   314  	rkeys := make([]string, s.rset.tree.Len())
   315  	wkeys := make([]string, len(s.wset))
   316  
   317  	i := 0
   318  	s.rset.tree.Ascend(func(item btree.Item) bool {
   319  		rkeys[i] = item.(*stmGet).key
   320  		i++
   321  
   322  		return true
   323  	})
   324  
   325  	i = 0
   326  	for key := range s.wset {
   327  		wkeys[i] = key
   328  		i++
   329  	}
   330  
   331  	// Queue up the transaction for execution.
   332  	s.txQueue.Add(execute, rkeys, wkeys)
   333  
   334  	// Wait for the transaction to execute, or break if aborted.
   335  	select {
   336  	case <-done:
   337  	case <-s.options.ctx.Done():
   338  		return context.Canceled
   339  	}
   340  
   341  	if s.options.commitStatsCallback != nil {
   342  		stats.Retries = retries
   343  		s.options.commitStatsCallback(executeErr == nil, stats)
   344  	}
   345  
   346  	return executeErr
   347  }
   348  
   349  func newReadSet() *readSet {
   350  	return &readSet{
   351  		tree:       btree.New(5),
   352  		fullRanges: make(map[string]struct{}),
   353  	}
   354  }
   355  
   356  // add inserts key/values to to read set.
   357  func (rs *readSet) add(responses []*pb.ResponseOp) {
   358  	for _, resp := range responses {
   359  		getResp := resp.GetResponseRange()
   360  		for _, kv := range getResp.Kvs {
   361  			rs.addItem(
   362  				string(kv.Key), string(kv.Value), kv.ModRevision,
   363  			)
   364  		}
   365  	}
   366  }
   367  
   368  // addFullRange adds all full ranges to the read set.
   369  func (rs *readSet) addFullRange(prefixes []string, responses []*pb.ResponseOp) {
   370  	for i, resp := range responses {
   371  		getResp := resp.GetResponseRange()
   372  		for _, kv := range getResp.Kvs {
   373  			rs.addItem(
   374  				string(kv.Key), string(kv.Value), kv.ModRevision,
   375  			)
   376  		}
   377  
   378  		rs.fullRanges[prefixes[i]] = struct{}{}
   379  	}
   380  }
   381  
   382  // presetItem presets a key to zero revision if not already present in the read
   383  // set.
   384  func (rs *readSet) presetItem(key string) {
   385  	item := &stmGet{
   386  		KV: KV{
   387  			key: key,
   388  		},
   389  		rev: 0,
   390  	}
   391  
   392  	if !rs.tree.Has(item) {
   393  		rs.tree.ReplaceOrInsert(item)
   394  	}
   395  }
   396  
   397  // addItem adds a single new key/value to the read set (if not already present).
   398  func (rs *readSet) addItem(key, val string, modRevision int64) {
   399  	item := &stmGet{
   400  		KV: KV{
   401  			key: key,
   402  			val: val,
   403  		},
   404  		rev: modRevision,
   405  	}
   406  
   407  	rs.tree.ReplaceOrInsert(item)
   408  }
   409  
   410  // hasFullRange checks if the read set has a full range prefetched.
   411  func (rs *readSet) hasFullRange(prefix string) bool {
   412  	_, ok := rs.fullRanges[prefix]
   413  	return ok
   414  }
   415  
   416  // next returns the pre-fetched next value of the prefix. If matchKey is true,
   417  // it'll simply return the key/value that matches the passed key.
   418  func (rs *readSet) next(prefix, key string, matchKey bool) (*stmGet, bool) {
   419  	pivot := &stmGet{
   420  		KV: KV{
   421  			key: key,
   422  		},
   423  	}
   424  
   425  	var result *stmGet
   426  	rs.tree.AscendGreaterOrEqual(
   427  		pivot,
   428  		func(item btree.Item) bool {
   429  			next := item.(*stmGet)
   430  			if (!matchKey && next.key == key) || next.rev == 0 {
   431  				return true
   432  			}
   433  
   434  			if strings.HasPrefix(next.key, prefix) {
   435  				result = next
   436  			}
   437  
   438  			return false
   439  		},
   440  	)
   441  
   442  	return result, result != nil
   443  }
   444  
   445  // prev returns the pre-fetched prev key/value of the prefix from key.
   446  func (rs *readSet) prev(prefix, key string) (*stmGet, bool) {
   447  	pivot := &stmGet{
   448  		KV: KV{
   449  			key: key,
   450  		},
   451  	}
   452  
   453  	var result *stmGet
   454  
   455  	rs.tree.DescendLessOrEqual(
   456  		pivot, func(item btree.Item) bool {
   457  			prev := item.(*stmGet)
   458  			if prev.key == key || prev.rev == 0 {
   459  				return true
   460  			}
   461  
   462  			if strings.HasPrefix(prev.key, prefix) {
   463  				result = prev
   464  			}
   465  
   466  			return false
   467  		},
   468  	)
   469  
   470  	return result, result != nil
   471  }
   472  
   473  // last returns the last key/value of the passed range (if prefetched).
   474  func (rs *readSet) last(prefix string) (*stmGet, bool) {
   475  	// We create an artificial key here that is just one step away from the
   476  	// prefix. This way when we try to get the first item with our prefix
   477  	// before this newly crafted key we'll make sure it's the last element
   478  	// of our range.
   479  	key := []byte(prefix)
   480  	key[len(key)-1] += 1
   481  
   482  	return rs.prev(prefix, string(key))
   483  }
   484  
   485  // clear completely clears the readset.
   486  func (rs *readSet) clear() {
   487  	rs.tree.Clear(false)
   488  	rs.fullRanges = make(map[string]struct{})
   489  }
   490  
   491  // getItem returns the matching key/value from the readset.
   492  func (rs *readSet) getItem(key string) (*stmGet, bool) {
   493  	pivot := &stmGet{
   494  		KV: KV{
   495  			key: key,
   496  		},
   497  		rev: 0,
   498  	}
   499  	item := rs.tree.Get(pivot)
   500  	if item != nil {
   501  		return item.(*stmGet), true
   502  	}
   503  
   504  	// It's possible that although this key isn't in the read set, we
   505  	// fetched a full range the key is prefixed with. In this case we'll
   506  	// insert the key with zero revision.
   507  	for prefix := range rs.fullRanges {
   508  		if strings.HasPrefix(key, prefix) {
   509  			rs.tree.ReplaceOrInsert(pivot)
   510  			return pivot, true
   511  		}
   512  	}
   513  
   514  	return nil, false
   515  }
   516  
   517  // prefetchSet is a helper to create an op slice of all OpGet's that represent
   518  // fetched keys appended with a slice of all OpGet's representing all prefetched
   519  // full ranges.
   520  func (rs *readSet) prefetchSet() []v3.Op {
   521  	ops := make([]v3.Op, 0, rs.tree.Len())
   522  
   523  	rs.tree.Ascend(func(item btree.Item) bool {
   524  		key := item.(*stmGet).key
   525  		for prefix := range rs.fullRanges {
   526  			// Do not add the key if it has been prefetched in a
   527  			// full range.
   528  			if strings.HasPrefix(key, prefix) {
   529  				return true
   530  			}
   531  		}
   532  
   533  		ops = append(ops, v3.OpGet(key))
   534  		return true
   535  	})
   536  
   537  	for prefix := range rs.fullRanges {
   538  		ops = append(ops, v3.OpGet(prefix, v3.WithPrefix()))
   539  	}
   540  
   541  	return ops
   542  }
   543  
   544  // getFullRanges returns all prefixes that we prefetched.
   545  func (rs *readSet) getFullRanges() []string {
   546  	prefixes := make([]string, 0, len(rs.fullRanges))
   547  
   548  	for prefix := range rs.fullRanges {
   549  		prefixes = append(prefixes, prefix)
   550  	}
   551  
   552  	return prefixes
   553  }
   554  
   555  // cmps returns a compare list which will serve as a precondition testing that
   556  // the values in the read set didn't change.
   557  func (rs *readSet) cmps() []v3.Cmp {
   558  	cmps := make([]v3.Cmp, 0, rs.tree.Len())
   559  
   560  	rs.tree.Ascend(func(item btree.Item) bool {
   561  		get := item.(*stmGet)
   562  		cmps = append(
   563  			cmps, v3.Compare(v3.ModRevision(get.key), "=", get.rev),
   564  		)
   565  
   566  		return true
   567  	})
   568  
   569  	return cmps
   570  }
   571  
   572  // cmps returns a cmp list testing no writes have happened past rev.
   573  func (ws writeSet) cmps(rev int64) []v3.Cmp {
   574  	cmps := make([]v3.Cmp, 0, len(ws))
   575  	for key := range ws {
   576  		cmps = append(cmps, v3.Compare(v3.ModRevision(key), "<", rev))
   577  	}
   578  
   579  	return cmps
   580  }
   581  
   582  // puts is the list of ops for all pending writes.
   583  func (ws writeSet) puts() []v3.Op {
   584  	puts := make([]v3.Op, 0, len(ws))
   585  	for _, v := range ws {
   586  		puts = append(puts, v.op)
   587  	}
   588  
   589  	return puts
   590  }
   591  
   592  // fetch is a helper to fetch key/value given options. If a value is returned
   593  // then fetch will try to fix the STM's snapshot revision (if not already set).
   594  // We'll also cache the returned key/value in the read set.
   595  func (s *stm) fetch(key string, opts ...v3.OpOption) ([]KV, error) {
   596  	s.callCount++
   597  	resp, err := s.client.Get(
   598  		s.options.ctx, key, append(opts, s.getOpts...)...,
   599  	)
   600  	if err != nil {
   601  		return nil, DatabaseError{
   602  			msg: "stm.fetch() failed",
   603  			err: err,
   604  		}
   605  	}
   606  
   607  	// Set revision and serializable options upon first fetch
   608  	// for any subsequent fetches.
   609  	if s.getOpts == nil {
   610  		s.revision = resp.Header.Revision
   611  		s.getOpts = []v3.OpOption{
   612  			v3.WithRev(s.revision),
   613  			v3.WithSerializable(),
   614  		}
   615  	}
   616  
   617  	if len(resp.Kvs) == 0 {
   618  		// Add assertion to the read set which will extend our commit
   619  		// constraint such that the commit will fail if the key is
   620  		// present in the database.
   621  		s.rset.addItem(key, "", 0)
   622  	}
   623  
   624  	var result []KV
   625  
   626  	// Fill the read set with key/values returned.
   627  	for _, kv := range resp.Kvs {
   628  		key := string(kv.Key)
   629  		val := string(kv.Value)
   630  
   631  		// Add to read set.
   632  		s.rset.addItem(key, val, kv.ModRevision)
   633  
   634  		result = append(result, KV{key, val})
   635  	}
   636  
   637  	return result, nil
   638  }
   639  
   640  // Get returns the value for key. If there's no such
   641  // key/value in the database or the passed key is empty
   642  // Get will return nil.
   643  func (s *stm) Get(key string) ([]byte, error) {
   644  	if key == "" {
   645  		return nil, nil
   646  	}
   647  
   648  	// Return freshly written value if present.
   649  	if put, ok := s.wset[key]; ok {
   650  		if put.op.IsDelete() {
   651  			return nil, nil
   652  		}
   653  
   654  		return []byte(put.val), nil
   655  	}
   656  
   657  	// Return value if alread in read set.
   658  	if getValue, ok := s.rset.getItem(key); ok {
   659  		// Return the value if the rset contains an existing key.
   660  		if getValue.rev != 0 {
   661  			return []byte(getValue.val), nil
   662  		} else {
   663  			return nil, nil
   664  		}
   665  	}
   666  
   667  	// Fetch and return value.
   668  	kvs, err := s.fetch(key)
   669  	if err != nil {
   670  		return nil, err
   671  	}
   672  
   673  	if len(kvs) > 0 {
   674  		return []byte(kvs[0].val), nil
   675  	}
   676  
   677  	// Return empty result if key not in DB.
   678  	return nil, nil
   679  }
   680  
   681  // First returns the first key/value matching prefix. If there's no key starting
   682  // with prefix, Last will return nil.
   683  func (s *stm) First(prefix string) (*KV, error) {
   684  	return s.next(prefix, prefix, true)
   685  }
   686  
   687  // Last returns the last key/value with prefix. If there's no key starting with
   688  // prefix, Last will return nil.
   689  func (s *stm) Last(prefix string) (*KV, error) {
   690  	var (
   691  		kv    KV
   692  		found bool
   693  	)
   694  
   695  	if s.rset.hasFullRange(prefix) {
   696  		if item, ok := s.rset.last(prefix); ok {
   697  			kv = item.KV
   698  			found = true
   699  		}
   700  	} else {
   701  		// As we don't know the full range, fetch the last
   702  		// key/value with this prefix first.
   703  		resp, err := s.fetch(prefix, v3.WithLastKey()...)
   704  		if err != nil {
   705  			return nil, err
   706  		}
   707  
   708  		if len(resp) > 0 {
   709  			kv = resp[0]
   710  			found = true
   711  		}
   712  	}
   713  
   714  	// Now make sure there's nothing in the write set
   715  	// that is a better match, meaning it has the same
   716  	// prefix but is greater or equal than the current
   717  	// best candidate. Note that this is not efficient
   718  	// when the write set is large!
   719  	for k, put := range s.wset {
   720  		if put.op.IsDelete() {
   721  			continue
   722  		}
   723  
   724  		if strings.HasPrefix(k, prefix) && k >= kv.key {
   725  			kv.key = k
   726  			kv.val = put.val
   727  			found = true
   728  		}
   729  	}
   730  
   731  	if found {
   732  		return &kv, nil
   733  	}
   734  
   735  	return nil, nil
   736  }
   737  
   738  // Prev returns the prior key/value before key (with prefix). If there's no such
   739  // key Prev will return nil.
   740  func (s *stm) Prev(prefix, startKey string) (*KV, error) {
   741  	var kv, result KV
   742  
   743  	fetchKey := startKey
   744  	matchFound := false
   745  
   746  	for {
   747  		if s.rset.hasFullRange(prefix) {
   748  			if item, ok := s.rset.prev(prefix, fetchKey); ok {
   749  				kv = item.KV
   750  			} else {
   751  				break
   752  			}
   753  		} else {
   754  
   755  			// Ask etcd to retrieve one key that is a
   756  			// match in descending order from the passed key.
   757  			opts := []v3.OpOption{
   758  				v3.WithRange(fetchKey),
   759  				v3.WithSort(v3.SortByKey, v3.SortDescend),
   760  				v3.WithLimit(1),
   761  			}
   762  
   763  			kvs, err := s.fetch(prefix, opts...)
   764  			if err != nil {
   765  				return nil, err
   766  			}
   767  
   768  			if len(kvs) == 0 {
   769  				break
   770  			}
   771  
   772  			kv = kvs[0]
   773  		}
   774  
   775  		// WithRange and WithPrefix can't be used
   776  		// together, so check prefix here. If the
   777  		// returned key no longer has the prefix,
   778  		// then break out.
   779  		if !strings.HasPrefix(kv.key, prefix) {
   780  			break
   781  		}
   782  
   783  		// Fetch the prior key if this is deleted.
   784  		if put, ok := s.wset[kv.key]; ok && put.op.IsDelete() {
   785  			fetchKey = kv.key
   786  			continue
   787  		}
   788  
   789  		result = kv
   790  		matchFound = true
   791  
   792  		break
   793  	}
   794  
   795  	// Closure holding all checks to find a possibly
   796  	// better match.
   797  	matches := func(key string) bool {
   798  		if !strings.HasPrefix(key, prefix) {
   799  			return false
   800  		}
   801  
   802  		if !matchFound {
   803  			return key < startKey
   804  		}
   805  
   806  		// matchFound == true
   807  		return result.key <= key && key < startKey
   808  	}
   809  
   810  	// Now go trough the write set and check
   811  	// if there's an even better match.
   812  	for k, put := range s.wset {
   813  		if !put.op.IsDelete() && matches(k) {
   814  			result.key = k
   815  			result.val = put.val
   816  			matchFound = true
   817  		}
   818  	}
   819  
   820  	if !matchFound {
   821  		return nil, nil
   822  	}
   823  
   824  	return &result, nil
   825  }
   826  
   827  // Next returns the next key/value after key (with prefix). If there's no such
   828  // key Next will return nil.
   829  func (s *stm) Next(prefix string, key string) (*KV, error) {
   830  	return s.next(prefix, key, false)
   831  }
   832  
   833  // Seek "seeks" to the key (with prefix). If the key doesn't exists it'll get
   834  // the next key with the same prefix. If no key fills this criteria, Seek will
   835  // return nil.
   836  func (s *stm) Seek(prefix, key string) (*KV, error) {
   837  	return s.next(prefix, key, true)
   838  }
   839  
   840  // next will try to retrieve the next match that has prefix and starts with the
   841  // passed startKey. If includeStartKey is set to true, it'll return the value
   842  // of startKey (essentially implementing seek).
   843  func (s *stm) next(prefix, startKey string, includeStartKey bool) (*KV, error) {
   844  	var kv, result KV
   845  
   846  	fetchKey := startKey
   847  	firstFetch := true
   848  	matchFound := false
   849  
   850  	for {
   851  		if s.rset.hasFullRange(prefix) {
   852  			matchKey := includeStartKey && firstFetch
   853  			firstFetch = false
   854  			if item, ok := s.rset.next(
   855  				prefix, fetchKey, matchKey,
   856  			); ok {
   857  				kv = item.KV
   858  			} else {
   859  				break
   860  			}
   861  		} else {
   862  			// Ask etcd to retrieve one key that is a
   863  			// match in ascending order from the passed key.
   864  			opts := []v3.OpOption{
   865  				v3.WithFromKey(),
   866  				v3.WithSort(v3.SortByKey, v3.SortAscend),
   867  				v3.WithLimit(1),
   868  			}
   869  
   870  			// By default we include the start key too
   871  			// if it is a full match.
   872  			if includeStartKey && firstFetch {
   873  				firstFetch = false
   874  			} else {
   875  				// If we'd like to retrieve the first key
   876  				// after the start key.
   877  				fetchKey += "\x00"
   878  			}
   879  
   880  			kvs, err := s.fetch(fetchKey, opts...)
   881  			if err != nil {
   882  				return nil, err
   883  			}
   884  
   885  			if len(kvs) == 0 {
   886  				break
   887  			}
   888  
   889  			kv = kvs[0]
   890  
   891  			// WithRange and WithPrefix can't be used
   892  			// together, so check prefix here. If the
   893  			// returned key no longer has the prefix,
   894  			// then break the fetch loop.
   895  			if !strings.HasPrefix(kv.key, prefix) {
   896  				break
   897  			}
   898  		}
   899  
   900  		// Move on to fetch starting with the next
   901  		// key if this one is marked deleted.
   902  		if put, ok := s.wset[kv.key]; ok && put.op.IsDelete() {
   903  			fetchKey = kv.key
   904  			continue
   905  		}
   906  
   907  		result = kv
   908  		matchFound = true
   909  
   910  		break
   911  	}
   912  
   913  	// Closure holding all checks to find a possibly
   914  	// better match.
   915  	matches := func(k string) bool {
   916  		if !strings.HasPrefix(k, prefix) {
   917  			return false
   918  		}
   919  
   920  		if includeStartKey && !matchFound {
   921  			return startKey <= k
   922  		}
   923  
   924  		if !includeStartKey && !matchFound {
   925  			return startKey < k
   926  		}
   927  
   928  		if includeStartKey && matchFound {
   929  			return startKey <= k && k <= result.key
   930  		}
   931  
   932  		// !includeStartKey && matchFound.
   933  		return startKey < k && k <= result.key
   934  	}
   935  
   936  	// Now go trough the write set and check
   937  	// if there's an even better match.
   938  	for k, put := range s.wset {
   939  		if !put.op.IsDelete() && matches(k) {
   940  			result.key = k
   941  			result.val = put.val
   942  			matchFound = true
   943  		}
   944  	}
   945  
   946  	if !matchFound {
   947  		return nil, nil
   948  	}
   949  
   950  	return &result, nil
   951  }
   952  
   953  // Put sets the value of the passed key. The actual put will happen upon commit.
   954  func (s *stm) Put(key, val string) {
   955  	s.wset[key] = stmPut{
   956  		val: val,
   957  		op:  v3.OpPut(key, val),
   958  	}
   959  }
   960  
   961  // Del marks a key as deleted. The actual delete will happen upon commit.
   962  func (s *stm) Del(key string) {
   963  	s.wset[key] = stmPut{
   964  		val: "",
   965  		op:  v3.OpDelete(key),
   966  	}
   967  }
   968  
   969  // OnCommit sets the callback that is called upon committing the STM
   970  // transaction.
   971  func (s *stm) OnCommit(cb func()) {
   972  	s.onCommit = cb
   973  }
   974  
   975  // Prefetch will prefetch the passed keys and prefixes in one transaction.
   976  // Keys and prefixes that we already have will be skipped.
   977  func (s *stm) Prefetch(keys []string, prefixes []string) {
   978  	fetchKeys := make([]string, 0, len(keys))
   979  	for _, key := range keys {
   980  		if _, ok := s.rset.getItem(key); !ok {
   981  			fetchKeys = append(fetchKeys, key)
   982  		}
   983  	}
   984  
   985  	fetchPrefixes := make([]string, 0, len(prefixes))
   986  	for _, prefix := range prefixes {
   987  		if s.rset.hasFullRange(prefix) {
   988  			continue
   989  		}
   990  		fetchPrefixes = append(fetchPrefixes, prefix)
   991  	}
   992  
   993  	if len(fetchKeys) == 0 && len(fetchPrefixes) == 0 {
   994  		return
   995  	}
   996  
   997  	prefixOpts := append(
   998  		[]v3.OpOption{v3.WithPrefix()}, s.getOpts...,
   999  	)
  1000  
  1001  	txn := s.client.Txn(s.options.ctx)
  1002  	ops := make([]v3.Op, 0, len(fetchKeys)+len(fetchPrefixes))
  1003  
  1004  	for _, key := range fetchKeys {
  1005  		ops = append(ops, v3.OpGet(key, s.getOpts...))
  1006  	}
  1007  	for _, key := range fetchPrefixes {
  1008  		ops = append(ops, v3.OpGet(key, prefixOpts...))
  1009  	}
  1010  
  1011  	txn.Then(ops...)
  1012  	txnresp, err := txn.Commit()
  1013  	s.callCount++
  1014  
  1015  	if err != nil {
  1016  		return
  1017  	}
  1018  
  1019  	// Set revision and serializable options upon first fetch for any
  1020  	// subsequent fetches.
  1021  	if s.getOpts == nil {
  1022  		s.revision = txnresp.Header.Revision
  1023  		s.getOpts = []v3.OpOption{
  1024  			v3.WithRev(s.revision),
  1025  			v3.WithSerializable(),
  1026  		}
  1027  	}
  1028  
  1029  	// Preset keys to "not-present" (revision set to zero).
  1030  	for _, key := range fetchKeys {
  1031  		s.rset.presetItem(key)
  1032  	}
  1033  
  1034  	// Set prefetched keys.
  1035  	s.rset.add(txnresp.Responses[:len(fetchKeys)])
  1036  
  1037  	// Set prefetched ranges.
  1038  	s.rset.addFullRange(fetchPrefixes, txnresp.Responses[len(fetchKeys):])
  1039  }
  1040  
  1041  // commit builds the final transaction and tries to execute it. If commit fails
  1042  // because the keys have changed return a CommitError, otherwise return a
  1043  // DatabaseError.
  1044  func (s *stm) commit() (CommitStats, error) {
  1045  	rset := s.rset.cmps()
  1046  	wset := s.wset.cmps(s.revision + 1)
  1047  
  1048  	stats := CommitStats{
  1049  		Rset: len(rset),
  1050  		Wset: len(wset),
  1051  	}
  1052  
  1053  	// Create the compare set.
  1054  	cmps := append(rset, wset...)
  1055  	// Create a transaction with the optional abort context.
  1056  	txn := s.client.Txn(s.options.ctx)
  1057  
  1058  	// If the compare set holds, try executing the puts.
  1059  	txn = txn.If(cmps...)
  1060  	txn = txn.Then(s.wset.puts()...)
  1061  
  1062  	// Prefetch keys and ranges in case of conflict to save as many
  1063  	// round-trips as possible.
  1064  	txn = txn.Else(s.rset.prefetchSet()...)
  1065  
  1066  	s.callCount++
  1067  	txnresp, err := txn.Commit()
  1068  	if err != nil {
  1069  		return stats, DatabaseError{
  1070  			msg: "stm.Commit() failed",
  1071  			err: err,
  1072  		}
  1073  	}
  1074  
  1075  	// Call the commit callback if the transaction was successful.
  1076  	if txnresp.Succeeded {
  1077  		if s.onCommit != nil {
  1078  			s.onCommit()
  1079  		}
  1080  
  1081  		return stats, nil
  1082  	}
  1083  
  1084  	// Determine where our fetched full ranges begin in the response.
  1085  	prefixes := s.rset.getFullRanges()
  1086  	firstPrefixResp := len(txnresp.Responses) - len(prefixes)
  1087  
  1088  	// Clear reload and preload it with the prefetched keys and ranges.
  1089  	s.rset.clear()
  1090  	s.rset.add(txnresp.Responses[:firstPrefixResp])
  1091  	s.rset.addFullRange(prefixes, txnresp.Responses[firstPrefixResp:])
  1092  
  1093  	// Set our revision boundary.
  1094  	s.revision = txnresp.Header.Revision
  1095  	s.getOpts = []v3.OpOption{
  1096  		v3.WithRev(s.revision),
  1097  		v3.WithSerializable(),
  1098  	}
  1099  
  1100  	// Return CommitError indicating that the transaction can be retried.
  1101  	return stats, CommitError{}
  1102  }
  1103  
  1104  // Commit simply calls commit and the commit stats callback if set.
  1105  func (s *stm) Commit() error {
  1106  	stats, err := s.commit()
  1107  
  1108  	if s.options.commitStatsCallback != nil {
  1109  		s.options.commitStatsCallback(err == nil, stats)
  1110  	}
  1111  
  1112  	return err
  1113  }
  1114  
  1115  // Rollback resets the STM. This is useful for uncommitted transaction rollback
  1116  // and also used in the STM main loop to reset state if commit fails.
  1117  func (s *stm) Rollback() {
  1118  	s.rollback(true)
  1119  }
  1120  
  1121  // rollback will reset the read and write sets. If clearReadSet is false we'll
  1122  // only reset the the write set.
  1123  func (s *stm) rollback(clearReadSet bool) {
  1124  	if clearReadSet {
  1125  		s.rset.clear()
  1126  		s.revision = math.MaxInt64 - 1
  1127  		s.getOpts = nil
  1128  	}
  1129  
  1130  	s.wset = make(map[string]stmPut)
  1131  }