github.com/juju/juju@v0.0.0-20240430160146-1752b71fcf00/worker/uniter/relation/statemanager.go (about)

     1  // Copyright 2020 Canonical Ltd.
     2  // Licensed under the AGPLv3, see LICENCE file for details.
     3  
     4  package relation
     5  
     6  import (
     7  	"fmt"
     8  	"sync"
     9  
    10  	"github.com/juju/collections/set"
    11  	"github.com/juju/errors"
    12  	"github.com/juju/names/v5"
    13  	"github.com/kr/pretty"
    14  	"gopkg.in/yaml.v2"
    15  
    16  	"github.com/juju/juju/rpc/params"
    17  )
    18  
    19  // NewStateManager creates a new StateManager instance.
    20  func NewStateManager(rw UnitStateReadWriter, logger Logger) (StateManager, error) {
    21  	mgr := &stateManager{unitStateRW: rw, logger: logger}
    22  	return mgr, mgr.initialize()
    23  }
    24  
    25  type stateManager struct {
    26  	unitStateRW   UnitStateReadWriter
    27  	relationState map[int]State
    28  	logger        Logger
    29  	mu            sync.Mutex
    30  }
    31  
    32  // Relation returns a copy of the relation state for
    33  // the given id. Returns NotFound.
    34  func (m *stateManager) Relation(id int) (*State, error) {
    35  	m.mu.Lock()
    36  	defer m.mu.Unlock()
    37  	if s, ok := m.relationState[id]; ok {
    38  		return s.copy(), nil
    39  	}
    40  	return nil, errors.NotFoundf("relation %d", id)
    41  }
    42  
    43  // RemoveRelation removes the state for the given id from the
    44  // manager.  The change to the manager is only made when the
    45  // data is successfully saved.
    46  func (m *stateManager) RemoveRelation(id int, unitGetter UnitGetter, knownUnits map[string]bool) error {
    47  	m.mu.Lock()
    48  	defer m.mu.Unlock()
    49  	st, ok := m.relationState[id]
    50  	if !ok {
    51  		return errors.NotFoundf("relation %d", id)
    52  	}
    53  
    54  	// Check that the member unit exists - if not we ignore it.
    55  	// Cache the known member units so we only do any look up once per unit.
    56  	knownMembers := set.NewStrings()
    57  	for unitName := range st.Members {
    58  		unitExists, ok := knownUnits[unitName]
    59  		if !ok {
    60  			_, err := unitGetter.Unit(names.NewUnitTag(unitName))
    61  			if err != nil && !params.IsCodeNotFoundOrCodeUnauthorized(err) {
    62  				return errors.Trace(err)
    63  			}
    64  			unitExists = err == nil
    65  			knownUnits[unitName] = unitExists
    66  		}
    67  		if !unitExists {
    68  			m.logger.Warningf("unit %v in relation %d no longer exists", unitName, id)
    69  			continue
    70  		}
    71  		knownMembers.Add(unitName)
    72  	}
    73  	if knownMembers.Size() != 0 {
    74  		return errors.New(fmt.Sprintf("cannot remove persisted state, relation %d has members: %v", id, knownMembers.SortedValues()))
    75  	}
    76  	if err := m.remove(id); err != nil {
    77  		return err
    78  	}
    79  	delete(m.relationState, id)
    80  	return nil
    81  }
    82  
    83  // KnownIDs returns a slice of relation ids, known to the
    84  // state manager.
    85  func (m *stateManager) KnownIDs() []int {
    86  	m.mu.Lock()
    87  	defer m.mu.Unlock()
    88  	ids := make([]int, len(m.relationState))
    89  	// 0 is a valid id, and it's the initial value of an int
    90  	// ensure the only 0 is the slice should be there.
    91  	i := 0
    92  	for k := range m.relationState {
    93  		ids[i] = k
    94  		i += 1
    95  	}
    96  	return ids
    97  }
    98  
    99  // SetRelation persists the given state, overwriting the previous
   100  // state for a given id or creating state at a new id. The change to
   101  // the manager is only made when the data is successfully saved.
   102  func (m *stateManager) SetRelation(st *State) error {
   103  	m.mu.Lock()
   104  	defer m.mu.Unlock()
   105  	if err := m.write(st); err != nil {
   106  		return errors.Annotatef(err, "could not persist relation %d state", st.RelationId)
   107  	}
   108  	m.relationState[st.RelationId] = *st
   109  	return nil
   110  }
   111  
   112  // RelationFound returns true if the state manager has a
   113  // state for the given id.
   114  func (m *stateManager) RelationFound(id int) bool {
   115  	m.mu.Lock()
   116  	defer m.mu.Unlock()
   117  	_, ok := m.relationState[id]
   118  	return ok
   119  }
   120  
   121  // initialize loads the current state into the manager.
   122  func (m *stateManager) initialize() error {
   123  	unitState, err := m.unitStateRW.State()
   124  	if err != nil && !errors.IsNotFound(err) {
   125  		return errors.Trace(err)
   126  	}
   127  	m.relationState = make(map[int]State, len(unitState.RelationState))
   128  	if m.logger.IsTraceEnabled() {
   129  		m.logger.Tracef("initialising state manager: %# v", pretty.Formatter(unitState.RelationState))
   130  	}
   131  	for k, v := range unitState.RelationState {
   132  		var state State
   133  		if err = yaml.Unmarshal([]byte(v), &state); err != nil {
   134  			return errors.Annotatef(err, "cannot unmarshall relation %d state", k)
   135  		}
   136  		m.relationState[k] = state
   137  	}
   138  	return nil
   139  }
   140  
   141  func (m *stateManager) write(st *State) error {
   142  	newSt, err := m.stateToPersist()
   143  	if err != nil {
   144  		return errors.Trace(err)
   145  	}
   146  	str, err := st.YamlString()
   147  	if err != nil {
   148  		return errors.Trace(err)
   149  	}
   150  	newSt[st.RelationId] = str
   151  	return m.unitStateRW.SetState(params.SetUnitStateArg{RelationState: &newSt})
   152  }
   153  
   154  func (m *stateManager) remove(id int) error {
   155  	newSt, err := m.stateToPersist()
   156  	if err != nil {
   157  		return errors.Trace(err)
   158  	}
   159  	delete(newSt, id)
   160  	return m.unitStateRW.SetState(params.SetUnitStateArg{RelationState: &newSt})
   161  }
   162  
   163  // stateToPersist transforms the relationState of this manager
   164  // into a form used for UnitStateReadWriter SetState.
   165  func (m *stateManager) stateToPersist() (map[int]string, error) {
   166  	newSt := make(map[int]string, len(m.relationState))
   167  	for k, v := range m.relationState {
   168  		str, err := v.YamlString()
   169  		if err != nil {
   170  			return newSt, errors.Trace(err)
   171  		}
   172  		newSt[k] = str
   173  	}
   174  	return newSt, nil
   175  }