github.com/juju/juju@v0.0.0-20240327075706-a90865de2538/worker/multiwatcher/testbacking/backing.go (about) 1 // Copyright 2019 Canonical Ltd. 2 // Licensed under the AGPLv3, see LICENCE file for details. 3 4 package testbacking 5 6 import ( 7 "strings" 8 "sync" 9 "time" 10 11 "github.com/juju/errors" 12 13 "github.com/juju/juju/core/multiwatcher" 14 "github.com/juju/juju/state/watcher" 15 "github.com/juju/juju/testing" 16 ) 17 18 // Backing is a test state AllWatcherBacking 19 type Backing struct { 20 mu sync.Mutex 21 fetchErr error 22 entities map[multiwatcher.EntityID]multiwatcher.EntityInfo 23 watchc chan<- watcher.Change 24 txnRevno int64 25 } 26 27 // New returns a new test backing. 28 func New(initial []multiwatcher.EntityInfo) *Backing { 29 b := &Backing{ 30 entities: make(map[multiwatcher.EntityID]multiwatcher.EntityInfo), 31 } 32 for _, info := range initial { 33 b.entities[info.EntityID()] = info 34 } 35 return b 36 } 37 38 // Changed process the change event from the state base watcher. 39 func (b *Backing) Changed(store multiwatcher.Store, change watcher.Change) error { 40 modelUUID, changeID, ok := SplitDocID(change.Id.(string)) 41 if !ok { 42 return errors.Errorf("unexpected id format: %v", change.Id) 43 } 44 id := multiwatcher.EntityID{ 45 Kind: change.C, 46 ModelUUID: modelUUID, 47 ID: changeID, 48 } 49 info, err := b.fetch(id) 50 if errors.IsNotFound(err) { 51 store.Remove(id) 52 return nil 53 } 54 if err != nil { 55 return err 56 } 57 store.Update(info) 58 return nil 59 } 60 61 func (b *Backing) fetch(id multiwatcher.EntityID) (multiwatcher.EntityInfo, error) { 62 b.mu.Lock() 63 defer b.mu.Unlock() 64 if b.fetchErr != nil { 65 return nil, b.fetchErr 66 } 67 if info, ok := b.entities[id]; ok { 68 return info, nil 69 } 70 return nil, errors.NotFoundf("%s.%s", id.Kind, id.ID) 71 } 72 73 // Watch sets up the channel for the events. 74 func (b *Backing) Watch(c chan<- watcher.Change) { 75 b.mu.Lock() 76 defer b.mu.Unlock() 77 if b.watchc != nil { 78 panic("test backing can only watch once") 79 } 80 b.watchc = c 81 } 82 83 // Unwatch clears the channel for the events. 84 func (b *Backing) Unwatch(c chan<- watcher.Change) { 85 b.mu.Lock() 86 defer b.mu.Unlock() 87 if c != b.watchc { 88 panic("unwatching wrong channel") 89 } 90 b.watchc = nil 91 } 92 93 // GetAll does the initial population of the store. 94 func (b *Backing) GetAll(store multiwatcher.Store) error { 95 b.mu.Lock() 96 defer b.mu.Unlock() 97 for _, info := range b.entities { 98 store.Update(info) 99 } 100 return nil 101 } 102 103 // UpdateEntity allows the test to push an update. 104 func (b *Backing) UpdateEntity(info multiwatcher.EntityInfo) { 105 b.mu.Lock() 106 id := info.EntityID() 107 b.entities[id] = info 108 b.txnRevno++ 109 change := watcher.Change{ 110 C: id.Kind, 111 Id: EnsureModelUUID(id.ModelUUID, id.ID), 112 Revno: b.txnRevno, // This is actually ignored, but fill it in anyway. 113 } 114 listener := b.watchc 115 b.mu.Unlock() 116 if b.watchc != nil { 117 select { 118 case listener <- change: 119 case <-time.After(testing.LongWait): 120 panic("watcher isn't reading off channel") 121 122 } 123 } 124 } 125 126 // SetFetchError queues up an error to return on the next fetch. 127 func (b *Backing) SetFetchError(err error) { 128 b.mu.Lock() 129 defer b.mu.Unlock() 130 b.fetchErr = err 131 } 132 133 // DeleteEntity allows the test to push a delete through the test. 134 func (b *Backing) DeleteEntity(id multiwatcher.EntityID) { 135 b.mu.Lock() 136 delete(b.entities, id) 137 change := watcher.Change{ 138 C: id.Kind, 139 Id: EnsureModelUUID(id.ModelUUID, id.ID), 140 Revno: -1, 141 } 142 b.txnRevno++ 143 listener := b.watchc 144 b.mu.Unlock() 145 if b.watchc != nil { 146 select { 147 case listener <- change: 148 case <-time.After(testing.LongWait): 149 panic("watcher isn't reading off channel") 150 } 151 } 152 } 153 154 // EnsureModelUUID is exported as it is used in other _test files. 155 func EnsureModelUUID(modelUUID, id string) string { 156 prefix := modelUUID + ":" 157 if strings.HasPrefix(id, prefix) { 158 return id 159 } 160 return prefix + id 161 } 162 163 // SplitDocID is exported as it is used in other _test files. 164 func SplitDocID(id string) (string, string, bool) { 165 parts := strings.SplitN(id, ":", 2) 166 if len(parts) != 2 { 167 return "", "", false 168 } 169 return parts[0], parts[1], true 170 }