github.com/niedbalski/juju@v0.0.0-20190215020005-8ff100488e47/worker/apiservercertwatcher/manifold.go (about)

     1  // Copyright 2017 Canonical Ltd.
     2  // Licensed under the AGPLv3, see LICENCE file for details.
     3  
     4  package apiservercertwatcher
     5  
     6  import (
     7  	"crypto/tls"
     8  	"crypto/x509"
     9  	"strings"
    10  	"sync"
    11  
    12  	"github.com/juju/errors"
    13  	"github.com/juju/loggo"
    14  	"github.com/juju/utils/voyeur"
    15  	"gopkg.in/juju/worker.v1"
    16  	"gopkg.in/juju/worker.v1/dependency"
    17  	"gopkg.in/tomb.v2"
    18  
    19  	"github.com/juju/juju/agent"
    20  )
    21  
    22  var logger = loggo.GetLogger("juju.worker.apiservercertwatcher")
    23  
    24  type ManifoldConfig struct {
    25  	AgentName          string
    26  	AgentConfigChanged *voyeur.Value
    27  }
    28  
    29  // Manifold returns a dependency.Manifold which wraps an agent's
    30  // voyeur.Value which is set whenever the agent config is
    31  // changed. The manifold will not bounce when the certificates
    32  // change.
    33  //
    34  // The worker will watch for API server certificate changes,
    35  // and make the current value available via the manifold's Output.
    36  // The Output expects a pointer to a function of type:
    37  //    func() *tls.Certificate
    38  //
    39  // The resulting tls.Certificate's Leaf field will be set, to
    40  // ensure we only parse the certificate once. This allows the
    41  // consumer to obtain the associated DNS names.
    42  //
    43  // The manifold is intended to be a dependency for the apiserver.
    44  func Manifold(config ManifoldConfig) dependency.Manifold {
    45  	return dependency.Manifold{
    46  		Inputs: []string{config.AgentName},
    47  		Start: func(context dependency.Context) (worker.Worker, error) {
    48  			if config.AgentConfigChanged == nil {
    49  				return nil, errors.NotValidf("nil AgentConfigChanged")
    50  			}
    51  
    52  			var a agent.Agent
    53  			if err := context.Get(config.AgentName, &a); err != nil {
    54  				return nil, err
    55  			}
    56  
    57  			w := &apiserverCertWatcher{
    58  				agent:              a,
    59  				agentConfigChanged: config.AgentConfigChanged,
    60  			}
    61  			if err := w.update(); err != nil {
    62  				return nil, errors.Annotate(err, "parsing initial certificate")
    63  			}
    64  
    65  			w.tomb.Go(w.loop)
    66  			return w, nil
    67  		},
    68  		Output: outputFunc,
    69  	}
    70  }
    71  
    72  func outputFunc(in worker.Worker, out interface{}) error {
    73  	inWorker, _ := in.(*apiserverCertWatcher)
    74  	if inWorker == nil {
    75  		return errors.Errorf("in should be a %T; got a %T", inWorker, in)
    76  	}
    77  	outPointer, ok := out.(*func() *tls.Certificate)
    78  	if !ok {
    79  		return errors.Errorf("out should be %T; got %T", outPointer, out)
    80  	}
    81  	*outPointer = inWorker.getCurrent
    82  	return nil
    83  }
    84  
    85  type apiserverCertWatcher struct {
    86  	tomb               tomb.Tomb
    87  	agent              agent.Agent
    88  	agentConfigChanged *voyeur.Value
    89  
    90  	mu         sync.Mutex
    91  	currentRaw string
    92  	current    *tls.Certificate
    93  }
    94  
    95  func (w *apiserverCertWatcher) loop() error {
    96  	watch := w.agentConfigChanged.Watch()
    97  	defer watch.Close()
    98  	done := make(chan struct{})
    99  	defer close(done)
   100  
   101  	// TODO(axw) - this is pretty awful. There should be a
   102  	// NotifyWatcher for voyeur.Value. Note also that this code is
   103  	// repeated elsewhere.
   104  	watchCh := make(chan bool)
   105  	go func() {
   106  		defer close(watchCh)
   107  		for watch.Next() {
   108  			select {
   109  			case <-done:
   110  				return
   111  			case watchCh <- true:
   112  			}
   113  		}
   114  	}()
   115  
   116  	for {
   117  		// Always unconditionally check for a change first, in case
   118  		// there was a change between the start func and the call
   119  		// to Watch.
   120  		if err := w.update(); err != nil {
   121  			// We don't bounce the worker on bad certificate data.
   122  			logger.Errorf("cannot update certificate: %v", err)
   123  		}
   124  		select {
   125  		case <-w.tomb.Dying():
   126  			return tomb.ErrDying
   127  		case _, ok := <-watchCh:
   128  			if !ok {
   129  				return errors.New("config changed value closed")
   130  			}
   131  		}
   132  	}
   133  }
   134  
   135  // Kill implements worker.Worker.
   136  func (w *apiserverCertWatcher) Kill() {
   137  	w.tomb.Kill(nil)
   138  }
   139  
   140  // Wait implements worker.Worker.
   141  func (w *apiserverCertWatcher) Wait() error {
   142  	return w.tomb.Wait()
   143  }
   144  
   145  func (w *apiserverCertWatcher) update() error {
   146  	//logger.Errorf("cannot update certificate: %v", err)
   147  	config := w.agent.CurrentConfig()
   148  	info, ok := config.StateServingInfo()
   149  	if !ok {
   150  		return errors.New("no state serving info in agent config")
   151  	}
   152  	if info.Cert == "" {
   153  		return errors.New("certificate is empty")
   154  	}
   155  	if info.PrivateKey == "" {
   156  		return errors.New("private key is empty")
   157  	}
   158  	if info.Cert == w.currentRaw {
   159  		// No change.
   160  		return nil
   161  	}
   162  
   163  	tlsCert, err := tls.X509KeyPair([]byte(info.Cert), []byte(info.PrivateKey))
   164  	if err != nil {
   165  		return errors.Annotatef(err, "cannot create new TLS certificate")
   166  	}
   167  	x509Cert, err := x509.ParseCertificate(tlsCert.Certificate[0])
   168  	if err != nil {
   169  		return errors.Annotatef(err, "parsing x509 cert")
   170  	}
   171  	tlsCert.Leaf = x509Cert
   172  
   173  	w.currentRaw = info.Cert
   174  	w.mu.Lock()
   175  	w.current = &tlsCert
   176  	w.mu.Unlock()
   177  
   178  	var addr []string
   179  	for _, ip := range x509Cert.IPAddresses {
   180  		addr = append(addr, ip.String())
   181  	}
   182  	logger.Infof("new certificate addresses: %v", strings.Join(addr, ", "))
   183  	return nil
   184  }
   185  
   186  func (w *apiserverCertWatcher) getCurrent() *tls.Certificate {
   187  	w.mu.Lock()
   188  	defer w.mu.Unlock()
   189  	return w.current
   190  }