github.com/makyo/juju@v0.0.0-20160425123129-2608902037e9/payload/persistence/env_test.go (about) 1 // Copyright 2015 Canonical Ltd. 2 // Licensed under the AGPLv3, see LICENCE file for details. 3 4 package persistence 5 6 import ( 7 "reflect" 8 9 "github.com/juju/errors" 10 "github.com/juju/testing" 11 jc "github.com/juju/testing/checkers" 12 gc "gopkg.in/check.v1" 13 "gopkg.in/juju/charm.v6-unstable" 14 15 "github.com/juju/juju/payload" 16 ) 17 18 var _ = gc.Suite(&envPersistenceSuite{}) 19 20 type envPersistenceSuite struct { 21 BaseSuite 22 23 base *stubEnvPersistenceBase 24 } 25 26 func (s *envPersistenceSuite) SetUpTest(c *gc.C) { 27 s.BaseSuite.SetUpTest(c) 28 29 s.base = &stubEnvPersistenceBase{ 30 PersistenceBase: s.State, 31 stub: s.Stub, 32 } 33 } 34 35 func (s *envPersistenceSuite) newPayload(name string) payload.FullPayloadInfo { 36 return payload.FullPayloadInfo{ 37 Payload: payload.Payload{ 38 PayloadClass: charm.PayloadClass{ 39 Name: name, 40 Type: "docker", 41 }, 42 ID: "id" + name, 43 Status: payload.StateRunning, 44 Labels: []string{"a-tag"}, 45 Unit: "a-service/0", 46 }, 47 Machine: "1", 48 } 49 } 50 51 func (s *envPersistenceSuite) TestListAllOkay(c *gc.C) { 52 s.base.setUnits("0") 53 s.base.setUnits("1", "a-service/0", "a-service/1") 54 s.base.setUnits("2", "a-service/2") 55 p1 := s.newPayload("spam") 56 p2 := s.newPayload("eggs") 57 s.base.setPayloads(p1, p2) 58 59 persist := NewEnvPersistence(s.base) 60 persist.newUnitPersist = s.base.newUnitPersistence 61 62 payloads, err := persist.ListAll() 63 c.Assert(err, jc.ErrorIsNil) 64 65 checkPayloads(c, payloads, p1, p2) 66 s.Stub.CheckCallNames(c, 67 "Machines", 68 69 "MachineUnits", 70 71 "MachineUnits", 72 "newUnitPersistence", 73 "ListAll", 74 "newUnitPersistence", 75 "ListAll", 76 77 "MachineUnits", 78 "newUnitPersistence", 79 "ListAll", 80 ) 81 } 82 83 func (s *envPersistenceSuite) TestListAllEmpty(c *gc.C) { 84 s.base.setUnits("0") 85 s.base.setUnits("1", "a-service/0", "a-service/1") 86 persist := NewEnvPersistence(s.base) 87 persist.newUnitPersist = s.base.newUnitPersistence 88 89 payloads, err := persist.ListAll() 90 c.Assert(err, jc.ErrorIsNil) 91 92 c.Check(payloads, gc.HasLen, 0) 93 s.Stub.CheckCallNames(c, 94 "Machines", 95 96 "MachineUnits", 97 98 "MachineUnits", 99 "newUnitPersistence", 100 "ListAll", 101 "newUnitPersistence", 102 "ListAll", 103 ) 104 } 105 106 func (s *envPersistenceSuite) TestListAllFailed(c *gc.C) { 107 failure := errors.Errorf("<failed!>") 108 s.Stub.SetErrors(failure) 109 110 persist := NewEnvPersistence(s.base) 111 persist.newUnitPersist = s.base.newUnitPersistence 112 113 _, err := persist.ListAll() 114 115 c.Check(errors.Cause(err), gc.Equals, failure) 116 } 117 118 // TODO(ericsnow) Factor this out to a testing package. 119 120 func checkPayloads(c *gc.C, payloads []payload.FullPayloadInfo, expectedList ...payload.FullPayloadInfo) { 121 remainder := make([]payload.FullPayloadInfo, len(payloads)) 122 copy(remainder, payloads) 123 var noMatch []payload.FullPayloadInfo 124 for _, expected := range expectedList { 125 found := false 126 for i, payload := range remainder { 127 if reflect.DeepEqual(payload, expected) { 128 remainder = append(remainder[:i], remainder[i+1:]...) 129 found = true 130 break 131 } 132 } 133 if !found { 134 noMatch = append(noMatch, expected) 135 } 136 } 137 138 ok1 := c.Check(noMatch, gc.HasLen, 0) 139 ok2 := c.Check(remainder, gc.HasLen, 0) 140 if !ok1 || !ok2 { 141 c.Logf("<<<<<<<<\nexpected:") 142 for _, payload := range expectedList { 143 c.Logf("%#v", payload) 144 } 145 c.Logf("--------\ngot:") 146 for _, payload := range payloads { 147 c.Logf("%#v", payload) 148 } 149 c.Logf(">>>>>>>>") 150 } 151 } 152 153 type stubEnvPersistenceBase struct { 154 PersistenceBase 155 stub *testing.Stub 156 machines []string 157 units map[string]map[string]bool 158 unitPersists map[string]*stubUnitPersistence 159 } 160 161 func (s *stubEnvPersistenceBase) setPayloads(payloads ...payload.FullPayloadInfo) { 162 if s.unitPersists == nil && len(payloads) > 0 { 163 s.unitPersists = make(map[string]*stubUnitPersistence) 164 } 165 166 for _, pl := range payloads { 167 s.setUnits(pl.Machine, pl.Unit) 168 169 unitPayloads := s.unitPersists[pl.Unit] 170 if unitPayloads == nil { 171 unitPayloads = &stubUnitPersistence{stub: s.stub} 172 s.unitPersists[pl.Unit] = unitPayloads 173 } 174 175 unitPayloads.setPayloads(pl.Payload) 176 } 177 } 178 179 func (s *stubEnvPersistenceBase) setUnits(machine string, units ...string) { 180 if s.units == nil { 181 s.units = make(map[string]map[string]bool) 182 } 183 if _, ok := s.units[machine]; !ok { 184 s.machines = append(s.machines, machine) 185 s.units[machine] = make(map[string]bool) 186 } 187 188 for _, unit := range units { 189 s.units[machine][unit] = true 190 } 191 } 192 193 func (s *stubEnvPersistenceBase) newUnitPersistence(base PersistenceBase, unit string) unitPersistence { 194 s.stub.AddCall("newUnitPersistence", base, unit) 195 s.stub.NextErr() // pop one off 196 197 persist, ok := s.unitPersists[unit] 198 if !ok { 199 if s.unitPersists == nil { 200 s.unitPersists = make(map[string]*stubUnitPersistence) 201 } 202 persist = &stubUnitPersistence{stub: s.stub} 203 s.unitPersists[unit] = persist 204 } 205 return persist 206 } 207 208 func (s *stubEnvPersistenceBase) Machines() ([]string, error) { 209 s.stub.AddCall("Machines") 210 if err := s.stub.NextErr(); err != nil { 211 return nil, errors.Trace(err) 212 } 213 214 var names []string 215 for _, name := range s.machines { 216 names = append(names, name) 217 } 218 return names, nil 219 } 220 221 func (s *stubEnvPersistenceBase) MachineUnits(machine string) ([]string, error) { 222 s.stub.AddCall("MachineUnits", machine) 223 if err := s.stub.NextErr(); err != nil { 224 return nil, errors.Trace(err) 225 } 226 227 var units []string 228 for unit := range s.units[machine] { 229 units = append(units, unit) 230 } 231 return units, nil 232 } 233 234 type stubUnitPersistence struct { 235 stub *testing.Stub 236 237 payloads []payload.Payload 238 } 239 240 func (s *stubUnitPersistence) setPayloads(payloads ...payload.Payload) { 241 s.payloads = append(s.payloads, payloads...) 242 } 243 244 func (s *stubUnitPersistence) ListAll() ([]payload.Payload, error) { 245 s.stub.AddCall("ListAll") 246 if err := s.stub.NextErr(); err != nil { 247 return nil, errors.Trace(err) 248 } 249 250 return s.payloads, nil 251 }