github.com/niedbalski/juju@v0.0.0-20190215020005-8ff100488e47/worker/raft/rafttransport/worker.go (about)

     1  // Copyright 2018 Canonical Ltd.
     2  // Licensed under the AGPLv3, see LICENCE file for details.
     3  
     4  package rafttransport
     5  
     6  import (
     7  	"context"
     8  	"crypto/tls"
     9  	"log"
    10  	"net"
    11  	"net/http"
    12  	"time"
    13  
    14  	"github.com/hashicorp/raft"
    15  	"github.com/juju/clock"
    16  	"github.com/juju/errors"
    17  	"github.com/juju/loggo"
    18  	"github.com/juju/pubsub"
    19  	"github.com/juju/replicaset"
    20  	"gopkg.in/juju/worker.v1"
    21  	"gopkg.in/juju/worker.v1/catacomb"
    22  
    23  	"github.com/juju/juju/api"
    24  	"github.com/juju/juju/apiserver/apiserverhttp"
    25  	"github.com/juju/juju/apiserver/httpcontext"
    26  	"github.com/juju/juju/worker/raft/raftutil"
    27  )
    28  
    29  var (
    30  	logger = loggo.GetLogger("juju.worker.raft.rafttransport")
    31  )
    32  
    33  const (
    34  	maxPoolSize = replicaset.MaxPeers
    35  )
    36  
    37  // Config is the configuration required for running an apiserver-based
    38  // raft transport worker.
    39  type Config struct {
    40  	// APIInfo contains the information, excluding addresses,
    41  	// required to connect to an API server.
    42  	APIInfo *api.Info
    43  
    44  	// Authenticator is the HTTP request authenticator to use for
    45  	// the raft endpoint.
    46  	Authenticator httpcontext.Authenticator
    47  
    48  	// DialConn is the function to use for dialing connections to
    49  	// other API servers.
    50  	DialConn DialConnFunc
    51  
    52  	// Hub is the central hub to which the worker will subscribe
    53  	// for notification of local address changes.
    54  	Hub *pubsub.StructuredHub
    55  
    56  	// Mux is the API server HTTP mux into which the handler will
    57  	// be installed.
    58  	Mux *apiserverhttp.Mux
    59  
    60  	// Path is the path of the raft HTTP endpoint.
    61  	Path string
    62  
    63  	// LocalID is the raft.ServerID of the agent running this worker.
    64  	LocalID raft.ServerID
    65  
    66  	// Timeout, if non-zero, is the timeout to apply to transport
    67  	// operations. See raft.NetworkTransportConfig.Timeout for more
    68  	// details.
    69  	Timeout time.Duration
    70  
    71  	// TLSConfig is the TLS configuration to use for making
    72  	// connections to API servers.
    73  	TLSConfig *tls.Config
    74  
    75  	// Clock is used for timing out the Addr getter - if the
    76  	// peergrouper isn't publishing good API addresses in a timely
    77  	// fashion it's better to fail and log than to hang indefinitely.
    78  	Clock clock.Clock
    79  }
    80  
    81  // DialConnFunc is type of function used by the transport for
    82  // dialing a TLS connection to another API server. The worker
    83  // will send an HTTP request over the connection to upgrade it.
    84  type DialConnFunc func(ctx context.Context, addr string, tlsConfig *tls.Config) (net.Conn, error)
    85  
    86  // Validate validates the raft worker configuration.
    87  func (config Config) Validate() error {
    88  	if config.APIInfo == nil {
    89  		return errors.NotValidf("nil APIInfo")
    90  	}
    91  	if config.Authenticator == nil {
    92  		return errors.NotValidf("nil Authenticator")
    93  	}
    94  	if config.DialConn == nil {
    95  		return errors.NotValidf("nil DialConn")
    96  	}
    97  	if config.Hub == nil {
    98  		return errors.NotValidf("nil Hub")
    99  	}
   100  	if config.Mux == nil {
   101  		return errors.NotValidf("nil Mux")
   102  	}
   103  	if config.Path == "" {
   104  		return errors.NotValidf("empty Path")
   105  	}
   106  	if config.LocalID == "" {
   107  		return errors.NotValidf("empty LocalID")
   108  	}
   109  	if config.TLSConfig == nil {
   110  		return errors.NotValidf("nil TLSConfig")
   111  	}
   112  	return nil
   113  }
   114  
   115  // NewWorker returns a new apiserver-based raft transport worker,
   116  // with the given configuration. The worker itself implements
   117  // raft.Transport.
   118  func NewWorker(config Config) (worker.Worker, error) {
   119  	if err := config.Validate(); err != nil {
   120  		return nil, errors.Trace(err)
   121  	}
   122  
   123  	apiPorts := config.APIInfo.Ports()
   124  	if n := len(apiPorts); n != 1 {
   125  		return nil, errors.Errorf("api.Info has %d unique ports, expected 1", n)
   126  	}
   127  
   128  	w := &Worker{
   129  		config:       config,
   130  		connections:  make(chan net.Conn),
   131  		dialRequests: make(chan dialRequest),
   132  		apiPort:      apiPorts[0],
   133  	}
   134  
   135  	const logPrefix = "[transport] "
   136  	logWriter := &raftutil.LoggoWriter{logger, loggo.DEBUG}
   137  	logLogger := log.New(logWriter, logPrefix, 0)
   138  	stream, err := newStreamLayer(config.LocalID, config.Hub, w.connections, config.Clock, &Dialer{
   139  		APIInfo: config.APIInfo,
   140  		DialRaw: w.dialRaw,
   141  		Path:    config.Path,
   142  	})
   143  	if err != nil {
   144  		return nil, errors.Trace(err)
   145  	}
   146  	transport := raft.NewNetworkTransportWithConfig(&raft.NetworkTransportConfig{
   147  		Logger:  logLogger,
   148  		MaxPool: maxPoolSize,
   149  		Stream:  stream,
   150  		Timeout: config.Timeout,
   151  	})
   152  	w.Transport = transport
   153  
   154  	var h http.Handler = NewHandler(w.connections, w.catacomb.Dying())
   155  	h = &httpcontext.BasicAuthHandler{
   156  		Handler:       h,
   157  		Authenticator: w.config.Authenticator,
   158  		Authorizer:    httpcontext.AuthorizerFunc(controllerAuthorizer),
   159  	}
   160  	h = &httpcontext.ImpliedModelHandler{
   161  		Handler:   h,
   162  		ModelUUID: w.config.APIInfo.ModelTag.Id(),
   163  	}
   164  
   165  	w.config.Mux.AddHandler("GET", w.config.Path, h)
   166  
   167  	if err := catacomb.Invoke(catacomb.Plan{
   168  		Site: &w.catacomb,
   169  		Work: func() error {
   170  			defer transport.Close()
   171  			defer w.config.Mux.RemoveHandler("GET", w.config.Path)
   172  			return w.loop()
   173  		},
   174  		Init: []worker.Worker{stream},
   175  	}); err != nil {
   176  		transport.Close()
   177  		w.config.Mux.RemoveHandler("GET", w.config.Path)
   178  		return nil, errors.Trace(err)
   179  	}
   180  	return w, nil
   181  }
   182  
   183  // Worker is a worker that manages a raft.Transport.
   184  type Worker struct {
   185  	raft.Transport
   186  
   187  	catacomb     catacomb.Catacomb
   188  	config       Config
   189  	connections  chan net.Conn
   190  	dialRequests chan dialRequest
   191  	tlsConfig    *tls.Config
   192  	apiPort      int
   193  }
   194  
   195  type dialRequest struct {
   196  	ctx     context.Context
   197  	address string
   198  	result  chan<- dialResult
   199  }
   200  
   201  type dialResult struct {
   202  	conn net.Conn
   203  	err  error
   204  }
   205  
   206  // Kill is part of the worker.Worker interface.
   207  func (w *Worker) Kill() {
   208  	w.catacomb.Kill(nil)
   209  }
   210  
   211  // Wait is part of the worker.Worker interface.
   212  func (w *Worker) Wait() error {
   213  	return w.catacomb.Wait()
   214  }
   215  
   216  // dialRaw dials a new TLS connection to the controller identified
   217  // by the given address. The address is expected to be the stringified
   218  // tag of a controller machine agent. The resulting connection is
   219  // appropriate for use as Dialer.DialRaw.
   220  func (w *Worker) dialRaw(address raft.ServerAddress, timeout time.Duration) (net.Conn, error) {
   221  	// Give precedence to the worker dying.
   222  	select {
   223  	case <-w.catacomb.Dying():
   224  		return nil, w.errDialWorkerStopped()
   225  	default:
   226  	}
   227  
   228  	ctx := context.Background()
   229  	if timeout != 0 {
   230  		var cancel context.CancelFunc
   231  		ctx, cancel = context.WithTimeout(ctx, timeout)
   232  		defer cancel()
   233  	}
   234  
   235  	resultCh := make(chan dialResult)
   236  	req := dialRequest{
   237  		ctx:     ctx,
   238  		address: string(address),
   239  		result:  resultCh,
   240  	}
   241  	select {
   242  	case <-w.catacomb.Dying():
   243  		return nil, w.errDialWorkerStopped()
   244  	case <-ctx.Done():
   245  		return nil, dialRequestTimeoutError{}
   246  	case w.dialRequests <- req:
   247  	}
   248  
   249  	select {
   250  	case res := <-resultCh:
   251  		return res.conn, res.err
   252  	case <-ctx.Done():
   253  		return nil, dialRequestTimeoutError{}
   254  	case <-w.catacomb.Dying():
   255  		return nil, w.errDialWorkerStopped()
   256  	}
   257  }
   258  
   259  func (w *Worker) errDialWorkerStopped() error {
   260  	err := w.catacomb.Err()
   261  	if err != nil && err != w.catacomb.ErrDying() {
   262  		return dialWorkerStoppedError{err}
   263  	}
   264  	return dialWorkerStoppedError{
   265  		errors.New("worker stopped"),
   266  	}
   267  }
   268  
   269  func (w *Worker) loop() error {
   270  	for {
   271  		select {
   272  		case <-w.catacomb.Dying():
   273  			return w.catacomb.ErrDying()
   274  		case req := <-w.dialRequests:
   275  			go w.handleDial(req)
   276  		}
   277  	}
   278  }
   279  
   280  func (w *Worker) handleDial(req dialRequest) {
   281  	conn, err := w.config.DialConn(req.ctx, req.address, w.config.TLSConfig)
   282  	select {
   283  	case req.result <- dialResult{conn, err}:
   284  		return
   285  	case <-req.ctx.Done():
   286  	case <-w.catacomb.Dying():
   287  	}
   288  	if err == nil {
   289  		// result wasn't delivered, close connection
   290  		conn.Close()
   291  	}
   292  }
   293  
   294  // DialConn dials a TLS connection to the API server with the
   295  // given address, using the given TLS configuration. This will
   296  // be used for requesting the raft endpoint, upgrading to a
   297  // raw connection for inter-node raft communications.
   298  //
   299  // TODO: this function needs to be made proxy-aware.
   300  func DialConn(ctx context.Context, addr string, tlsConfig *tls.Config) (net.Conn, error) {
   301  	dialer := &net.Dialer{}
   302  	if deadline, ok := ctx.Deadline(); ok {
   303  		dialer.Deadline = deadline
   304  	}
   305  
   306  	ctx, cancel := context.WithCancel(ctx)
   307  	defer cancel()
   308  
   309  	canceled := make(chan struct{})
   310  	go func() {
   311  		<-ctx.Done()
   312  		if ctx.Err() == context.Canceled {
   313  			close(canceled)
   314  		}
   315  	}()
   316  	dialer.Cancel = canceled
   317  
   318  	return tls.DialWithDialer(dialer, "tcp", addr, tlsConfig)
   319  }
   320  
   321  func controllerAuthorizer(authInfo httpcontext.AuthInfo) error {
   322  	if authInfo.Controller {
   323  		return nil
   324  	}
   325  	return errors.New("controller agents only")
   326  }