github.com/pachyderm/pachyderm@v1.13.4/src/server/pkg/collection/transaction.go (about)

     1  package collection
     2  
     3  // Copyright 2016 The etcd Authors
     4  //
     5  // Licensed under the Apache License, Version 2.0 (the "License");
     6  // you may not use this file except in compliance with the License.
     7  // You may obtain a copy of the License at
     8  //
     9  //     http://www.apache.org/licenses/LICENSE-2.0
    10  //
    11  // Unless required by applicable law or agreed to in writing, software
    12  // distributed under the License is distributed on an "AS IS" BASIS,
    13  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  // See the License for the specific language governing permissions and
    15  // limitations under the License.
    16  
    17  // We copy this code from etcd because the etcd implementation of STM does
    18  // not have the DelAll method, which we need.
    19  
    20  import (
    21  	"bytes"
    22  	"sort"
    23  	"strings"
    24  	"sync"
    25  
    26  	v3 "github.com/coreos/etcd/clientv3"
    27  	"github.com/coreos/etcd/etcdserver/api/v3rpc/rpctypes"
    28  	"github.com/pachyderm/pachyderm/src/client/pkg/errors"
    29  	"github.com/pachyderm/pachyderm/src/client/pkg/tracing"
    30  	"golang.org/x/net/context"
    31  )
    32  
    33  // STM is an interface for software transactional memory.
    34  type STM interface {
    35  	// Get returns the value for a key and inserts the key in the txn's read set.
    36  	// If Get fails, it aborts the transaction with an error, never returning.
    37  	Get(key string) (string, error)
    38  	// Put adds a value for a key to the write set.
    39  	Put(key, val string, ttl int64, ptr uintptr) error
    40  	// Rev returns the revision of a key in the read set.
    41  	Rev(key string) int64
    42  	// Del deletes a key.
    43  	Del(key string)
    44  	// TTL returns the remaining time to live for 'key', or 0 if 'key' has no TTL
    45  	TTL(key string) (int64, error)
    46  	// DelAll deletes all keys with the given prefix
    47  	// Note that the current implementation of DelAll is incomplete.
    48  	// To use DelAll safely, do not issue any Get/Put operations after
    49  	// DelAll is called.
    50  	DelAll(key string)
    51  	Context() context.Context
    52  	// SetSafePutCheck sets the bit pattern to check if a put is safe.
    53  	SetSafePutCheck(key string, ptr uintptr)
    54  	// IsSafePut checks against the bit pattern for a key to see if it is safe to put.
    55  	IsSafePut(key string, ptr uintptr) bool
    56  
    57  	// commit attempts to apply the txn's changes to the server.
    58  	commit() *v3.TxnResponse
    59  	reset()
    60  	fetch(key string) *v3.GetResponse
    61  }
    62  
    63  // stmError safely passes STM errors through panic to the STM error channel.
    64  type stmError struct{ err error }
    65  
    66  // NewSTM intiates a new STM operation. It uses a serializable model.
    67  func NewSTM(ctx context.Context, c *v3.Client, apply func(STM) error) (*v3.TxnResponse, error) {
    68  	return newSTMSerializable(ctx, c, apply, false)
    69  }
    70  
    71  // NewDryrunSTM intiates a new STM operation, but the final commit is skipped.
    72  // It uses a serializable model.
    73  func NewDryrunSTM(ctx context.Context, c *v3.Client, apply func(STM) error) error {
    74  	_, err := newSTMSerializable(ctx, c, apply, true)
    75  	return err
    76  }
    77  
    78  // newSTMSerializable initiates a new serialized transaction; reads within the
    79  // same transaction attempt to return data from the revision of the first read.
    80  func newSTMSerializable(ctx context.Context, c *v3.Client, apply func(STM) error, dryrun bool) (*v3.TxnResponse, error) {
    81  	s := &stmSerializable{
    82  		stm:      stm{client: c, ctx: ctx},
    83  		prefetch: make(map[string]*v3.GetResponse),
    84  	}
    85  	return runSTM(s, apply, dryrun)
    86  }
    87  
    88  type stmResponse struct {
    89  	resp *v3.TxnResponse
    90  	err  error
    91  }
    92  
    93  func runSTM(s STM, apply func(STM) error, dryrun bool) (*v3.TxnResponse, error) {
    94  	outc := make(chan stmResponse, 1)
    95  	go func() {
    96  		defer func() {
    97  			if r := recover(); r != nil {
    98  				e, ok := r.(stmError)
    99  				if !ok {
   100  					// client apply panicked
   101  					panic(r)
   102  				}
   103  				outc <- stmResponse{nil, e.err}
   104  			}
   105  		}()
   106  		var out stmResponse
   107  		for {
   108  			s.reset()
   109  			if out.err = apply(s); out.err != nil {
   110  				break
   111  			}
   112  			if dryrun {
   113  				break
   114  			} else if out.resp = s.commit(); out.resp != nil {
   115  				break
   116  			}
   117  		}
   118  		outc <- out
   119  	}()
   120  	r := <-outc
   121  	return r.resp, r.err
   122  }
   123  
   124  // stm implements repeatable-read software transactional memory over etcd
   125  type stm struct {
   126  	client *v3.Client
   127  	ctx    context.Context
   128  	// rset holds read key values and revisions
   129  	rset map[string]*v3.GetResponse
   130  	// wset holds overwritten keys and their values
   131  	wset map[string]stmPut
   132  	// deletedPrefixes holds the set of prefixes that have been deleted
   133  	deletedPrefixes []string
   134  	// getOpts are the opts used for gets. Includes revision of first read for
   135  	// stmSerializable
   136  	getOpts []v3.OpOption
   137  	// ttlset is a cache from key to lease TTL. It's similar to rset in that it
   138  	// caches leases that have already been read, but each may contain keys not in
   139  	// the other (ttlset in particular caches the TTL of all keys associated with
   140  	// a lease after reading that lease, even if the other keys haven't been read)
   141  	ttlset map[string]int64
   142  	// newLeases is a map from TTL to lease ID; it caches new leases used for this
   143  	// write. We de-dupe leases by TTL (values written with the same TTL get the
   144  	// same lease) so that kvs in a collection and its indexes all share a lease.
   145  	// It's similar to wset for TTLs.
   146  	newLeases map[int64]v3.LeaseID
   147  	// mutex for concurrent access
   148  	sync.Mutex
   149  }
   150  
   151  type stmPut struct {
   152  	val        string
   153  	ttl        int64
   154  	op         v3.Op
   155  	safePutPtr uintptr
   156  }
   157  
   158  func (s *stm) Context() context.Context {
   159  	return s.ctx
   160  }
   161  
   162  func (s *stm) Get(key string) (string, error) {
   163  	s.Lock()
   164  	defer s.Unlock()
   165  	if wv, ok := s.wset[key]; ok {
   166  		return wv.val, nil
   167  	}
   168  	if s.isKeyRangeDeleted(key) {
   169  		return "", ErrNotFound{Key: key}
   170  	}
   171  	return respToValue(key, s.fetch(key))
   172  }
   173  
   174  func (s *stm) SetSafePutCheck(key string, ptr uintptr) {
   175  	s.Lock()
   176  	defer s.Unlock()
   177  	if wv, ok := s.wset[key]; ok {
   178  		wv.safePutPtr = ptr
   179  		s.wset[key] = wv
   180  	}
   181  }
   182  
   183  func (s *stm) IsSafePut(key string, ptr uintptr) bool {
   184  	s.Lock()
   185  	defer s.Unlock()
   186  	if _, ok := s.wset[key]; ok && s.wset[key].safePutPtr != 0 && ptr != s.wset[key].safePutPtr {
   187  		return false
   188  	}
   189  	return true
   190  }
   191  
   192  func (s *stm) isKeyRangeDeleted(key string) bool {
   193  	for _, prefix := range s.deletedPrefixes {
   194  		if strings.HasPrefix(key, prefix) {
   195  			return true
   196  		}
   197  	}
   198  	return false
   199  }
   200  
   201  func (s *stm) Put(key, val string, ttl int64, ptr uintptr) error {
   202  	s.Lock()
   203  	defer s.Unlock()
   204  	var options []v3.OpOption
   205  	if ttl > 0 {
   206  		lease, ok := s.newLeases[ttl]
   207  		if !ok {
   208  			span, ctx := tracing.AddSpanToAnyExisting(s.ctx, "/etcd/GrantLease")
   209  			defer tracing.FinishAnySpan(span)
   210  			leaseResp, err := s.client.Grant(ctx, ttl)
   211  			if err != nil {
   212  				return errors.Wrapf(err, "error granting lease")
   213  			}
   214  			lease = leaseResp.ID
   215  			s.newLeases[ttl] = lease
   216  			s.ttlset[key] = ttl // cache key->ttl, in case it's read later
   217  		}
   218  		options = append(options, v3.WithLease(lease))
   219  	}
   220  	s.wset[key] = stmPut{val, ttl, v3.OpPut(key, val, options...), ptr}
   221  	return nil
   222  }
   223  
   224  func (s *stm) Del(key string) {
   225  	s.Lock()
   226  	defer s.Unlock()
   227  	s.wset[key] = stmPut{"", 0, v3.OpDelete(key), 0}
   228  }
   229  
   230  func (s *stm) DelAll(prefix string) {
   231  	s.Lock()
   232  	defer s.Unlock()
   233  	// Remove any eclipsed deletes then add the new delete
   234  	isEclipsed := false
   235  	i := 0
   236  	for _, deletedPrefix := range s.deletedPrefixes {
   237  		if strings.HasPrefix(prefix, deletedPrefix) {
   238  			isEclipsed = true
   239  		}
   240  		if !strings.HasPrefix(deletedPrefix, prefix) {
   241  			s.deletedPrefixes[i] = deletedPrefix
   242  			i++
   243  		}
   244  	}
   245  	s.deletedPrefixes = s.deletedPrefixes[:i]
   246  
   247  	// If the new DelAll prefix is eclipsed by an already-deleted prefix, don't
   248  	// add it to the set, but still clean up any eclipsed writes.
   249  	if !isEclipsed {
   250  		s.deletedPrefixes = append(s.deletedPrefixes, prefix)
   251  	}
   252  
   253  	for k := range s.wset {
   254  		if strings.HasPrefix(k, prefix) {
   255  			delete(s.wset, k)
   256  		}
   257  	}
   258  }
   259  
   260  func (s *stm) Rev(key string) int64 {
   261  	s.Lock()
   262  	defer s.Unlock()
   263  	if resp := s.fetch(key); resp != nil && len(resp.Kvs) != 0 {
   264  		return resp.Kvs[0].ModRevision
   265  	}
   266  	return 0
   267  }
   268  
   269  func (s *stm) commit() *v3.TxnResponse {
   270  	span, ctx := tracing.AddSpanToAnyExisting(s.ctx, "/etcd/Txn")
   271  	defer tracing.FinishAnySpan(span)
   272  
   273  	cmps := s.cmps()
   274  	writes := s.writes()
   275  	txnresp, err := s.client.Txn(ctx).If(cmps...).Then(writes...).Commit()
   276  	if errors.Is(err, rpctypes.ErrTooManyOps) {
   277  		panic(stmError{
   278  			errors.Errorf(
   279  				"%v (%d comparisons, %d writes: hint: set --max-txn-ops on the "+
   280  					"ETCD cluster to at least the largest of those values)",
   281  				err, len(cmps), len(writes)),
   282  		})
   283  	} else if err != nil {
   284  		panic(stmError{err})
   285  	}
   286  	if txnresp.Succeeded {
   287  		return txnresp
   288  	}
   289  	return nil
   290  }
   291  
   292  // cmps guards the txn from updates to read set
   293  func (s *stm) cmps() []v3.Cmp {
   294  	cmps := make([]v3.Cmp, 0, len(s.rset))
   295  	for k, rk := range s.rset {
   296  		cmps = append(cmps, isKeyCurrent(k, rk))
   297  	}
   298  	return cmps
   299  }
   300  
   301  func (s *stm) fetch(key string) *v3.GetResponse {
   302  	if resp, ok := s.rset[key]; ok {
   303  		return resp
   304  	}
   305  
   306  	span, ctx := tracing.AddSpanToAnyExisting(s.ctx, "/etcd.stm/Get", "key", key)
   307  	defer tracing.FinishAnySpan(span)
   308  	resp, err := s.client.Get(ctx, key, s.getOpts...)
   309  	if err != nil {
   310  		panic(stmError{err})
   311  	}
   312  	s.rset[key] = resp
   313  	return resp
   314  }
   315  
   316  // writes is the list of ops for all pending writes
   317  func (s *stm) writes() []v3.Op {
   318  	prefixes := s.deletedPrefixes
   319  	puts := make([]string, 0, len(s.wset))
   320  	for key := range s.wset {
   321  		puts = append(puts, key)
   322  	}
   323  	sort.Strings(puts)
   324  	sort.Strings(s.deletedPrefixes)
   325  
   326  	writes := make([]v3.Op, 0, 2*len(s.wset)+len(s.deletedPrefixes))
   327  	i := 0 // index into puts
   328  	j := 0 // index into prefixes
   329  	for i < len(puts) && j < len(prefixes) {
   330  		if puts[i] < prefixes[j] {
   331  			// This is a standalone put, nothing fancy here
   332  			writes = append(writes, s.wset[puts[i]].op)
   333  			i++
   334  		} else {
   335  			// There may be puts within a deleted range, but we can't have two
   336  			// overlapping writes - break up the deleted range into multiple deletes.
   337  			start := prefixes[j]
   338  			for i < len(puts) && strings.HasPrefix(puts[i], prefixes[j]) {
   339  				writes = append(writes, v3.OpDelete(start, v3.WithRange(puts[i])))
   340  				writes = append(writes, s.wset[puts[i]].op)
   341  				start = puts[i] + "\x00"
   342  				i++
   343  			}
   344  			writes = append(writes, v3.OpDelete(start, v3.WithRange(v3.GetPrefixRangeEnd(prefixes[j]))))
   345  			j++
   346  		}
   347  	}
   348  	for i < len(puts) {
   349  		writes = append(writes, s.wset[puts[i]].op)
   350  		i++
   351  	}
   352  	for j < len(prefixes) {
   353  		writes = append(writes, v3.OpDelete(prefixes[j], v3.WithPrefix()))
   354  		j++
   355  	}
   356  	return writes
   357  }
   358  
   359  func (s *stm) reset() {
   360  	s.rset = make(map[string]*v3.GetResponse)
   361  	s.wset = make(map[string]stmPut)
   362  	s.deletedPrefixes = []string{}
   363  	s.ttlset = make(map[string]int64)
   364  	s.newLeases = make(map[int64]v3.LeaseID)
   365  }
   366  
   367  type stmSerializable struct {
   368  	stm
   369  	prefetch map[string]*v3.GetResponse
   370  }
   371  
   372  func (s *stmSerializable) Get(key string) (string, error) {
   373  	s.Lock()
   374  	defer s.Unlock()
   375  	if wv, ok := s.wset[key]; ok {
   376  		return wv.val, nil
   377  	}
   378  	if s.isKeyRangeDeleted(key) {
   379  		return "", ErrNotFound{Key: key}
   380  	}
   381  	return respToValue(key, s.fetch(key))
   382  }
   383  
   384  func (s *stmSerializable) fetch(key string) *v3.GetResponse {
   385  	firstRead := len(s.rset) == 0
   386  	if resp, ok := s.prefetch[key]; ok {
   387  		delete(s.prefetch, key)
   388  		s.rset[key] = resp
   389  	}
   390  	resp := s.stm.fetch(key)
   391  	if firstRead {
   392  		// txn's base revision is defined by the first read
   393  		s.getOpts = []v3.OpOption{
   394  			v3.WithRev(resp.Header.Revision),
   395  			v3.WithSerializable(),
   396  		}
   397  	}
   398  	return resp
   399  }
   400  
   401  func (s *stmSerializable) Rev(key string) int64 {
   402  	s.Get(key)
   403  	return s.stm.Rev(key)
   404  }
   405  
   406  func (s *stmSerializable) gets() ([]string, []v3.Op) {
   407  	keys := make([]string, 0, len(s.rset))
   408  	ops := make([]v3.Op, 0, len(s.rset))
   409  	for k := range s.rset {
   410  		keys = append(keys, k)
   411  		ops = append(ops, v3.OpGet(k))
   412  	}
   413  	return keys, ops
   414  }
   415  
   416  func (s *stmSerializable) commit() *v3.TxnResponse {
   417  	span, ctx := tracing.AddSpanToAnyExisting(s.ctx, "/etcd/Txn")
   418  	defer tracing.FinishAnySpan(span)
   419  	if span != nil {
   420  		keys := make([]byte, 0, 512)
   421  		for k := range s.wset {
   422  			keys = append(append(keys, ','), k...)
   423  		}
   424  		span.SetTag("updated-keys", string(bytes.TrimLeft(keys, ",")))
   425  	}
   426  
   427  	keys, getops := s.gets()
   428  	cmps := s.cmps()
   429  	writes := s.writes()
   430  	txn := s.client.Txn(ctx).If(cmps...).Then(writes...)
   431  	// use Else to prefetch keys in case of conflict to save a round trip
   432  	txnresp, err := txn.Else(getops...).Commit()
   433  	if errors.Is(err, rpctypes.ErrTooManyOps) {
   434  		panic(stmError{
   435  			errors.Errorf(
   436  				"%v (%d comparisons, %d writes: hint: set --max-txn-ops on the "+
   437  					"ETCD cluster to at least the largest of those values)",
   438  				err, len(cmps), len(writes)),
   439  		})
   440  	} else if err != nil {
   441  		panic(stmError{err})
   442  	}
   443  
   444  	tracing.TagAnySpan(span, "applied-at-revision", txnresp.Header.Revision)
   445  	if txnresp.Succeeded {
   446  		return txnresp
   447  	}
   448  	// load prefetch with Else data
   449  	for i := range keys {
   450  		resp := txnresp.Responses[i].GetResponseRange()
   451  		s.rset[keys[i]] = (*v3.GetResponse)(resp)
   452  	}
   453  	s.prefetch = s.rset
   454  	s.getOpts = nil
   455  	return nil
   456  }
   457  
   458  func isKeyCurrent(k string, r *v3.GetResponse) v3.Cmp {
   459  	if len(r.Kvs) != 0 {
   460  		return v3.Compare(v3.ModRevision(k), "=", r.Kvs[0].ModRevision)
   461  	}
   462  	return v3.Compare(v3.ModRevision(k), "=", 0)
   463  }
   464  
   465  func respToValue(key string, resp *v3.GetResponse) (string, error) {
   466  	if len(resp.Kvs) == 0 {
   467  		return "", ErrNotFound{Key: key}
   468  	}
   469  	return string(resp.Kvs[0].Value), nil
   470  }
   471  
   472  // fetchTTL contains the essential implementation of TTL().
   473  //
   474  // Note that 'iface' should either be the receiver 's' or a containing
   475  // 'stmSerializeable'--the only reason 'iface' is passed as a separate argument
   476  // is because fetchTTL calls iface.fetch(), and the implementation of 'fetch' is
   477  // different for stm and stmSerializeable. Passing the interface ensures the
   478  // correct version of fetch() is called
   479  func (s *stm) fetchTTL(iface STM, key string) (int64, error) {
   480  	// check wset cache
   481  	if wv, ok := s.wset[key]; ok {
   482  		return wv.ttl, nil
   483  	}
   484  	if s.isKeyRangeDeleted(key) {
   485  		return 0, ErrNotFound{Key: key}
   486  	}
   487  
   488  	// Read ttl through s.ttlset cache
   489  	if ttl, ok := s.ttlset[key]; ok {
   490  		return ttl, nil
   491  	}
   492  
   493  	// Read kv and lease ID, and cache new TTL
   494  	getResp := iface.fetch(key) // call correct implementation of fetch()
   495  	if len(getResp.Kvs) == 0 {
   496  		return 0, ErrNotFound{Key: key}
   497  	}
   498  	leaseID := v3.LeaseID(getResp.Kvs[0].Lease)
   499  	if leaseID == 0 {
   500  		s.ttlset[key] = 0 // 0 is default value, but now 'ok' will be true on check
   501  		return 0, nil
   502  	}
   503  	span, ctx := tracing.AddSpanToAnyExisting(s.ctx, "/etcd.stm/TimeToLive", "key", key)
   504  	defer tracing.FinishAnySpan(span)
   505  	leaseResp, err := s.client.TimeToLive(ctx, leaseID)
   506  	if err != nil {
   507  		panic(stmError{err})
   508  	}
   509  	s.ttlset[key] = leaseResp.TTL
   510  	for _, key := range leaseResp.Keys {
   511  		s.ttlset[string(key)] = leaseResp.TTL
   512  	}
   513  	return leaseResp.TTL, nil
   514  }
   515  
   516  func (s *stm) TTL(key string) (int64, error) {
   517  	s.Lock()
   518  	defer s.Unlock()
   519  	return s.fetchTTL(s, key)
   520  }
   521  
   522  func (s *stmSerializable) TTL(key string) (int64, error) {
   523  	s.Lock()
   524  	defer s.Unlock()
   525  	return s.fetchTTL(s, key)
   526  }