github.com/makyo/juju@v0.0.0-20160425123129-2608902037e9/payload/persistence/fakes_test.go (about) 1 // Copyright 2015 Canonical Ltd. 2 // Licensed under the AGPLv3, see LICENCE file for details. 3 4 package persistence 5 6 import ( 7 "github.com/juju/errors" 8 gitjujutesting "github.com/juju/testing" 9 jc "github.com/juju/testing/checkers" 10 jujutxn "github.com/juju/txn" 11 gc "gopkg.in/check.v1" 12 "gopkg.in/mgo.v2/bson" 13 "gopkg.in/mgo.v2/txn" 14 ) 15 16 type fakeStatePersistence struct { 17 *gitjujutesting.Stub 18 19 docs map[string]*payloadDoc 20 ops [][]txn.Op 21 } 22 23 func (sp *fakeStatePersistence) SetDocs(docs ...*payloadDoc) { 24 if sp.docs == nil { 25 sp.docs = make(map[string]*payloadDoc) 26 } 27 for _, doc := range docs { 28 sp.docs[doc.DocID] = doc 29 } 30 } 31 32 func (sp fakeStatePersistence) CheckOps(c *gc.C, expected [][]txn.Op) { 33 if len(sp.ops) != len(expected) { 34 c.Check(sp.ops, jc.DeepEquals, expected) 35 return 36 } 37 38 for i, ops := range sp.ops { 39 c.Logf(" -- txn attempt %d --\n", i) 40 expectedRun := expected[i] 41 if len(ops) != len(expectedRun) { 42 c.Check(ops, jc.DeepEquals, expectedRun) 43 continue 44 } 45 for j, op := range ops { 46 c.Logf(" <op %d>\n", j) 47 expectedOp := expectedRun[j] 48 if expectedOp.Insert != nil { 49 if doc, ok := expectedOp.Insert.(*PayloadDoc); ok { 50 expectedOp.Insert = doc.convert() 51 } 52 } else if expectedOp.Update != nil { 53 if doc, ok := expectedOp.Update.(*PayloadDoc); ok { 54 expectedOp.Update = doc.convert() 55 } 56 } 57 c.Check(op, jc.DeepEquals, expectedOp) 58 } 59 } 60 } 61 62 func (sp fakeStatePersistence) CheckNoOps(c *gc.C) { 63 c.Check(sp.ops, gc.HasLen, 0) 64 } 65 66 func (sp fakeStatePersistence) One(collName, id string, doc interface{}) error { 67 sp.AddCall("One", collName, id, doc) 68 if err := sp.NextErr(); err != nil { 69 return errors.Trace(err) 70 } 71 72 if len(sp.docs) == 0 { 73 return errors.NotFoundf(id) 74 } 75 found, ok := sp.docs[id] 76 if !ok { 77 return errors.NotFoundf(id) 78 } 79 actual := doc.(*payloadDoc) 80 *actual = *found 81 return nil 82 } 83 84 func (sp fakeStatePersistence) All(collName string, query, docs interface{}) error { 85 sp.AddCall("All", collName, query, docs) 86 if err := sp.NextErr(); err != nil { 87 return errors.Trace(err) 88 } 89 90 var ids []string 91 elems := query.(bson.D) 92 if len(elems) < 1 { 93 err := errors.Errorf("bad query %v", query) 94 panic(err) 95 } 96 switch elems[0].Name { 97 case "_id": 98 if len(elems) != 1 { 99 err := errors.Errorf("bad query %v", query) 100 panic(err) 101 } 102 elems = elems[0].Value.(bson.D) 103 if len(elems) != 1 || elems[0].Name != "$in" { 104 err := errors.Errorf("bad query %v", query) 105 panic(err) 106 } 107 ids = elems[0].Value.([]string) 108 case "unitid": 109 for id := range sp.docs { 110 ids = append(ids, id) 111 } 112 default: 113 panic(query) 114 } 115 116 var found []payloadDoc 117 for _, id := range ids { 118 doc, ok := sp.docs[id] 119 if !ok { 120 continue 121 } 122 found = append(found, *doc) 123 } 124 actual := docs.(*[]payloadDoc) 125 *actual = found 126 return nil 127 } 128 129 func (sp *fakeStatePersistence) Run(transactions jujutxn.TransactionSource) error { 130 sp.AddCall("Run", transactions) 131 132 // See transactionRunner.Run in github.com/juju/txn. 133 for i := 0; ; i++ { 134 const nrRetries = 3 135 if i >= nrRetries { 136 return jujutxn.ErrExcessiveContention 137 } 138 139 // Get the ops. 140 ops, err := transactions(i) 141 if err == jujutxn.ErrTransientFailure { 142 continue 143 } 144 if err == jujutxn.ErrNoOperations { 145 break 146 } 147 if err != nil { 148 return err 149 } 150 151 // "run" the ops. 152 sp.ops = append(sp.ops, ops) 153 if err := sp.NextErr(); err == nil { 154 return nil 155 } else if err != txn.ErrAborted { 156 return err 157 } 158 } 159 return nil 160 }