github.com/niedbalski/juju@v0.0.0-20190215020005-8ff100488e47/state/statetest/persistence_stubs.go (about) 1 // Copyright 2015 Canonical Ltd. 2 // Licensed under the AGPLv3, see LICENCE file for details. 3 4 package statetest 5 6 import ( 7 "reflect" 8 9 "github.com/juju/errors" 10 "github.com/juju/testing" 11 jujutxn "github.com/juju/txn" 12 "gopkg.in/mgo.v2/txn" 13 ) 14 15 type StubPersistence struct { 16 *testing.Stub 17 18 RunFunc func(jujutxn.TransactionSource) error 19 20 ReturnAll interface{} // homegenous(?) list of doc struct (not pointers) 21 ReturnOne interface{} // a doc struct (not a pointer) 22 23 ReturnApplicationExistsOps []txn.Op 24 ReturnIncCharmModifiedVersionOps []txn.Op 25 } 26 27 func NewStubPersistence(stub *testing.Stub) *StubPersistence { 28 s := &StubPersistence{ 29 Stub: stub, 30 } 31 s.RunFunc = s.run 32 return s 33 } 34 35 func (s *StubPersistence) One(collName, id string, doc interface{}) error { 36 s.AddCall("One", collName, id, doc) 37 if err := s.NextErr(); err != nil { 38 return errors.Trace(err) 39 } 40 41 if reflect.TypeOf(s.ReturnOne) == nil { 42 return errors.NotFoundf("resource") 43 } 44 ptr := reflect.ValueOf(doc) 45 newVal := reflect.ValueOf(s.ReturnOne) 46 ptr.Elem().Set(newVal) 47 return nil 48 } 49 50 func (s *StubPersistence) All(collName string, query, docs interface{}) error { 51 s.AddCall("All", collName, query, docs) 52 if err := s.NextErr(); err != nil { 53 return errors.Trace(err) 54 } 55 56 ptr := reflect.ValueOf(docs) 57 if reflect.TypeOf(s.ReturnAll) == nil { 58 ptr.Elem().SetLen(0) 59 } else { 60 newVal := reflect.ValueOf(s.ReturnAll) 61 ptr.Elem().Set(newVal) 62 } 63 return nil 64 } 65 66 func (s *StubPersistence) Run(buildTxn jujutxn.TransactionSource) error { 67 s.AddCall("Run", buildTxn) 68 if err := s.NextErr(); err != nil { 69 return errors.Trace(err) 70 } 71 72 if err := s.run(buildTxn); err != nil { 73 return errors.Trace(err) 74 } 75 76 return nil 77 } 78 79 // See github.com/juju/txn.transactionRunner.Run. 80 func (s *StubPersistence) run(buildTxn jujutxn.TransactionSource) error { 81 for i := 0; ; i++ { 82 ops, err := buildTxn(i) 83 if errors.Cause(err) == jujutxn.ErrTransientFailure { 84 continue 85 } 86 if errors.Cause(err) == jujutxn.ErrNoOperations { 87 return nil 88 } 89 if err != nil { 90 return err 91 } 92 93 err = s.RunTransaction(ops) 94 if errors.Cause(err) == txn.ErrAborted { 95 continue 96 } 97 return err 98 } 99 } 100 101 func (s *StubPersistence) RunTransaction(ops []txn.Op) error { 102 s.AddCall("RunTransaction", ops) 103 if err := s.NextErr(); err != nil { 104 return errors.Trace(err) 105 } 106 107 return nil 108 } 109 110 func (s *StubPersistence) ApplicationExistsOps(applicationID string) []txn.Op { 111 s.AddCall("ApplicationExistsOps", applicationID) 112 // pop off an error so num errors == num calls, even though this call 113 // doesn't actually use the error. 114 s.NextErr() 115 116 return s.ReturnApplicationExistsOps 117 } 118 119 func (s *StubPersistence) IncCharmModifiedVersionOps(applicationID string) []txn.Op { 120 s.AddCall("IncCharmModifiedVersionOps", applicationID) 121 // pop off an error so num errors == num calls, even though this call 122 // doesn't actually use the error. 123 s.NextErr() 124 125 return s.ReturnIncCharmModifiedVersionOps 126 }