github.com/niedbalski/juju@v0.0.0-20190215020005-8ff100488e47/worker/raft/rafttransport/streamlayer.go (about)

     1  // Copyright 2018 Canonical Ltd.
     2  // Licensed under the AGPLv3, see LICENCE file for details.
     3  
     4  package rafttransport
     5  
     6  import (
     7  	"net"
     8  	"time"
     9  
    10  	"github.com/hashicorp/raft"
    11  	"github.com/juju/clock"
    12  	"github.com/juju/errors"
    13  	"github.com/juju/pubsub"
    14  	"gopkg.in/tomb.v2"
    15  
    16  	"github.com/juju/juju/pubsub/apiserver"
    17  )
    18  
    19  const (
    20  	// AddrTimeout is how long we'll wait for a good address to be
    21  	// sent before timing out in the Addr call - this is better than
    22  	// hanging indefinitely.
    23  	AddrTimeout = 1 * time.Minute
    24  )
    25  
    26  var (
    27  	// ErrAddressTimeout is used as the death reason when this transport dies because no good API address has been sent.
    28  	ErrAddressTimeout = errors.New("timed out waiting for API address")
    29  )
    30  
    31  func newStreamLayer(
    32  	localID raft.ServerID,
    33  	hub *pubsub.StructuredHub,
    34  	connections <-chan net.Conn,
    35  	clk clock.Clock,
    36  	dialer *Dialer,
    37  ) (*streamLayer, error) {
    38  	l := &streamLayer{
    39  		localID:     localID,
    40  		hub:         hub,
    41  		connections: connections,
    42  		dialer:      dialer,
    43  
    44  		addr:        make(chan net.Addr),
    45  		addrChanges: make(chan string),
    46  		clock:       clk,
    47  	}
    48  	// Watch for apiserver details changes, sending them
    49  	// down the "addrChanges" channel. The worker loop
    50  	// picks those up and makes the address available to
    51  	// the "Addr()" method.
    52  	unsubscribe, err := hub.Subscribe(apiserver.DetailsTopic, l.apiserverDetailsChanged)
    53  	if err != nil {
    54  		return nil, errors.Trace(err)
    55  	}
    56  
    57  	// Ask for the current details to be sent.
    58  	req := apiserver.DetailsRequest{
    59  		Requester: "raft-transport-stream-layer",
    60  		LocalOnly: true,
    61  	}
    62  	if _, err := hub.Publish(apiserver.DetailsRequestTopic, req); err != nil {
    63  		return nil, errors.Trace(err)
    64  	}
    65  
    66  	l.tomb.Go(func() error {
    67  		defer unsubscribe()
    68  		return l.loop()
    69  	})
    70  	return l, nil
    71  }
    72  
    73  // streamLayer represents the connection between raft nodes.
    74  //
    75  // Partially based on code from https://github.com/CanonicalLtd/raft-http.
    76  type streamLayer struct {
    77  	tomb        tomb.Tomb
    78  	localID     raft.ServerID
    79  	hub         *pubsub.StructuredHub
    80  	connections <-chan net.Conn
    81  	dialer      *Dialer
    82  	addr        chan net.Addr
    83  	addrChanges chan string
    84  	clock       clock.Clock
    85  }
    86  
    87  // Kill implements worker.Worker.
    88  func (l *streamLayer) Kill() {
    89  	l.tomb.Kill(nil)
    90  }
    91  
    92  // Wait implements worker.Worker.
    93  func (l *streamLayer) Wait() error {
    94  	return l.tomb.Wait()
    95  }
    96  
    97  // Accept waits for the next connection.
    98  func (l *streamLayer) Accept() (net.Conn, error) {
    99  	select {
   100  	case <-l.tomb.Dying():
   101  		return nil, errors.New("transport closed")
   102  	case conn := <-l.connections:
   103  		return conn, nil
   104  	}
   105  }
   106  
   107  // Close closes the layer.
   108  func (l *streamLayer) Close() error {
   109  	l.tomb.Kill(nil)
   110  	return l.tomb.Wait()
   111  }
   112  
   113  var invalidAddr = tcpAddr("address.invalid:0")
   114  
   115  // Addr returns the local address for the layer.
   116  func (l *streamLayer) Addr() net.Addr {
   117  	select {
   118  	case <-l.tomb.Dying():
   119  		return invalidAddr
   120  	case <-l.clock.After(AddrTimeout):
   121  		logger.Errorf("streamLayer.Addr timed out waiting for API address")
   122  		// Stop this (and parent) worker.
   123  		l.tomb.Kill(ErrAddressTimeout)
   124  		return invalidAddr
   125  	case addr := <-l.addr:
   126  		return addr
   127  	}
   128  }
   129  
   130  // Dial creates a new network connection.
   131  func (l *streamLayer) Dial(addr raft.ServerAddress, timeout time.Duration) (net.Conn, error) {
   132  	return l.dialer.Dial(addr, timeout)
   133  }
   134  
   135  func (l *streamLayer) loop() error {
   136  	// Wait for the internal address of this agent,
   137  	// and then send it out on l.addr whenever possible.
   138  	var addr tcpAddr
   139  	var out chan<- net.Addr
   140  	for {
   141  		select {
   142  		case <-l.tomb.Dying():
   143  			return tomb.ErrDying
   144  		case newAddr := <-l.addrChanges:
   145  			if newAddr == "" || newAddr == string(addr) {
   146  				continue
   147  			}
   148  			addr = tcpAddr(newAddr)
   149  			out = l.addr
   150  		case out <- addr:
   151  		}
   152  	}
   153  }
   154  
   155  func (l *streamLayer) apiserverDetailsChanged(topic string, details apiserver.Details, err error) {
   156  	if err != nil {
   157  		l.tomb.Kill(err)
   158  		return
   159  	}
   160  	var addr string
   161  	for _, server := range details.Servers {
   162  		if raft.ServerID(server.ID) != l.localID {
   163  			continue
   164  		}
   165  		addr = server.InternalAddress
   166  		break
   167  	}
   168  	select {
   169  	case l.addrChanges <- addr:
   170  	case <-l.tomb.Dying():
   171  	}
   172  }
   173  
   174  // tcpAddr is an implementation of net.Addr which simply
   175  // returns the address reported via pubsub. This avoids
   176  // having to resolve the address just to get back the
   177  // string representation of the address, which is all that
   178  // the address is used for.
   179  type tcpAddr string
   180  
   181  // Network is part of the net.Addr interface.
   182  func (a tcpAddr) Network() string {
   183  	return "tcp"
   184  }
   185  
   186  // String is part of the net.Addr interface.
   187  func (a tcpAddr) String() string {
   188  	return string(a)
   189  }