go.etcd.io/etcd@v3.3.27+incompatible/clientv3/concurrency/stm.go (about) 1 // Copyright 2016 The etcd Authors 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package concurrency 16 17 import ( 18 "context" 19 "math" 20 21 v3 "github.com/coreos/etcd/clientv3" 22 ) 23 24 // STM is an interface for software transactional memory. 25 type STM interface { 26 // Get returns the value for a key and inserts the key in the txn's read set. 27 // If Get fails, it aborts the transaction with an error, never returning. 28 Get(key ...string) string 29 // Put adds a value for a key to the write set. 30 Put(key, val string, opts ...v3.OpOption) 31 // Rev returns the revision of a key in the read set. 32 Rev(key string) int64 33 // Del deletes a key. 34 Del(key string) 35 36 // commit attempts to apply the txn's changes to the server. 37 commit() *v3.TxnResponse 38 reset() 39 } 40 41 // Isolation is an enumeration of transactional isolation levels which 42 // describes how transactions should interfere and conflict. 43 type Isolation int 44 45 const ( 46 // SerializableSnapshot provides serializable isolation and also checks 47 // for write conflicts. 48 SerializableSnapshot Isolation = iota 49 // Serializable reads within the same transaction attempt return data 50 // from the at the revision of the first read. 51 Serializable 52 // RepeatableReads reads within the same transaction attempt always 53 // return the same data. 54 RepeatableReads 55 // ReadCommitted reads keys from any committed revision. 56 ReadCommitted 57 ) 58 59 // stmError safely passes STM errors through panic to the STM error channel. 60 type stmError struct{ err error } 61 62 type stmOptions struct { 63 iso Isolation 64 ctx context.Context 65 prefetch []string 66 } 67 68 type stmOption func(*stmOptions) 69 70 // WithIsolation specifies the transaction isolation level. 71 func WithIsolation(lvl Isolation) stmOption { 72 return func(so *stmOptions) { so.iso = lvl } 73 } 74 75 // WithAbortContext specifies the context for permanently aborting the transaction. 76 func WithAbortContext(ctx context.Context) stmOption { 77 return func(so *stmOptions) { so.ctx = ctx } 78 } 79 80 // WithPrefetch is a hint to prefetch a list of keys before trying to apply. 81 // If an STM transaction will unconditionally fetch a set of keys, prefetching 82 // those keys will save the round-trip cost from requesting each key one by one 83 // with Get(). 84 func WithPrefetch(keys ...string) stmOption { 85 return func(so *stmOptions) { so.prefetch = append(so.prefetch, keys...) } 86 } 87 88 // NewSTM initiates a new STM instance, using serializable snapshot isolation by default. 89 func NewSTM(c *v3.Client, apply func(STM) error, so ...stmOption) (*v3.TxnResponse, error) { 90 opts := &stmOptions{ctx: c.Ctx()} 91 for _, f := range so { 92 f(opts) 93 } 94 if len(opts.prefetch) != 0 { 95 f := apply 96 apply = func(s STM) error { 97 s.Get(opts.prefetch...) 98 return f(s) 99 } 100 } 101 return runSTM(mkSTM(c, opts), apply) 102 } 103 104 func mkSTM(c *v3.Client, opts *stmOptions) STM { 105 switch opts.iso { 106 case SerializableSnapshot: 107 s := &stmSerializable{ 108 stm: stm{client: c, ctx: opts.ctx}, 109 prefetch: make(map[string]*v3.GetResponse), 110 } 111 s.conflicts = func() []v3.Cmp { 112 return append(s.rset.cmps(), s.wset.cmps(s.rset.first()+1)...) 113 } 114 return s 115 case Serializable: 116 s := &stmSerializable{ 117 stm: stm{client: c, ctx: opts.ctx}, 118 prefetch: make(map[string]*v3.GetResponse), 119 } 120 s.conflicts = func() []v3.Cmp { return s.rset.cmps() } 121 return s 122 case RepeatableReads: 123 s := &stm{client: c, ctx: opts.ctx, getOpts: []v3.OpOption{v3.WithSerializable()}} 124 s.conflicts = func() []v3.Cmp { return s.rset.cmps() } 125 return s 126 case ReadCommitted: 127 s := &stm{client: c, ctx: opts.ctx, getOpts: []v3.OpOption{v3.WithSerializable()}} 128 s.conflicts = func() []v3.Cmp { return nil } 129 return s 130 default: 131 panic("unsupported stm") 132 } 133 } 134 135 type stmResponse struct { 136 resp *v3.TxnResponse 137 err error 138 } 139 140 func runSTM(s STM, apply func(STM) error) (*v3.TxnResponse, error) { 141 outc := make(chan stmResponse, 1) 142 go func() { 143 defer func() { 144 if r := recover(); r != nil { 145 e, ok := r.(stmError) 146 if !ok { 147 // client apply panicked 148 panic(r) 149 } 150 outc <- stmResponse{nil, e.err} 151 } 152 }() 153 var out stmResponse 154 for { 155 s.reset() 156 if out.err = apply(s); out.err != nil { 157 break 158 } 159 if out.resp = s.commit(); out.resp != nil { 160 break 161 } 162 } 163 outc <- out 164 }() 165 r := <-outc 166 return r.resp, r.err 167 } 168 169 // stm implements repeatable-read software transactional memory over etcd 170 type stm struct { 171 client *v3.Client 172 ctx context.Context 173 // rset holds read key values and revisions 174 rset readSet 175 // wset holds overwritten keys and their values 176 wset writeSet 177 // getOpts are the opts used for gets 178 getOpts []v3.OpOption 179 // conflicts computes the current conflicts on the txn 180 conflicts func() []v3.Cmp 181 } 182 183 type stmPut struct { 184 val string 185 op v3.Op 186 } 187 188 type readSet map[string]*v3.GetResponse 189 190 func (rs readSet) add(keys []string, txnresp *v3.TxnResponse) { 191 for i, resp := range txnresp.Responses { 192 rs[keys[i]] = (*v3.GetResponse)(resp.GetResponseRange()) 193 } 194 } 195 196 // first returns the store revision from the first fetch 197 func (rs readSet) first() int64 { 198 ret := int64(math.MaxInt64 - 1) 199 for _, resp := range rs { 200 if rev := resp.Header.Revision; rev < ret { 201 ret = rev 202 } 203 } 204 return ret 205 } 206 207 // cmps guards the txn from updates to read set 208 func (rs readSet) cmps() []v3.Cmp { 209 cmps := make([]v3.Cmp, 0, len(rs)) 210 for k, rk := range rs { 211 cmps = append(cmps, isKeyCurrent(k, rk)) 212 } 213 return cmps 214 } 215 216 type writeSet map[string]stmPut 217 218 func (ws writeSet) get(keys ...string) *stmPut { 219 for _, key := range keys { 220 if wv, ok := ws[key]; ok { 221 return &wv 222 } 223 } 224 return nil 225 } 226 227 // cmps returns a cmp list testing no writes have happened past rev 228 func (ws writeSet) cmps(rev int64) []v3.Cmp { 229 cmps := make([]v3.Cmp, 0, len(ws)) 230 for key := range ws { 231 cmps = append(cmps, v3.Compare(v3.ModRevision(key), "<", rev)) 232 } 233 return cmps 234 } 235 236 // puts is the list of ops for all pending writes 237 func (ws writeSet) puts() []v3.Op { 238 puts := make([]v3.Op, 0, len(ws)) 239 for _, v := range ws { 240 puts = append(puts, v.op) 241 } 242 return puts 243 } 244 245 func (s *stm) Get(keys ...string) string { 246 if wv := s.wset.get(keys...); wv != nil { 247 return wv.val 248 } 249 return respToValue(s.fetch(keys...)) 250 } 251 252 func (s *stm) Put(key, val string, opts ...v3.OpOption) { 253 s.wset[key] = stmPut{val, v3.OpPut(key, val, opts...)} 254 } 255 256 func (s *stm) Del(key string) { s.wset[key] = stmPut{"", v3.OpDelete(key)} } 257 258 func (s *stm) Rev(key string) int64 { 259 if resp := s.fetch(key); resp != nil && len(resp.Kvs) != 0 { 260 return resp.Kvs[0].ModRevision 261 } 262 return 0 263 } 264 265 func (s *stm) commit() *v3.TxnResponse { 266 txnresp, err := s.client.Txn(s.ctx).If(s.conflicts()...).Then(s.wset.puts()...).Commit() 267 if err != nil { 268 panic(stmError{err}) 269 } 270 if txnresp.Succeeded { 271 return txnresp 272 } 273 return nil 274 } 275 276 func (s *stm) fetch(keys ...string) *v3.GetResponse { 277 if len(keys) == 0 { 278 return nil 279 } 280 ops := make([]v3.Op, len(keys)) 281 for i, key := range keys { 282 if resp, ok := s.rset[key]; ok { 283 return resp 284 } 285 ops[i] = v3.OpGet(key, s.getOpts...) 286 } 287 txnresp, err := s.client.Txn(s.ctx).Then(ops...).Commit() 288 if err != nil { 289 panic(stmError{err}) 290 } 291 s.rset.add(keys, txnresp) 292 return (*v3.GetResponse)(txnresp.Responses[0].GetResponseRange()) 293 } 294 295 func (s *stm) reset() { 296 s.rset = make(map[string]*v3.GetResponse) 297 s.wset = make(map[string]stmPut) 298 } 299 300 type stmSerializable struct { 301 stm 302 prefetch map[string]*v3.GetResponse 303 } 304 305 func (s *stmSerializable) Get(keys ...string) string { 306 if wv := s.wset.get(keys...); wv != nil { 307 return wv.val 308 } 309 firstRead := len(s.rset) == 0 310 for _, key := range keys { 311 if resp, ok := s.prefetch[key]; ok { 312 delete(s.prefetch, key) 313 s.rset[key] = resp 314 } 315 } 316 resp := s.stm.fetch(keys...) 317 if firstRead { 318 // txn's base revision is defined by the first read 319 s.getOpts = []v3.OpOption{ 320 v3.WithRev(resp.Header.Revision), 321 v3.WithSerializable(), 322 } 323 } 324 return respToValue(resp) 325 } 326 327 func (s *stmSerializable) Rev(key string) int64 { 328 s.Get(key) 329 return s.stm.Rev(key) 330 } 331 332 func (s *stmSerializable) gets() ([]string, []v3.Op) { 333 keys := make([]string, 0, len(s.rset)) 334 ops := make([]v3.Op, 0, len(s.rset)) 335 for k := range s.rset { 336 keys = append(keys, k) 337 ops = append(ops, v3.OpGet(k)) 338 } 339 return keys, ops 340 } 341 342 func (s *stmSerializable) commit() *v3.TxnResponse { 343 keys, getops := s.gets() 344 txn := s.client.Txn(s.ctx).If(s.conflicts()...).Then(s.wset.puts()...) 345 // use Else to prefetch keys in case of conflict to save a round trip 346 txnresp, err := txn.Else(getops...).Commit() 347 if err != nil { 348 panic(stmError{err}) 349 } 350 if txnresp.Succeeded { 351 return txnresp 352 } 353 // load prefetch with Else data 354 s.rset.add(keys, txnresp) 355 s.prefetch = s.rset 356 s.getOpts = nil 357 return nil 358 } 359 360 func isKeyCurrent(k string, r *v3.GetResponse) v3.Cmp { 361 if len(r.Kvs) != 0 { 362 return v3.Compare(v3.ModRevision(k), "=", r.Kvs[0].ModRevision) 363 } 364 return v3.Compare(v3.ModRevision(k), "=", 0) 365 } 366 367 func respToValue(resp *v3.GetResponse) string { 368 if resp == nil || len(resp.Kvs) == 0 { 369 return "" 370 } 371 return string(resp.Kvs[0].Value) 372 } 373 374 // NewSTMRepeatable is deprecated. 375 func NewSTMRepeatable(ctx context.Context, c *v3.Client, apply func(STM) error) (*v3.TxnResponse, error) { 376 return NewSTM(c, apply, WithAbortContext(ctx), WithIsolation(RepeatableReads)) 377 } 378 379 // NewSTMSerializable is deprecated. 380 func NewSTMSerializable(ctx context.Context, c *v3.Client, apply func(STM) error) (*v3.TxnResponse, error) { 381 return NewSTM(c, apply, WithAbortContext(ctx), WithIsolation(Serializable)) 382 } 383 384 // NewSTMReadCommitted is deprecated. 385 func NewSTMReadCommitted(ctx context.Context, c *v3.Client, apply func(STM) error) (*v3.TxnResponse, error) { 386 return NewSTM(c, apply, WithAbortContext(ctx), WithIsolation(ReadCommitted)) 387 }