github.com/niedbalski/juju@v0.0.0-20190215020005-8ff100488e47/worker/raft/rafttransport/handler_test.go (about) 1 // Copyright 2018 Canonical Ltd. 2 // Licensed under the AGPLv3, see LICENCE file for details. 3 4 package rafttransport_test 5 6 import ( 7 "crypto/tls" 8 "net" 9 "net/http" 10 "net/http/httptest" 11 "net/url" 12 "time" 13 14 "github.com/hashicorp/raft" 15 "github.com/juju/testing" 16 jc "github.com/juju/testing/checkers" 17 gc "gopkg.in/check.v1" 18 19 "github.com/juju/juju/api" 20 coretesting "github.com/juju/juju/testing" 21 "github.com/juju/juju/worker/raft/rafttransport" 22 ) 23 24 type HandlerSuite struct { 25 testing.IsolationSuite 26 connections chan net.Conn 27 handler *rafttransport.Handler 28 server *httptest.Server 29 } 30 31 var _ = gc.Suite(&HandlerSuite{}) 32 33 func (s *HandlerSuite) SetUpTest(c *gc.C) { 34 s.IsolationSuite.SetUpTest(c) 35 s.connections = make(chan net.Conn) 36 s.handler = rafttransport.NewHandler(s.connections, nil) 37 s.server = httptest.NewTLSServer(s.handler) 38 s.AddCleanup(func(c *gc.C) { 39 s.server.Close() 40 }) 41 } 42 43 func (s *HandlerSuite) TestHandler(c *gc.C) { 44 u, err := url.Parse(s.server.URL) 45 c.Assert(err, jc.ErrorIsNil) 46 dialRaw := func(addr raft.ServerAddress, timeout time.Duration) (net.Conn, error) { 47 tlsConfig := s.server.Client().Transport.(*http.Transport).TLSClientConfig 48 return tls.Dial("tcp", u.Host, tlsConfig) 49 } 50 dialer := rafttransport.Dialer{ 51 APIInfo: &api.Info{}, 52 Path: "/raft", 53 DialRaw: dialRaw, 54 } 55 clientConn, err := dialer.Dial("", 0) 56 c.Assert(err, jc.ErrorIsNil) 57 defer clientConn.Close() 58 59 var serverConn net.Conn 60 select { 61 case conn := <-s.connections: 62 serverConn = conn 63 case <-time.After(coretesting.LongWait): 64 c.Fatal("timed out waiting for server connection") 65 } 66 defer serverConn.Close() 67 68 payload := "hello, server!" 69 n, err := clientConn.Write([]byte(payload)) 70 c.Assert(err, jc.ErrorIsNil) 71 72 read := make([]byte, n) 73 n, err = serverConn.Read(read) 74 c.Assert(err, jc.ErrorIsNil) 75 c.Assert(n, gc.Equals, len(payload)) 76 c.Assert(string(read), gc.Equals, payload) 77 }