github.com/niedbalski/juju@v0.0.0-20190215020005-8ff100488e47/worker/raft/rafttransport/streamlayer.go (about) 1 // Copyright 2018 Canonical Ltd. 2 // Licensed under the AGPLv3, see LICENCE file for details. 3 4 package rafttransport 5 6 import ( 7 "net" 8 "time" 9 10 "github.com/hashicorp/raft" 11 "github.com/juju/clock" 12 "github.com/juju/errors" 13 "github.com/juju/pubsub" 14 "gopkg.in/tomb.v2" 15 16 "github.com/juju/juju/pubsub/apiserver" 17 ) 18 19 const ( 20 // AddrTimeout is how long we'll wait for a good address to be 21 // sent before timing out in the Addr call - this is better than 22 // hanging indefinitely. 23 AddrTimeout = 1 * time.Minute 24 ) 25 26 var ( 27 // ErrAddressTimeout is used as the death reason when this transport dies because no good API address has been sent. 28 ErrAddressTimeout = errors.New("timed out waiting for API address") 29 ) 30 31 func newStreamLayer( 32 localID raft.ServerID, 33 hub *pubsub.StructuredHub, 34 connections <-chan net.Conn, 35 clk clock.Clock, 36 dialer *Dialer, 37 ) (*streamLayer, error) { 38 l := &streamLayer{ 39 localID: localID, 40 hub: hub, 41 connections: connections, 42 dialer: dialer, 43 44 addr: make(chan net.Addr), 45 addrChanges: make(chan string), 46 clock: clk, 47 } 48 // Watch for apiserver details changes, sending them 49 // down the "addrChanges" channel. The worker loop 50 // picks those up and makes the address available to 51 // the "Addr()" method. 52 unsubscribe, err := hub.Subscribe(apiserver.DetailsTopic, l.apiserverDetailsChanged) 53 if err != nil { 54 return nil, errors.Trace(err) 55 } 56 57 // Ask for the current details to be sent. 58 req := apiserver.DetailsRequest{ 59 Requester: "raft-transport-stream-layer", 60 LocalOnly: true, 61 } 62 if _, err := hub.Publish(apiserver.DetailsRequestTopic, req); err != nil { 63 return nil, errors.Trace(err) 64 } 65 66 l.tomb.Go(func() error { 67 defer unsubscribe() 68 return l.loop() 69 }) 70 return l, nil 71 } 72 73 // streamLayer represents the connection between raft nodes. 74 // 75 // Partially based on code from https://github.com/CanonicalLtd/raft-http. 76 type streamLayer struct { 77 tomb tomb.Tomb 78 localID raft.ServerID 79 hub *pubsub.StructuredHub 80 connections <-chan net.Conn 81 dialer *Dialer 82 addr chan net.Addr 83 addrChanges chan string 84 clock clock.Clock 85 } 86 87 // Kill implements worker.Worker. 88 func (l *streamLayer) Kill() { 89 l.tomb.Kill(nil) 90 } 91 92 // Wait implements worker.Worker. 93 func (l *streamLayer) Wait() error { 94 return l.tomb.Wait() 95 } 96 97 // Accept waits for the next connection. 98 func (l *streamLayer) Accept() (net.Conn, error) { 99 select { 100 case <-l.tomb.Dying(): 101 return nil, errors.New("transport closed") 102 case conn := <-l.connections: 103 return conn, nil 104 } 105 } 106 107 // Close closes the layer. 108 func (l *streamLayer) Close() error { 109 l.tomb.Kill(nil) 110 return l.tomb.Wait() 111 } 112 113 var invalidAddr = tcpAddr("address.invalid:0") 114 115 // Addr returns the local address for the layer. 116 func (l *streamLayer) Addr() net.Addr { 117 select { 118 case <-l.tomb.Dying(): 119 return invalidAddr 120 case <-l.clock.After(AddrTimeout): 121 logger.Errorf("streamLayer.Addr timed out waiting for API address") 122 // Stop this (and parent) worker. 123 l.tomb.Kill(ErrAddressTimeout) 124 return invalidAddr 125 case addr := <-l.addr: 126 return addr 127 } 128 } 129 130 // Dial creates a new network connection. 131 func (l *streamLayer) Dial(addr raft.ServerAddress, timeout time.Duration) (net.Conn, error) { 132 return l.dialer.Dial(addr, timeout) 133 } 134 135 func (l *streamLayer) loop() error { 136 // Wait for the internal address of this agent, 137 // and then send it out on l.addr whenever possible. 138 var addr tcpAddr 139 var out chan<- net.Addr 140 for { 141 select { 142 case <-l.tomb.Dying(): 143 return tomb.ErrDying 144 case newAddr := <-l.addrChanges: 145 if newAddr == "" || newAddr == string(addr) { 146 continue 147 } 148 addr = tcpAddr(newAddr) 149 out = l.addr 150 case out <- addr: 151 } 152 } 153 } 154 155 func (l *streamLayer) apiserverDetailsChanged(topic string, details apiserver.Details, err error) { 156 if err != nil { 157 l.tomb.Kill(err) 158 return 159 } 160 var addr string 161 for _, server := range details.Servers { 162 if raft.ServerID(server.ID) != l.localID { 163 continue 164 } 165 addr = server.InternalAddress 166 break 167 } 168 select { 169 case l.addrChanges <- addr: 170 case <-l.tomb.Dying(): 171 } 172 } 173 174 // tcpAddr is an implementation of net.Addr which simply 175 // returns the address reported via pubsub. This avoids 176 // having to resolve the address just to get back the 177 // string representation of the address, which is all that 178 // the address is used for. 179 type tcpAddr string 180 181 // Network is part of the net.Addr interface. 182 func (a tcpAddr) Network() string { 183 return "tcp" 184 } 185 186 // String is part of the net.Addr interface. 187 func (a tcpAddr) String() string { 188 return string(a) 189 }