github.com/juju/juju@v0.0.0-20240327075706-a90865de2538/worker/multiwatcher/testbacking/backing.go (about)

     1  // Copyright 2019 Canonical Ltd.
     2  // Licensed under the AGPLv3, see LICENCE file for details.
     3  
     4  package testbacking
     5  
     6  import (
     7  	"strings"
     8  	"sync"
     9  	"time"
    10  
    11  	"github.com/juju/errors"
    12  
    13  	"github.com/juju/juju/core/multiwatcher"
    14  	"github.com/juju/juju/state/watcher"
    15  	"github.com/juju/juju/testing"
    16  )
    17  
    18  // Backing is a test state AllWatcherBacking
    19  type Backing struct {
    20  	mu       sync.Mutex
    21  	fetchErr error
    22  	entities map[multiwatcher.EntityID]multiwatcher.EntityInfo
    23  	watchc   chan<- watcher.Change
    24  	txnRevno int64
    25  }
    26  
    27  // New returns a new test backing.
    28  func New(initial []multiwatcher.EntityInfo) *Backing {
    29  	b := &Backing{
    30  		entities: make(map[multiwatcher.EntityID]multiwatcher.EntityInfo),
    31  	}
    32  	for _, info := range initial {
    33  		b.entities[info.EntityID()] = info
    34  	}
    35  	return b
    36  }
    37  
    38  // Changed process the change event from the state base watcher.
    39  func (b *Backing) Changed(store multiwatcher.Store, change watcher.Change) error {
    40  	modelUUID, changeID, ok := SplitDocID(change.Id.(string))
    41  	if !ok {
    42  		return errors.Errorf("unexpected id format: %v", change.Id)
    43  	}
    44  	id := multiwatcher.EntityID{
    45  		Kind:      change.C,
    46  		ModelUUID: modelUUID,
    47  		ID:        changeID,
    48  	}
    49  	info, err := b.fetch(id)
    50  	if errors.IsNotFound(err) {
    51  		store.Remove(id)
    52  		return nil
    53  	}
    54  	if err != nil {
    55  		return err
    56  	}
    57  	store.Update(info)
    58  	return nil
    59  }
    60  
    61  func (b *Backing) fetch(id multiwatcher.EntityID) (multiwatcher.EntityInfo, error) {
    62  	b.mu.Lock()
    63  	defer b.mu.Unlock()
    64  	if b.fetchErr != nil {
    65  		return nil, b.fetchErr
    66  	}
    67  	if info, ok := b.entities[id]; ok {
    68  		return info, nil
    69  	}
    70  	return nil, errors.NotFoundf("%s.%s", id.Kind, id.ID)
    71  }
    72  
    73  // Watch sets up the channel for the events.
    74  func (b *Backing) Watch(c chan<- watcher.Change) {
    75  	b.mu.Lock()
    76  	defer b.mu.Unlock()
    77  	if b.watchc != nil {
    78  		panic("test backing can only watch once")
    79  	}
    80  	b.watchc = c
    81  }
    82  
    83  // Unwatch clears the channel for the events.
    84  func (b *Backing) Unwatch(c chan<- watcher.Change) {
    85  	b.mu.Lock()
    86  	defer b.mu.Unlock()
    87  	if c != b.watchc {
    88  		panic("unwatching wrong channel")
    89  	}
    90  	b.watchc = nil
    91  }
    92  
    93  // GetAll does the initial population of the store.
    94  func (b *Backing) GetAll(store multiwatcher.Store) error {
    95  	b.mu.Lock()
    96  	defer b.mu.Unlock()
    97  	for _, info := range b.entities {
    98  		store.Update(info)
    99  	}
   100  	return nil
   101  }
   102  
   103  // UpdateEntity allows the test to push an update.
   104  func (b *Backing) UpdateEntity(info multiwatcher.EntityInfo) {
   105  	b.mu.Lock()
   106  	id := info.EntityID()
   107  	b.entities[id] = info
   108  	b.txnRevno++
   109  	change := watcher.Change{
   110  		C:     id.Kind,
   111  		Id:    EnsureModelUUID(id.ModelUUID, id.ID),
   112  		Revno: b.txnRevno, // This is actually ignored, but fill it in anyway.
   113  	}
   114  	listener := b.watchc
   115  	b.mu.Unlock()
   116  	if b.watchc != nil {
   117  		select {
   118  		case listener <- change:
   119  		case <-time.After(testing.LongWait):
   120  			panic("watcher isn't reading off channel")
   121  
   122  		}
   123  	}
   124  }
   125  
   126  // SetFetchError queues up an error to return on the next fetch.
   127  func (b *Backing) SetFetchError(err error) {
   128  	b.mu.Lock()
   129  	defer b.mu.Unlock()
   130  	b.fetchErr = err
   131  }
   132  
   133  // DeleteEntity allows the test to push a delete through the test.
   134  func (b *Backing) DeleteEntity(id multiwatcher.EntityID) {
   135  	b.mu.Lock()
   136  	delete(b.entities, id)
   137  	change := watcher.Change{
   138  		C:     id.Kind,
   139  		Id:    EnsureModelUUID(id.ModelUUID, id.ID),
   140  		Revno: -1,
   141  	}
   142  	b.txnRevno++
   143  	listener := b.watchc
   144  	b.mu.Unlock()
   145  	if b.watchc != nil {
   146  		select {
   147  		case listener <- change:
   148  		case <-time.After(testing.LongWait):
   149  			panic("watcher isn't reading off channel")
   150  		}
   151  	}
   152  }
   153  
   154  // EnsureModelUUID is exported as it is used in other _test files.
   155  func EnsureModelUUID(modelUUID, id string) string {
   156  	prefix := modelUUID + ":"
   157  	if strings.HasPrefix(id, prefix) {
   158  		return id
   159  	}
   160  	return prefix + id
   161  }
   162  
   163  // SplitDocID is exported as it is used in other _test files.
   164  func SplitDocID(id string) (string, string, bool) {
   165  	parts := strings.SplitN(id, ":", 2)
   166  	if len(parts) != 2 {
   167  		return "", "", false
   168  	}
   169  	return parts[0], parts[1], true
   170  }