github.com/juju/juju@v0.0.0-20240430160146-1752b71fcf00/state/sshhostkeys.go (about)

     1  // Copyright 2016 Canonical Ltd.
     2  // Licensed under the AGPLv3, see LICENCE file for details.
     3  
     4  package state
     5  
     6  import (
     7  	"sort"
     8  
     9  	"github.com/juju/errors"
    10  	"github.com/juju/mgo/v3"
    11  	"github.com/juju/mgo/v3/bson"
    12  	"github.com/juju/mgo/v3/txn"
    13  	"github.com/juju/names/v5"
    14  	jujutxn "github.com/juju/txn/v3"
    15  )
    16  
    17  // SSHHostKeys holds the public SSH host keys for an entity (almost
    18  // certainly a machine).
    19  //
    20  // The host keys are one line each and are stored in the same format
    21  // as the SSH authorized_keys and ssh_host_key*.pub files.
    22  type SSHHostKeys []string
    23  
    24  // sshHostKeysDoc represents the MongoDB document that stores the SSH
    25  // host keys for an entity.
    26  //
    27  // Note that the document id hasn't been included because we don't
    28  // need to read it or (directly) write it.
    29  type sshHostKeysDoc struct {
    30  	Keys []string `bson:"keys"`
    31  }
    32  
    33  // GetSSHHostKeys retrieves the SSH host keys stored for an entity.
    34  // /
    35  // NOTE: Currently only machines are supported. This can be
    36  // generalised to take other tag types later, if and when we need it.
    37  func (st *State) GetSSHHostKeys(tag names.MachineTag) (SSHHostKeys, error) {
    38  	coll, closer := st.db().GetCollection(sshHostKeysC)
    39  	defer closer()
    40  
    41  	var doc sshHostKeysDoc
    42  	err := coll.FindId(machineGlobalKey(tag.Id())).One(&doc)
    43  	if err == mgo.ErrNotFound {
    44  		return nil, errors.NotFoundf("keys")
    45  	} else if err != nil {
    46  		return nil, errors.Annotate(err, "key lookup failed")
    47  	}
    48  	return SSHHostKeys(doc.Keys), nil
    49  }
    50  
    51  // keysEqual checks if the ssh host keys are the same between two sets.
    52  // we shouldn't care about the order of the keys.
    53  func keysEqual(a, b []string) bool {
    54  	if len(a) != len(b) {
    55  		return false
    56  	}
    57  	a = a[:]
    58  	b = b[:]
    59  	sort.Strings(a)
    60  	sort.Strings(b)
    61  	for i := range a {
    62  		if a[i] != b[i] {
    63  			return false
    64  		}
    65  	}
    66  	return true
    67  }
    68  
    69  // SetSSHHostKeys updates the stored SSH host keys for an entity.
    70  //
    71  // See the note for GetSSHHostKeys regarding supported entities.
    72  func (st *State) SetSSHHostKeys(tag names.MachineTag, keys SSHHostKeys) error {
    73  	coll, closer := st.db().GetCollection(sshHostKeysC)
    74  	defer closer()
    75  	id := machineGlobalKey(tag.Id())
    76  	doc := sshHostKeysDoc{
    77  		Keys: keys,
    78  	}
    79  	var dbDoc sshHostKeysDoc
    80  	buildTxn := func(attempt int) ([]txn.Op, error) {
    81  		err := coll.FindId(id).One(&dbDoc)
    82  		if err != nil {
    83  			if err == mgo.ErrNotFound {
    84  				return []txn.Op{{
    85  					C:      sshHostKeysC,
    86  					Id:     id,
    87  					Insert: doc,
    88  				}}, nil
    89  			}
    90  			return nil, err
    91  		}
    92  		if keysEqual(dbDoc.Keys, keys) {
    93  			return nil, jujutxn.ErrNoOperations
    94  		}
    95  		return []txn.Op{{
    96  			C:      sshHostKeysC,
    97  			Id:     id,
    98  			Update: bson.M{"$set": doc},
    99  		}}, nil
   100  	}
   101  
   102  	if err := st.db().Run(buildTxn); err != nil {
   103  		return errors.Annotate(err, "SSH host key update failed")
   104  	}
   105  	return nil
   106  }
   107  
   108  // removeSSHHostKeyOp returns the operation needed to remove the SSH
   109  // host key document associated with the given globalKey.
   110  func removeSSHHostKeyOp(globalKey string) txn.Op {
   111  	return txn.Op{
   112  		C:      sshHostKeysC,
   113  		Id:     globalKey,
   114  		Remove: true,
   115  	}
   116  }