github.com/Axway/agent-sdk@v1.1.101/pkg/util/dialer_test.go (about) 1 package util 2 3 import ( 4 "context" 5 "encoding/base64" 6 "fmt" 7 "net" 8 "net/http" 9 "net/http/httptest" 10 "net/url" 11 "testing" 12 13 "github.com/stretchr/testify/assert" 14 ) 15 16 type mockTCPServer struct { 17 listener net.Listener 18 } 19 20 func newMockTCPServer() (*mockTCPServer, error) { 21 l, err := net.Listen("tcp", "localhost:0") 22 if err != nil { 23 return nil, err 24 } 25 26 server := &mockTCPServer{ 27 listener: l, 28 } 29 return server, nil 30 } 31 32 func (s *mockTCPServer) getAddr() string { 33 if s.listener != nil { 34 return s.listener.Addr().(*net.TCPAddr).String() 35 } 36 return "" 37 } 38 39 func (s *mockTCPServer) getIP() string { 40 if s.listener != nil { 41 return s.listener.Addr().(*net.TCPAddr).IP.String() 42 } 43 return "" 44 } 45 46 func (s *mockTCPServer) getPort() int { 47 if s.listener != nil { 48 return s.listener.Addr().(*net.TCPAddr).Port 49 } 50 return 0 51 } 52 53 func (s *mockTCPServer) close() { 54 if s.listener != nil { 55 s.listener.Close() 56 s.listener = nil 57 } 58 } 59 60 type mocHTTPServer struct { 61 responseStatus int 62 proxyAuth []string 63 server *httptest.Server 64 requestReceived bool 65 } 66 67 func (m *mocHTTPServer) handleReq(resp http.ResponseWriter, req *http.Request) { 68 m.requestReceived = true 69 proxyAuth, ok := req.Header["Proxy-Authorization"] 70 if ok { 71 m.proxyAuth = proxyAuth 72 resp.WriteHeader(m.responseStatus) 73 } 74 resp.WriteHeader(m.responseStatus) 75 } 76 77 func newMockHTTPServer() *mocHTTPServer { 78 mockServer := &mocHTTPServer{} 79 mockServer.server = httptest.NewServer(http.HandlerFunc(mockServer.handleReq)) 80 return mockServer 81 } 82 83 func TestProxyDial(t *testing.T) { 84 proxyURL, _ := url.Parse("http://localhost:8888") 85 dialer := NewDialer(proxyURL, nil) 86 conn, err := dialer.DialContext(context.Background(), "tcp", "testtarget") 87 assert.Nil(t, conn) 88 assert.NotNil(t, err) 89 90 proxyServer := newMockHTTPServer() 91 proxyURL, _ = url.Parse(proxyServer.server.URL) 92 dialer = NewDialer(proxyURL, nil) 93 proxyServer.responseStatus = 200 94 conn, err = dialer.DialContext(context.Background(), "tcp", "testtarget") 95 assert.NotNil(t, conn) 96 assert.Nil(t, err) 97 assert.Nil(t, proxyServer.proxyAuth) 98 99 proxyServer.responseStatus = 407 100 conn, err = dialer.DialContext(context.Background(), "tcp", "testtarget") 101 assert.Nil(t, conn) 102 assert.NotNil(t, err) 103 assert.Nil(t, proxyServer.proxyAuth) 104 105 proxyServer.responseStatus = 200 106 proxyAuthURL, _ := url.Parse("http://foo:bar@" + proxyURL.Host) 107 dialer = NewDialer(proxyAuthURL, nil) 108 conn, err = dialer.DialContext(context.Background(), "tcp", "testtarget") 109 assert.NotNil(t, conn) 110 assert.Nil(t, err) 111 assert.NotNil(t, proxyServer.proxyAuth) 112 assert.Equal(t, proxyServer.proxyAuth[0], "Basic "+base64.StdEncoding.EncodeToString([]byte("foo:bar"))) 113 } 114 115 func TestSingleEntryDial(t *testing.T) { 116 targetServer, _ := newMockTCPServer() 117 defer targetServer.close() 118 singleEntryServer, _ := newMockTCPServer() 119 defer singleEntryServer.close() 120 121 // No proxy, no single entry, validate connection directly to target server 122 targetServerURL, _ := url.Parse(fmt.Sprintf("https://%s", targetServer.getAddr())) 123 singleHostMapping := map[string]string{} 124 dialer := NewDialer(nil, singleHostMapping) 125 conn, err := dialer.Dial("tcp", targetServerURL.Host) 126 assert.NotNil(t, conn) 127 assert.Nil(t, err) 128 129 assert.Equal(t, targetServer.getIP(), conn.RemoteAddr().(*net.TCPAddr).IP.String()) 130 assert.Equal(t, targetServer.getPort(), conn.RemoteAddr().(*net.TCPAddr).Port) 131 assert.NotEqual(t, singleEntryServer.getPort(), conn.RemoteAddr().(*net.TCPAddr).Port) 132 133 // No proxy, single entry configured to match target, validate connection to single entry 134 singleHostMapping = map[string]string{ 135 targetServer.getAddr(): singleEntryServer.getAddr(), 136 } 137 dialer = NewDialer(nil, singleHostMapping) 138 conn, err = dialer.Dial("tcp", targetServerURL.Host) 139 assert.NotNil(t, conn) 140 assert.Nil(t, err) 141 142 assert.Equal(t, targetServer.getIP(), conn.RemoteAddr().(*net.TCPAddr).IP.String()) 143 assert.NotEqual(t, targetServer.getPort(), conn.RemoteAddr().(*net.TCPAddr).Port) 144 assert.Equal(t, singleEntryServer.getPort(), conn.RemoteAddr().(*net.TCPAddr).Port) 145 146 // Proxy configured, single entry configured to match target, validate connection to proxy 147 proxyServer := newMockHTTPServer() 148 proxyURL, _ := url.Parse(proxyServer.server.URL) 149 dialer = NewDialer(proxyURL, singleHostMapping) 150 proxyServer.responseStatus = 200 151 conn, err = dialer.Dial("tcp", targetServerURL.Host) 152 assert.NotNil(t, conn) 153 assert.Nil(t, err) 154 155 assert.Equal(t, targetServer.getIP(), conn.RemoteAddr().(*net.TCPAddr).IP.String()) 156 assert.NotEqual(t, targetServer.getPort(), conn.RemoteAddr().(*net.TCPAddr).Port) 157 assert.NotEqual(t, singleEntryServer.getPort(), conn.RemoteAddr().(*net.TCPAddr).Port) 158 assert.Equal(t, ParsePort(proxyURL), conn.RemoteAddr().(*net.TCPAddr).Port) 159 assert.Equal(t, true, proxyServer.requestReceived) 160 161 // Invalid proxy configured 162 proxyURL, _ = url.Parse("socks5://test:test@localhost:0") 163 dialer = NewDialer(proxyURL, singleHostMapping) 164 conn, err = dialer.Dial("tcp", targetServerURL.Host) 165 assert.Nil(t, conn) 166 assert.NotNil(t, err) 167 168 // Invalid proxy scheme 169 proxyURL, _ = url.Parse("noscheme://localhost:0") 170 dialer = NewDialer(proxyURL, singleHostMapping) 171 conn, err = dialer.Dial("tcp", targetServerURL.Host) 172 assert.Nil(t, conn) 173 assert.NotNil(t, err) 174 }