github.com/mattyw/juju@v0.0.0-20140610034352-732aecd63861/worker/peergrouper/mock_test.go (about) 1 // Copyright 2014 Canonical Ltd. 2 // Licensed under the AGPLv3, see LICENCE file for details. 3 4 package peergrouper 5 6 import ( 7 "encoding/json" 8 "fmt" 9 "net" 10 "path" 11 "reflect" 12 "strconv" 13 "sync" 14 15 "github.com/juju/errors" 16 "github.com/juju/utils/voyeur" 17 "launchpad.net/tomb" 18 19 "github.com/juju/juju/instance" 20 "github.com/juju/juju/replicaset" 21 "github.com/juju/juju/state" 22 "github.com/juju/juju/worker" 23 ) 24 25 // This file holds helper functions for mocking pieces of State and replicaset 26 // that we don't want to directly depend on in unit tests. 27 28 type fakeState struct { 29 mu sync.Mutex 30 machines map[string]*fakeMachine 31 stateServers voyeur.Value // of *state.StateServerInfo 32 session *fakeMongoSession 33 check func(st *fakeState) error 34 } 35 36 var ( 37 _ stateInterface = (*fakeState)(nil) 38 _ stateMachine = (*fakeMachine)(nil) 39 _ mongoSession = (*fakeMongoSession)(nil) 40 ) 41 42 type errorPattern struct { 43 pattern string 44 errFunc func() error 45 } 46 47 var ( 48 errorsMutex sync.Mutex 49 errorPatterns []errorPattern 50 ) 51 52 // setErrorFor causes the given error to be returned 53 // from any mock call that matches the given 54 // string, which may contain wildcards as 55 // in path.Match. 56 // 57 // The standard form for errors is: 58 // Type.Function <arg>... 59 // See individual functions for details. 60 func setErrorFor(what string, err error) { 61 setErrorFuncFor(what, func() error { 62 return err 63 }) 64 } 65 66 // setErrorFuncFor causes the given function 67 // to be invoked to return the error for the 68 // given pattern. 69 func setErrorFuncFor(what string, errFunc func() error) { 70 errorsMutex.Lock() 71 defer errorsMutex.Unlock() 72 errorPatterns = append(errorPatterns, errorPattern{ 73 pattern: what, 74 errFunc: errFunc, 75 }) 76 } 77 78 // errorFor concatenates the call name 79 // with all the args, space separated, 80 // and returns any error registered with 81 // setErrorFor that matches the resulting string. 82 func errorFor(name string, args ...interface{}) error { 83 errorsMutex.Lock() 84 s := name 85 for _, arg := range args { 86 s += " " + fmt.Sprint(arg) 87 } 88 f := func() error { return nil } 89 for _, pattern := range errorPatterns { 90 if ok, _ := path.Match(pattern.pattern, s); ok { 91 f = pattern.errFunc 92 break 93 } 94 } 95 errorsMutex.Unlock() 96 err := f() 97 logger.Errorf("errorFor %q -> %v", s, err) 98 return err 99 } 100 101 func resetErrors() { 102 errorsMutex.Lock() 103 defer errorsMutex.Unlock() 104 errorPatterns = errorPatterns[:0] 105 } 106 107 func newFakeState() *fakeState { 108 st := &fakeState{ 109 machines: make(map[string]*fakeMachine), 110 } 111 st.session = newFakeMongoSession(st) 112 st.stateServers.Set(&state.StateServerInfo{}) 113 return st 114 } 115 116 func (st *fakeState) MongoSession() mongoSession { 117 return st.session 118 } 119 120 func (st *fakeState) checkInvariants() { 121 if st.check == nil { 122 return 123 } 124 if err := st.check(st); err != nil { 125 // Force a panic, otherwise we can deadlock 126 // when called from within the worker. 127 go panic(err) 128 select {} 129 } 130 } 131 132 // checkInvariants checks that all the expected invariants 133 // in the state hold true. Currently we check that: 134 // - total number of votes is odd. 135 // - member voting status implies that machine has vote. 136 func checkInvariants(st *fakeState) error { 137 members := st.session.members.Get().([]replicaset.Member) 138 voteCount := 0 139 for _, m := range members { 140 votes := 1 141 if m.Votes != nil { 142 votes = *m.Votes 143 } 144 voteCount += votes 145 if id, ok := m.Tags[jujuMachineTag]; ok { 146 if votes > 0 { 147 m := st.machine(id) 148 if m == nil { 149 return fmt.Errorf("voting member with machine id %q has no associated Machine", id) 150 } 151 if !m.HasVote() { 152 return fmt.Errorf("machine %q should be marked as having the vote, but does not", id) 153 } 154 } 155 } 156 } 157 if voteCount%2 != 1 { 158 return fmt.Errorf("total vote count is not odd (got %d)", voteCount) 159 } 160 return nil 161 } 162 163 type invariantChecker interface { 164 checkInvariants() 165 } 166 167 // machine is similar to Machine except that 168 // it bypasses the error mocking machinery. 169 // It returns nil if there is no machine with the 170 // given id. 171 func (st *fakeState) machine(id string) *fakeMachine { 172 st.mu.Lock() 173 defer st.mu.Unlock() 174 return st.machines[id] 175 } 176 177 func (st *fakeState) Machine(id string) (stateMachine, error) { 178 if err := errorFor("State.Machine", id); err != nil { 179 return nil, err 180 } 181 if m := st.machine(id); m != nil { 182 return m, nil 183 } 184 return nil, errors.NotFoundf("machine %s", id) 185 } 186 187 func (st *fakeState) addMachine(id string, wantsVote bool) *fakeMachine { 188 st.mu.Lock() 189 defer st.mu.Unlock() 190 logger.Infof("fakeState.addMachine %q", id) 191 if st.machines[id] != nil { 192 panic(fmt.Errorf("id %q already used", id)) 193 } 194 m := &fakeMachine{ 195 checker: st, 196 doc: machineDoc{ 197 id: id, 198 wantsVote: wantsVote, 199 }, 200 } 201 st.machines[id] = m 202 m.val.Set(m.doc) 203 return m 204 } 205 206 func (st *fakeState) removeMachine(id string) { 207 st.mu.Lock() 208 defer st.mu.Unlock() 209 if st.machines[id] == nil { 210 panic(fmt.Errorf("removing non-existent machine %q", id)) 211 } 212 delete(st.machines, id) 213 } 214 215 func (st *fakeState) setStateServers(ids ...string) { 216 st.stateServers.Set(&state.StateServerInfo{ 217 MachineIds: ids, 218 }) 219 } 220 221 func (st *fakeState) StateServerInfo() (*state.StateServerInfo, error) { 222 if err := errorFor("State.StateServerInfo"); err != nil { 223 return nil, err 224 } 225 return deepCopy(st.stateServers.Get()).(*state.StateServerInfo), nil 226 } 227 228 func (st *fakeState) WatchStateServerInfo() state.NotifyWatcher { 229 return WatchValue(&st.stateServers) 230 } 231 232 type fakeMachine struct { 233 mu sync.Mutex 234 val voyeur.Value // of machineDoc 235 doc machineDoc 236 checker invariantChecker 237 } 238 239 type machineDoc struct { 240 id string 241 wantsVote bool 242 hasVote bool 243 instanceId instance.Id 244 mongoHostPorts []instance.HostPort 245 apiHostPorts []instance.HostPort 246 } 247 248 func (m *fakeMachine) Refresh() error { 249 if err := errorFor("Machine.Refresh", m.doc.id); err != nil { 250 return err 251 } 252 m.doc = m.val.Get().(machineDoc) 253 return nil 254 } 255 256 func (m *fakeMachine) GoString() string { 257 return fmt.Sprintf("&fakeMachine{%#v}", m.doc) 258 } 259 260 func (m *fakeMachine) Id() string { 261 return m.doc.id 262 } 263 264 func (m *fakeMachine) InstanceId() (instance.Id, error) { 265 if err := errorFor("Machine.InstanceId", m.doc.id); err != nil { 266 return "", err 267 } 268 return m.doc.instanceId, nil 269 } 270 271 func (m *fakeMachine) Watch() state.NotifyWatcher { 272 return WatchValue(&m.val) 273 } 274 275 func (m *fakeMachine) WantsVote() bool { 276 return m.doc.wantsVote 277 } 278 279 func (m *fakeMachine) HasVote() bool { 280 return m.doc.hasVote 281 } 282 283 func (m *fakeMachine) MongoHostPorts() []instance.HostPort { 284 return m.doc.mongoHostPorts 285 } 286 287 func (m *fakeMachine) APIHostPorts() []instance.HostPort { 288 return m.doc.apiHostPorts 289 } 290 291 // mutate atomically changes the machineDoc of 292 // the receiver by mutating it with the provided function. 293 func (m *fakeMachine) mutate(f func(*machineDoc)) { 294 m.mu.Lock() 295 doc := m.val.Get().(machineDoc) 296 f(&doc) 297 m.val.Set(doc) 298 f(&m.doc) 299 m.mu.Unlock() 300 m.checker.checkInvariants() 301 } 302 303 func (m *fakeMachine) setStateHostPort(hostPort string) { 304 var mongoHostPorts []instance.HostPort 305 if hostPort != "" { 306 host, portStr, err := net.SplitHostPort(hostPort) 307 if err != nil { 308 panic(err) 309 } 310 port, err := strconv.Atoi(portStr) 311 if err != nil { 312 panic(err) 313 } 314 mongoHostPorts = instance.AddressesWithPort(instance.NewAddresses(host), port) 315 mongoHostPorts[0].NetworkScope = instance.NetworkCloudLocal 316 } 317 318 m.mutate(func(doc *machineDoc) { 319 doc.mongoHostPorts = mongoHostPorts 320 }) 321 } 322 323 func (m *fakeMachine) setMongoHostPorts(hostPorts []instance.HostPort) { 324 m.mutate(func(doc *machineDoc) { 325 doc.mongoHostPorts = hostPorts 326 }) 327 } 328 329 func (m *fakeMachine) setAPIHostPorts(hostPorts []instance.HostPort) { 330 m.mutate(func(doc *machineDoc) { 331 doc.apiHostPorts = hostPorts 332 }) 333 } 334 335 func (m *fakeMachine) setInstanceId(instanceId instance.Id) { 336 m.mutate(func(doc *machineDoc) { 337 doc.instanceId = instanceId 338 }) 339 } 340 341 // SetHasVote implements stateMachine.SetHasVote. 342 func (m *fakeMachine) SetHasVote(hasVote bool) error { 343 if err := errorFor("Machine.SetHasVote", m.doc.id, hasVote); err != nil { 344 return err 345 } 346 m.mutate(func(doc *machineDoc) { 347 doc.hasVote = hasVote 348 }) 349 return nil 350 } 351 352 func (m *fakeMachine) setWantsVote(wantsVote bool) { 353 m.mutate(func(doc *machineDoc) { 354 doc.wantsVote = wantsVote 355 }) 356 } 357 358 type fakeMongoSession struct { 359 // If InstantlyReady is true, replica status of 360 // all members will be instantly reported as ready. 361 InstantlyReady bool 362 363 checker invariantChecker 364 members voyeur.Value // of []replicaset.Member 365 status voyeur.Value // of *replicaset.Status 366 } 367 368 // newFakeMongoSession returns a mock implementation of mongoSession. 369 func newFakeMongoSession(checker invariantChecker) *fakeMongoSession { 370 s := new(fakeMongoSession) 371 s.checker = checker 372 s.members.Set([]replicaset.Member(nil)) 373 s.status.Set(&replicaset.Status{}) 374 return s 375 } 376 377 // CurrentMembers implements mongoSession.CurrentMembers. 378 func (session *fakeMongoSession) CurrentMembers() ([]replicaset.Member, error) { 379 if err := errorFor("Session.CurrentMembers"); err != nil { 380 return nil, err 381 } 382 return deepCopy(session.members.Get()).([]replicaset.Member), nil 383 } 384 385 // CurrentStatus implements mongoSession.CurrentStatus. 386 func (session *fakeMongoSession) CurrentStatus() (*replicaset.Status, error) { 387 if err := errorFor("Session.CurrentStatus"); err != nil { 388 return nil, err 389 } 390 return deepCopy(session.status.Get()).(*replicaset.Status), nil 391 } 392 393 // setStatus sets the status of the current members of the session. 394 func (session *fakeMongoSession) setStatus(members []replicaset.MemberStatus) { 395 session.status.Set(deepCopy(&replicaset.Status{ 396 Members: members, 397 })) 398 } 399 400 // Set implements mongoSession.Set 401 func (session *fakeMongoSession) Set(members []replicaset.Member) error { 402 if err := errorFor("Session.Set"); err != nil { 403 logger.Infof("not setting replicaset members to %#v", members) 404 return err 405 } 406 logger.Infof("setting replicaset members to %#v", members) 407 session.members.Set(deepCopy(members)) 408 if session.InstantlyReady { 409 statuses := make([]replicaset.MemberStatus, len(members)) 410 for i, m := range members { 411 statuses[i] = replicaset.MemberStatus{ 412 Id: m.Id, 413 Address: m.Address, 414 Healthy: true, 415 State: replicaset.SecondaryState, 416 } 417 if i == 0 { 418 statuses[i].State = replicaset.PrimaryState 419 } 420 } 421 session.setStatus(statuses) 422 } 423 session.checker.checkInvariants() 424 return nil 425 } 426 427 // deepCopy makes a deep copy of any type by marshalling 428 // it as JSON, then unmarshalling it. 429 func deepCopy(x interface{}) interface{} { 430 v := reflect.ValueOf(x) 431 data, err := json.Marshal(x) 432 if err != nil { 433 panic(fmt.Errorf("cannot marshal %#v: %v", x, err)) 434 } 435 newv := reflect.New(v.Type()) 436 if err := json.Unmarshal(data, newv.Interface()); err != nil { 437 panic(fmt.Errorf("cannot unmarshal %q into %s", data, newv.Type())) 438 } 439 // sanity check 440 newx := newv.Elem().Interface() 441 if !reflect.DeepEqual(newx, x) { 442 panic(fmt.Errorf("value not deep-copied correctly")) 443 } 444 return newx 445 } 446 447 type notifier struct { 448 tomb tomb.Tomb 449 w *voyeur.Watcher 450 changes chan struct{} 451 } 452 453 // WatchValue returns a NotifyWatcher that triggers 454 // when the given value changes. Its Wait and Err methods 455 // never return a non-nil error. 456 func WatchValue(val *voyeur.Value) state.NotifyWatcher { 457 n := ¬ifier{ 458 w: val.Watch(), 459 changes: make(chan struct{}), 460 } 461 go n.loop() 462 return n 463 } 464 465 func (n *notifier) loop() { 466 defer n.tomb.Done() 467 for n.w.Next() { 468 select { 469 case n.changes <- struct{}{}: 470 case <-n.tomb.Dying(): 471 } 472 } 473 } 474 475 // Changes returns a channel that sends a value when the value changes. 476 // The value itself can be retrieved by calling the value's Get method. 477 func (n *notifier) Changes() <-chan struct{} { 478 return n.changes 479 } 480 481 // Kill stops the notifier but does not wait for it to finish. 482 func (n *notifier) Kill() { 483 n.tomb.Kill(nil) 484 n.w.Close() 485 } 486 487 func (n *notifier) Err() error { 488 return n.tomb.Err() 489 } 490 491 // Wait waits for the notifier to finish. It always returns nil. 492 func (n *notifier) Wait() error { 493 return n.tomb.Wait() 494 } 495 496 func (n *notifier) Stop() error { 497 return worker.Stop(n) 498 }