github.com/Axway/agent-sdk@v1.1.101/pkg/util/dialer_test.go (about)

     1  package util
     2  
     3  import (
     4  	"context"
     5  	"encoding/base64"
     6  	"fmt"
     7  	"net"
     8  	"net/http"
     9  	"net/http/httptest"
    10  	"net/url"
    11  	"testing"
    12  
    13  	"github.com/stretchr/testify/assert"
    14  )
    15  
    16  type mockTCPServer struct {
    17  	listener net.Listener
    18  }
    19  
    20  func newMockTCPServer() (*mockTCPServer, error) {
    21  	l, err := net.Listen("tcp", "localhost:0")
    22  	if err != nil {
    23  		return nil, err
    24  	}
    25  
    26  	server := &mockTCPServer{
    27  		listener: l,
    28  	}
    29  	return server, nil
    30  }
    31  
    32  func (s *mockTCPServer) getAddr() string {
    33  	if s.listener != nil {
    34  		return s.listener.Addr().(*net.TCPAddr).String()
    35  	}
    36  	return ""
    37  }
    38  
    39  func (s *mockTCPServer) getIP() string {
    40  	if s.listener != nil {
    41  		return s.listener.Addr().(*net.TCPAddr).IP.String()
    42  	}
    43  	return ""
    44  }
    45  
    46  func (s *mockTCPServer) getPort() int {
    47  	if s.listener != nil {
    48  		return s.listener.Addr().(*net.TCPAddr).Port
    49  	}
    50  	return 0
    51  }
    52  
    53  func (s *mockTCPServer) close() {
    54  	if s.listener != nil {
    55  		s.listener.Close()
    56  		s.listener = nil
    57  	}
    58  }
    59  
    60  type mocHTTPServer struct {
    61  	responseStatus  int
    62  	proxyAuth       []string
    63  	server          *httptest.Server
    64  	requestReceived bool
    65  }
    66  
    67  func (m *mocHTTPServer) handleReq(resp http.ResponseWriter, req *http.Request) {
    68  	m.requestReceived = true
    69  	proxyAuth, ok := req.Header["Proxy-Authorization"]
    70  	if ok {
    71  		m.proxyAuth = proxyAuth
    72  		resp.WriteHeader(m.responseStatus)
    73  	}
    74  	resp.WriteHeader(m.responseStatus)
    75  }
    76  
    77  func newMockHTTPServer() *mocHTTPServer {
    78  	mockServer := &mocHTTPServer{}
    79  	mockServer.server = httptest.NewServer(http.HandlerFunc(mockServer.handleReq))
    80  	return mockServer
    81  }
    82  
    83  func TestProxyDial(t *testing.T) {
    84  	proxyURL, _ := url.Parse("http://localhost:8888")
    85  	dialer := NewDialer(proxyURL, nil)
    86  	conn, err := dialer.DialContext(context.Background(), "tcp", "testtarget")
    87  	assert.Nil(t, conn)
    88  	assert.NotNil(t, err)
    89  
    90  	proxyServer := newMockHTTPServer()
    91  	proxyURL, _ = url.Parse(proxyServer.server.URL)
    92  	dialer = NewDialer(proxyURL, nil)
    93  	proxyServer.responseStatus = 200
    94  	conn, err = dialer.DialContext(context.Background(), "tcp", "testtarget")
    95  	assert.NotNil(t, conn)
    96  	assert.Nil(t, err)
    97  	assert.Nil(t, proxyServer.proxyAuth)
    98  
    99  	proxyServer.responseStatus = 407
   100  	conn, err = dialer.DialContext(context.Background(), "tcp", "testtarget")
   101  	assert.Nil(t, conn)
   102  	assert.NotNil(t, err)
   103  	assert.Nil(t, proxyServer.proxyAuth)
   104  
   105  	proxyServer.responseStatus = 200
   106  	proxyAuthURL, _ := url.Parse("http://foo:bar@" + proxyURL.Host)
   107  	dialer = NewDialer(proxyAuthURL, nil)
   108  	conn, err = dialer.DialContext(context.Background(), "tcp", "testtarget")
   109  	assert.NotNil(t, conn)
   110  	assert.Nil(t, err)
   111  	assert.NotNil(t, proxyServer.proxyAuth)
   112  	assert.Equal(t, proxyServer.proxyAuth[0], "Basic "+base64.StdEncoding.EncodeToString([]byte("foo:bar")))
   113  }
   114  
   115  func TestSingleEntryDial(t *testing.T) {
   116  	targetServer, _ := newMockTCPServer()
   117  	defer targetServer.close()
   118  	singleEntryServer, _ := newMockTCPServer()
   119  	defer singleEntryServer.close()
   120  
   121  	// No proxy, no single entry, validate connection directly to target server
   122  	targetServerURL, _ := url.Parse(fmt.Sprintf("https://%s", targetServer.getAddr()))
   123  	singleHostMapping := map[string]string{}
   124  	dialer := NewDialer(nil, singleHostMapping)
   125  	conn, err := dialer.Dial("tcp", targetServerURL.Host)
   126  	assert.NotNil(t, conn)
   127  	assert.Nil(t, err)
   128  
   129  	assert.Equal(t, targetServer.getIP(), conn.RemoteAddr().(*net.TCPAddr).IP.String())
   130  	assert.Equal(t, targetServer.getPort(), conn.RemoteAddr().(*net.TCPAddr).Port)
   131  	assert.NotEqual(t, singleEntryServer.getPort(), conn.RemoteAddr().(*net.TCPAddr).Port)
   132  
   133  	// No proxy, single entry configured to match target, validate connection to single entry
   134  	singleHostMapping = map[string]string{
   135  		targetServer.getAddr(): singleEntryServer.getAddr(),
   136  	}
   137  	dialer = NewDialer(nil, singleHostMapping)
   138  	conn, err = dialer.Dial("tcp", targetServerURL.Host)
   139  	assert.NotNil(t, conn)
   140  	assert.Nil(t, err)
   141  
   142  	assert.Equal(t, targetServer.getIP(), conn.RemoteAddr().(*net.TCPAddr).IP.String())
   143  	assert.NotEqual(t, targetServer.getPort(), conn.RemoteAddr().(*net.TCPAddr).Port)
   144  	assert.Equal(t, singleEntryServer.getPort(), conn.RemoteAddr().(*net.TCPAddr).Port)
   145  
   146  	// Proxy configured, single entry configured to match target, validate connection to proxy
   147  	proxyServer := newMockHTTPServer()
   148  	proxyURL, _ := url.Parse(proxyServer.server.URL)
   149  	dialer = NewDialer(proxyURL, singleHostMapping)
   150  	proxyServer.responseStatus = 200
   151  	conn, err = dialer.Dial("tcp", targetServerURL.Host)
   152  	assert.NotNil(t, conn)
   153  	assert.Nil(t, err)
   154  
   155  	assert.Equal(t, targetServer.getIP(), conn.RemoteAddr().(*net.TCPAddr).IP.String())
   156  	assert.NotEqual(t, targetServer.getPort(), conn.RemoteAddr().(*net.TCPAddr).Port)
   157  	assert.NotEqual(t, singleEntryServer.getPort(), conn.RemoteAddr().(*net.TCPAddr).Port)
   158  	assert.Equal(t, ParsePort(proxyURL), conn.RemoteAddr().(*net.TCPAddr).Port)
   159  	assert.Equal(t, true, proxyServer.requestReceived)
   160  
   161  	// Invalid proxy configured
   162  	proxyURL, _ = url.Parse("socks5://test:test@localhost:0")
   163  	dialer = NewDialer(proxyURL, singleHostMapping)
   164  	conn, err = dialer.Dial("tcp", targetServerURL.Host)
   165  	assert.Nil(t, conn)
   166  	assert.NotNil(t, err)
   167  
   168  	// Invalid proxy scheme
   169  	proxyURL, _ = url.Parse("noscheme://localhost:0")
   170  	dialer = NewDialer(proxyURL, singleHostMapping)
   171  	conn, err = dialer.Dial("tcp", targetServerURL.Host)
   172  	assert.Nil(t, conn)
   173  	assert.NotNil(t, err)
   174  }