github.com/pachyderm/pachyderm@v1.13.4/src/server/pkg/collection/transaction.go (about) 1 package collection 2 3 // Copyright 2016 The etcd Authors 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 17 // We copy this code from etcd because the etcd implementation of STM does 18 // not have the DelAll method, which we need. 19 20 import ( 21 "bytes" 22 "sort" 23 "strings" 24 "sync" 25 26 v3 "github.com/coreos/etcd/clientv3" 27 "github.com/coreos/etcd/etcdserver/api/v3rpc/rpctypes" 28 "github.com/pachyderm/pachyderm/src/client/pkg/errors" 29 "github.com/pachyderm/pachyderm/src/client/pkg/tracing" 30 "golang.org/x/net/context" 31 ) 32 33 // STM is an interface for software transactional memory. 34 type STM interface { 35 // Get returns the value for a key and inserts the key in the txn's read set. 36 // If Get fails, it aborts the transaction with an error, never returning. 37 Get(key string) (string, error) 38 // Put adds a value for a key to the write set. 39 Put(key, val string, ttl int64, ptr uintptr) error 40 // Rev returns the revision of a key in the read set. 41 Rev(key string) int64 42 // Del deletes a key. 43 Del(key string) 44 // TTL returns the remaining time to live for 'key', or 0 if 'key' has no TTL 45 TTL(key string) (int64, error) 46 // DelAll deletes all keys with the given prefix 47 // Note that the current implementation of DelAll is incomplete. 48 // To use DelAll safely, do not issue any Get/Put operations after 49 // DelAll is called. 50 DelAll(key string) 51 Context() context.Context 52 // SetSafePutCheck sets the bit pattern to check if a put is safe. 53 SetSafePutCheck(key string, ptr uintptr) 54 // IsSafePut checks against the bit pattern for a key to see if it is safe to put. 55 IsSafePut(key string, ptr uintptr) bool 56 57 // commit attempts to apply the txn's changes to the server. 58 commit() *v3.TxnResponse 59 reset() 60 fetch(key string) *v3.GetResponse 61 } 62 63 // stmError safely passes STM errors through panic to the STM error channel. 64 type stmError struct{ err error } 65 66 // NewSTM intiates a new STM operation. It uses a serializable model. 67 func NewSTM(ctx context.Context, c *v3.Client, apply func(STM) error) (*v3.TxnResponse, error) { 68 return newSTMSerializable(ctx, c, apply, false) 69 } 70 71 // NewDryrunSTM intiates a new STM operation, but the final commit is skipped. 72 // It uses a serializable model. 73 func NewDryrunSTM(ctx context.Context, c *v3.Client, apply func(STM) error) error { 74 _, err := newSTMSerializable(ctx, c, apply, true) 75 return err 76 } 77 78 // newSTMSerializable initiates a new serialized transaction; reads within the 79 // same transaction attempt to return data from the revision of the first read. 80 func newSTMSerializable(ctx context.Context, c *v3.Client, apply func(STM) error, dryrun bool) (*v3.TxnResponse, error) { 81 s := &stmSerializable{ 82 stm: stm{client: c, ctx: ctx}, 83 prefetch: make(map[string]*v3.GetResponse), 84 } 85 return runSTM(s, apply, dryrun) 86 } 87 88 type stmResponse struct { 89 resp *v3.TxnResponse 90 err error 91 } 92 93 func runSTM(s STM, apply func(STM) error, dryrun bool) (*v3.TxnResponse, error) { 94 outc := make(chan stmResponse, 1) 95 go func() { 96 defer func() { 97 if r := recover(); r != nil { 98 e, ok := r.(stmError) 99 if !ok { 100 // client apply panicked 101 panic(r) 102 } 103 outc <- stmResponse{nil, e.err} 104 } 105 }() 106 var out stmResponse 107 for { 108 s.reset() 109 if out.err = apply(s); out.err != nil { 110 break 111 } 112 if dryrun { 113 break 114 } else if out.resp = s.commit(); out.resp != nil { 115 break 116 } 117 } 118 outc <- out 119 }() 120 r := <-outc 121 return r.resp, r.err 122 } 123 124 // stm implements repeatable-read software transactional memory over etcd 125 type stm struct { 126 client *v3.Client 127 ctx context.Context 128 // rset holds read key values and revisions 129 rset map[string]*v3.GetResponse 130 // wset holds overwritten keys and their values 131 wset map[string]stmPut 132 // deletedPrefixes holds the set of prefixes that have been deleted 133 deletedPrefixes []string 134 // getOpts are the opts used for gets. Includes revision of first read for 135 // stmSerializable 136 getOpts []v3.OpOption 137 // ttlset is a cache from key to lease TTL. It's similar to rset in that it 138 // caches leases that have already been read, but each may contain keys not in 139 // the other (ttlset in particular caches the TTL of all keys associated with 140 // a lease after reading that lease, even if the other keys haven't been read) 141 ttlset map[string]int64 142 // newLeases is a map from TTL to lease ID; it caches new leases used for this 143 // write. We de-dupe leases by TTL (values written with the same TTL get the 144 // same lease) so that kvs in a collection and its indexes all share a lease. 145 // It's similar to wset for TTLs. 146 newLeases map[int64]v3.LeaseID 147 // mutex for concurrent access 148 sync.Mutex 149 } 150 151 type stmPut struct { 152 val string 153 ttl int64 154 op v3.Op 155 safePutPtr uintptr 156 } 157 158 func (s *stm) Context() context.Context { 159 return s.ctx 160 } 161 162 func (s *stm) Get(key string) (string, error) { 163 s.Lock() 164 defer s.Unlock() 165 if wv, ok := s.wset[key]; ok { 166 return wv.val, nil 167 } 168 if s.isKeyRangeDeleted(key) { 169 return "", ErrNotFound{Key: key} 170 } 171 return respToValue(key, s.fetch(key)) 172 } 173 174 func (s *stm) SetSafePutCheck(key string, ptr uintptr) { 175 s.Lock() 176 defer s.Unlock() 177 if wv, ok := s.wset[key]; ok { 178 wv.safePutPtr = ptr 179 s.wset[key] = wv 180 } 181 } 182 183 func (s *stm) IsSafePut(key string, ptr uintptr) bool { 184 s.Lock() 185 defer s.Unlock() 186 if _, ok := s.wset[key]; ok && s.wset[key].safePutPtr != 0 && ptr != s.wset[key].safePutPtr { 187 return false 188 } 189 return true 190 } 191 192 func (s *stm) isKeyRangeDeleted(key string) bool { 193 for _, prefix := range s.deletedPrefixes { 194 if strings.HasPrefix(key, prefix) { 195 return true 196 } 197 } 198 return false 199 } 200 201 func (s *stm) Put(key, val string, ttl int64, ptr uintptr) error { 202 s.Lock() 203 defer s.Unlock() 204 var options []v3.OpOption 205 if ttl > 0 { 206 lease, ok := s.newLeases[ttl] 207 if !ok { 208 span, ctx := tracing.AddSpanToAnyExisting(s.ctx, "/etcd/GrantLease") 209 defer tracing.FinishAnySpan(span) 210 leaseResp, err := s.client.Grant(ctx, ttl) 211 if err != nil { 212 return errors.Wrapf(err, "error granting lease") 213 } 214 lease = leaseResp.ID 215 s.newLeases[ttl] = lease 216 s.ttlset[key] = ttl // cache key->ttl, in case it's read later 217 } 218 options = append(options, v3.WithLease(lease)) 219 } 220 s.wset[key] = stmPut{val, ttl, v3.OpPut(key, val, options...), ptr} 221 return nil 222 } 223 224 func (s *stm) Del(key string) { 225 s.Lock() 226 defer s.Unlock() 227 s.wset[key] = stmPut{"", 0, v3.OpDelete(key), 0} 228 } 229 230 func (s *stm) DelAll(prefix string) { 231 s.Lock() 232 defer s.Unlock() 233 // Remove any eclipsed deletes then add the new delete 234 isEclipsed := false 235 i := 0 236 for _, deletedPrefix := range s.deletedPrefixes { 237 if strings.HasPrefix(prefix, deletedPrefix) { 238 isEclipsed = true 239 } 240 if !strings.HasPrefix(deletedPrefix, prefix) { 241 s.deletedPrefixes[i] = deletedPrefix 242 i++ 243 } 244 } 245 s.deletedPrefixes = s.deletedPrefixes[:i] 246 247 // If the new DelAll prefix is eclipsed by an already-deleted prefix, don't 248 // add it to the set, but still clean up any eclipsed writes. 249 if !isEclipsed { 250 s.deletedPrefixes = append(s.deletedPrefixes, prefix) 251 } 252 253 for k := range s.wset { 254 if strings.HasPrefix(k, prefix) { 255 delete(s.wset, k) 256 } 257 } 258 } 259 260 func (s *stm) Rev(key string) int64 { 261 s.Lock() 262 defer s.Unlock() 263 if resp := s.fetch(key); resp != nil && len(resp.Kvs) != 0 { 264 return resp.Kvs[0].ModRevision 265 } 266 return 0 267 } 268 269 func (s *stm) commit() *v3.TxnResponse { 270 span, ctx := tracing.AddSpanToAnyExisting(s.ctx, "/etcd/Txn") 271 defer tracing.FinishAnySpan(span) 272 273 cmps := s.cmps() 274 writes := s.writes() 275 txnresp, err := s.client.Txn(ctx).If(cmps...).Then(writes...).Commit() 276 if errors.Is(err, rpctypes.ErrTooManyOps) { 277 panic(stmError{ 278 errors.Errorf( 279 "%v (%d comparisons, %d writes: hint: set --max-txn-ops on the "+ 280 "ETCD cluster to at least the largest of those values)", 281 err, len(cmps), len(writes)), 282 }) 283 } else if err != nil { 284 panic(stmError{err}) 285 } 286 if txnresp.Succeeded { 287 return txnresp 288 } 289 return nil 290 } 291 292 // cmps guards the txn from updates to read set 293 func (s *stm) cmps() []v3.Cmp { 294 cmps := make([]v3.Cmp, 0, len(s.rset)) 295 for k, rk := range s.rset { 296 cmps = append(cmps, isKeyCurrent(k, rk)) 297 } 298 return cmps 299 } 300 301 func (s *stm) fetch(key string) *v3.GetResponse { 302 if resp, ok := s.rset[key]; ok { 303 return resp 304 } 305 306 span, ctx := tracing.AddSpanToAnyExisting(s.ctx, "/etcd.stm/Get", "key", key) 307 defer tracing.FinishAnySpan(span) 308 resp, err := s.client.Get(ctx, key, s.getOpts...) 309 if err != nil { 310 panic(stmError{err}) 311 } 312 s.rset[key] = resp 313 return resp 314 } 315 316 // writes is the list of ops for all pending writes 317 func (s *stm) writes() []v3.Op { 318 prefixes := s.deletedPrefixes 319 puts := make([]string, 0, len(s.wset)) 320 for key := range s.wset { 321 puts = append(puts, key) 322 } 323 sort.Strings(puts) 324 sort.Strings(s.deletedPrefixes) 325 326 writes := make([]v3.Op, 0, 2*len(s.wset)+len(s.deletedPrefixes)) 327 i := 0 // index into puts 328 j := 0 // index into prefixes 329 for i < len(puts) && j < len(prefixes) { 330 if puts[i] < prefixes[j] { 331 // This is a standalone put, nothing fancy here 332 writes = append(writes, s.wset[puts[i]].op) 333 i++ 334 } else { 335 // There may be puts within a deleted range, but we can't have two 336 // overlapping writes - break up the deleted range into multiple deletes. 337 start := prefixes[j] 338 for i < len(puts) && strings.HasPrefix(puts[i], prefixes[j]) { 339 writes = append(writes, v3.OpDelete(start, v3.WithRange(puts[i]))) 340 writes = append(writes, s.wset[puts[i]].op) 341 start = puts[i] + "\x00" 342 i++ 343 } 344 writes = append(writes, v3.OpDelete(start, v3.WithRange(v3.GetPrefixRangeEnd(prefixes[j])))) 345 j++ 346 } 347 } 348 for i < len(puts) { 349 writes = append(writes, s.wset[puts[i]].op) 350 i++ 351 } 352 for j < len(prefixes) { 353 writes = append(writes, v3.OpDelete(prefixes[j], v3.WithPrefix())) 354 j++ 355 } 356 return writes 357 } 358 359 func (s *stm) reset() { 360 s.rset = make(map[string]*v3.GetResponse) 361 s.wset = make(map[string]stmPut) 362 s.deletedPrefixes = []string{} 363 s.ttlset = make(map[string]int64) 364 s.newLeases = make(map[int64]v3.LeaseID) 365 } 366 367 type stmSerializable struct { 368 stm 369 prefetch map[string]*v3.GetResponse 370 } 371 372 func (s *stmSerializable) Get(key string) (string, error) { 373 s.Lock() 374 defer s.Unlock() 375 if wv, ok := s.wset[key]; ok { 376 return wv.val, nil 377 } 378 if s.isKeyRangeDeleted(key) { 379 return "", ErrNotFound{Key: key} 380 } 381 return respToValue(key, s.fetch(key)) 382 } 383 384 func (s *stmSerializable) fetch(key string) *v3.GetResponse { 385 firstRead := len(s.rset) == 0 386 if resp, ok := s.prefetch[key]; ok { 387 delete(s.prefetch, key) 388 s.rset[key] = resp 389 } 390 resp := s.stm.fetch(key) 391 if firstRead { 392 // txn's base revision is defined by the first read 393 s.getOpts = []v3.OpOption{ 394 v3.WithRev(resp.Header.Revision), 395 v3.WithSerializable(), 396 } 397 } 398 return resp 399 } 400 401 func (s *stmSerializable) Rev(key string) int64 { 402 s.Get(key) 403 return s.stm.Rev(key) 404 } 405 406 func (s *stmSerializable) gets() ([]string, []v3.Op) { 407 keys := make([]string, 0, len(s.rset)) 408 ops := make([]v3.Op, 0, len(s.rset)) 409 for k := range s.rset { 410 keys = append(keys, k) 411 ops = append(ops, v3.OpGet(k)) 412 } 413 return keys, ops 414 } 415 416 func (s *stmSerializable) commit() *v3.TxnResponse { 417 span, ctx := tracing.AddSpanToAnyExisting(s.ctx, "/etcd/Txn") 418 defer tracing.FinishAnySpan(span) 419 if span != nil { 420 keys := make([]byte, 0, 512) 421 for k := range s.wset { 422 keys = append(append(keys, ','), k...) 423 } 424 span.SetTag("updated-keys", string(bytes.TrimLeft(keys, ","))) 425 } 426 427 keys, getops := s.gets() 428 cmps := s.cmps() 429 writes := s.writes() 430 txn := s.client.Txn(ctx).If(cmps...).Then(writes...) 431 // use Else to prefetch keys in case of conflict to save a round trip 432 txnresp, err := txn.Else(getops...).Commit() 433 if errors.Is(err, rpctypes.ErrTooManyOps) { 434 panic(stmError{ 435 errors.Errorf( 436 "%v (%d comparisons, %d writes: hint: set --max-txn-ops on the "+ 437 "ETCD cluster to at least the largest of those values)", 438 err, len(cmps), len(writes)), 439 }) 440 } else if err != nil { 441 panic(stmError{err}) 442 } 443 444 tracing.TagAnySpan(span, "applied-at-revision", txnresp.Header.Revision) 445 if txnresp.Succeeded { 446 return txnresp 447 } 448 // load prefetch with Else data 449 for i := range keys { 450 resp := txnresp.Responses[i].GetResponseRange() 451 s.rset[keys[i]] = (*v3.GetResponse)(resp) 452 } 453 s.prefetch = s.rset 454 s.getOpts = nil 455 return nil 456 } 457 458 func isKeyCurrent(k string, r *v3.GetResponse) v3.Cmp { 459 if len(r.Kvs) != 0 { 460 return v3.Compare(v3.ModRevision(k), "=", r.Kvs[0].ModRevision) 461 } 462 return v3.Compare(v3.ModRevision(k), "=", 0) 463 } 464 465 func respToValue(key string, resp *v3.GetResponse) (string, error) { 466 if len(resp.Kvs) == 0 { 467 return "", ErrNotFound{Key: key} 468 } 469 return string(resp.Kvs[0].Value), nil 470 } 471 472 // fetchTTL contains the essential implementation of TTL(). 473 // 474 // Note that 'iface' should either be the receiver 's' or a containing 475 // 'stmSerializeable'--the only reason 'iface' is passed as a separate argument 476 // is because fetchTTL calls iface.fetch(), and the implementation of 'fetch' is 477 // different for stm and stmSerializeable. Passing the interface ensures the 478 // correct version of fetch() is called 479 func (s *stm) fetchTTL(iface STM, key string) (int64, error) { 480 // check wset cache 481 if wv, ok := s.wset[key]; ok { 482 return wv.ttl, nil 483 } 484 if s.isKeyRangeDeleted(key) { 485 return 0, ErrNotFound{Key: key} 486 } 487 488 // Read ttl through s.ttlset cache 489 if ttl, ok := s.ttlset[key]; ok { 490 return ttl, nil 491 } 492 493 // Read kv and lease ID, and cache new TTL 494 getResp := iface.fetch(key) // call correct implementation of fetch() 495 if len(getResp.Kvs) == 0 { 496 return 0, ErrNotFound{Key: key} 497 } 498 leaseID := v3.LeaseID(getResp.Kvs[0].Lease) 499 if leaseID == 0 { 500 s.ttlset[key] = 0 // 0 is default value, but now 'ok' will be true on check 501 return 0, nil 502 } 503 span, ctx := tracing.AddSpanToAnyExisting(s.ctx, "/etcd.stm/TimeToLive", "key", key) 504 defer tracing.FinishAnySpan(span) 505 leaseResp, err := s.client.TimeToLive(ctx, leaseID) 506 if err != nil { 507 panic(stmError{err}) 508 } 509 s.ttlset[key] = leaseResp.TTL 510 for _, key := range leaseResp.Keys { 511 s.ttlset[string(key)] = leaseResp.TTL 512 } 513 return leaseResp.TTL, nil 514 } 515 516 func (s *stm) TTL(key string) (int64, error) { 517 s.Lock() 518 defer s.Unlock() 519 return s.fetchTTL(s, key) 520 } 521 522 func (s *stmSerializable) TTL(key string) (int64, error) { 523 s.Lock() 524 defer s.Unlock() 525 return s.fetchTTL(s, key) 526 }