github.com/mwhudson/juju@v0.0.0-20160512215208-90ff01f3497f/payload/state/base_test.go (about) 1 // Copyright 2015 Canonical Ltd. 2 // Licensed under the AGPLv3, see LICENCE file for details. 3 4 package state_test 5 6 import ( 7 "fmt" 8 9 "github.com/juju/errors" 10 gitjujutesting "github.com/juju/testing" 11 jc "github.com/juju/testing/checkers" 12 "github.com/juju/utils" 13 gc "gopkg.in/check.v1" 14 "gopkg.in/juju/charm.v6-unstable" 15 16 "github.com/juju/juju/payload" 17 "github.com/juju/juju/testing" 18 ) 19 20 type basePayloadsSuite struct { 21 testing.BaseSuite 22 23 stub *gitjujutesting.Stub 24 persist *fakePayloadsPersistence 25 } 26 27 func (s *basePayloadsSuite) SetUpTest(c *gc.C) { 28 s.BaseSuite.SetUpTest(c) 29 30 s.stub = &gitjujutesting.Stub{} 31 s.persist = &fakePayloadsPersistence{Stub: s.stub} 32 } 33 34 func (s *basePayloadsSuite) newPayload(pType string, id string) payload.Payload { 35 name, rawID := payload.ParseID(id) 36 if rawID == "" { 37 rawID = fmt.Sprintf("%s-%s", name, utils.MustNewUUID()) 38 } 39 40 return payload.Payload{ 41 PayloadClass: charm.PayloadClass{ 42 Name: name, 43 Type: pType, 44 }, 45 Status: payload.StateRunning, 46 ID: rawID, 47 Unit: "a-service/0", 48 } 49 } 50 51 type fakePayloadsPersistence struct { 52 *gitjujutesting.Stub 53 payloads map[string]*payload.Payload 54 } 55 56 func (s *fakePayloadsPersistence) checkPayload(c *gc.C, id string, expected payload.Payload) { 57 pl, ok := s.payloads[id] 58 if !ok { 59 c.Errorf("payload %q not found", id) 60 } else { 61 c.Check(pl, jc.DeepEquals, &expected) 62 } 63 } 64 65 func (s *fakePayloadsPersistence) setPayload(id string, pl *payload.Payload) { 66 if s.payloads == nil { 67 s.payloads = make(map[string]*payload.Payload) 68 } 69 s.payloads[id] = pl 70 } 71 72 func (s *fakePayloadsPersistence) Track(id string, pl payload.Payload) (bool, error) { 73 s.AddCall("Track", id, pl) 74 if err := s.NextErr(); err != nil { 75 return false, errors.Trace(err) 76 } 77 78 if _, ok := s.payloads[id]; ok { 79 return false, nil 80 } 81 s.setPayload(id, &pl) 82 return true, nil 83 } 84 85 func (s *fakePayloadsPersistence) SetStatus(id, status string) (bool, error) { 86 s.AddCall("SetStatus", id, status) 87 if err := s.NextErr(); err != nil { 88 return false, errors.Trace(err) 89 } 90 91 pl, ok := s.payloads[id] 92 if !ok { 93 return false, nil 94 } 95 pl.Status = status 96 return true, nil 97 } 98 99 func (s *fakePayloadsPersistence) List(ids ...string) ([]payload.Payload, []string, error) { 100 s.AddCall("List", ids) 101 if err := s.NextErr(); err != nil { 102 return nil, nil, errors.Trace(err) 103 } 104 105 var payloads []payload.Payload 106 var missing []string 107 for _, id := range ids { 108 if pl, ok := s.payloads[id]; !ok { 109 missing = append(missing, id) 110 } else { 111 payloads = append(payloads, *pl) 112 } 113 } 114 return payloads, missing, nil 115 } 116 117 func (s *fakePayloadsPersistence) ListAll() ([]payload.Payload, error) { 118 s.AddCall("ListAll") 119 if err := s.NextErr(); err != nil { 120 return nil, errors.Trace(err) 121 } 122 123 var payloads []payload.Payload 124 for _, pl := range s.payloads { 125 payloads = append(payloads, *pl) 126 } 127 return payloads, nil 128 } 129 130 func (s *fakePayloadsPersistence) LookUp(name, rawID string) (string, error) { 131 s.AddCall("LookUp", name, rawID) 132 if err := s.NextErr(); err != nil { 133 return "", errors.Trace(err) 134 } 135 136 for id, pl := range s.payloads { 137 if pl.Name == name && pl.ID == rawID { 138 return id, nil 139 } 140 } 141 return "", errors.NotFoundf("doc ID") 142 } 143 144 func (s *fakePayloadsPersistence) Untrack(id string) (bool, error) { 145 s.AddCall("Untrack", id) 146 if err := s.NextErr(); err != nil { 147 return false, errors.Trace(err) 148 } 149 150 if _, ok := s.payloads[id]; !ok { 151 return false, nil 152 } 153 delete(s.payloads, id) 154 return true, nil 155 }