github.com/niedbalski/juju@v0.0.0-20190215020005-8ff100488e47/worker/raft/rafttransport/dialer.go (about) 1 // Copyright 2018 Canonical Ltd. 2 // Licensed under the AGPLv3, see LICENCE file for details. 3 4 package rafttransport 5 6 import ( 7 "bufio" 8 "io/ioutil" 9 "net" 10 "net/http" 11 "net/url" 12 "time" 13 14 "github.com/hashicorp/raft" 15 "github.com/juju/errors" 16 17 "github.com/juju/juju/api" 18 ) 19 20 // Dialer is a type that can be used for dialling a connection 21 // connecting to a raft endpoint using the configured path, and 22 // upgrading to a raft connection. 23 type Dialer struct { 24 // APIInfo is used for authentication. 25 APIInfo *api.Info 26 27 // DialRaw returns a connection to the HTTP server 28 // that is serving the raft endpoint. 29 DialRaw func(raft.ServerAddress, time.Duration) (net.Conn, error) 30 31 // Path is the path of the raft HTTP endpoint. 32 Path string 33 } 34 35 // Dial dials a new raft network connection to the controller agent 36 // with the tag identified by the given address. 37 // 38 // Based on code from https://github.com/CanonicalLtd/raft-http. 39 func (d *Dialer) Dial(addr raft.ServerAddress, timeout time.Duration) (net.Conn, error) { 40 request := &http.Request{ 41 Method: "GET", 42 URL: &url.URL{Path: d.Path}, 43 Proto: "HTTP/1.1", 44 ProtoMajor: 1, 45 ProtoMinor: 1, 46 Header: make(http.Header), 47 } 48 request.Header.Set("Upgrade", "raft") 49 if err := api.AuthHTTPRequest(request, d.APIInfo); err != nil { 50 return nil, errors.Trace(err) 51 } 52 53 logger.Infof("dialing %s", addr) 54 conn, err := d.DialRaw(addr, timeout) 55 if err != nil { 56 return nil, errors.Annotate(err, "dial failed") 57 } 58 59 if err := request.Write(conn); err != nil { 60 return nil, errors.Annotate(err, "sending HTTP request failed") 61 } 62 63 response, err := http.ReadResponse(bufio.NewReader(conn), request) 64 if err != nil { 65 return nil, errors.Annotate(err, "failed to read response") 66 } 67 defer response.Body.Close() 68 if response.StatusCode != http.StatusSwitchingProtocols { 69 if body, err := ioutil.ReadAll(response.Body); err == nil && len(body) != 0 { 70 logger.Tracef("response: %s", body) 71 } 72 return nil, errors.Errorf( 73 "expected status code %d, got %d", 74 http.StatusSwitchingProtocols, 75 response.StatusCode, 76 ) 77 } 78 if response.Header.Get("Upgrade") != "raft" { 79 return nil, errors.New("missing or unexpected Upgrade header in response") 80 } 81 return conn, err 82 }