github.com/Axway/agent-sdk@v1.1.101/pkg/api/client_test.go (about)

     1  package api
     2  
     3  import (
     4  	"bufio"
     5  	"bytes"
     6  	"io"
     7  	"net/http"
     8  	"net/http/httptest"
     9  	"net/url"
    10  	"os"
    11  	"testing"
    12  	"time"
    13  
    14  	"github.com/Axway/agent-sdk/pkg/config"
    15  	"github.com/Axway/agent-sdk/pkg/util"
    16  	"github.com/stretchr/testify/assert"
    17  )
    18  
    19  type mocHTTPServer struct {
    20  	reqMethod    string
    21  	reqBody      []byte
    22  	reqUserAgent string
    23  
    24  	expectedHeader   map[string]string
    25  	processedHeaders bool
    26  
    27  	expectedQueryParams  map[string]string
    28  	processedQueryParams bool
    29  
    30  	respBody []byte
    31  	respCode int
    32  	server   *httptest.Server
    33  }
    34  
    35  func (m *mocHTTPServer) reset() {
    36  	m.reqMethod = ""
    37  	m.reqBody = nil
    38  	m.reqUserAgent = ""
    39  	m.expectedHeader = nil
    40  	m.processedHeaders = false
    41  	m.expectedQueryParams = nil
    42  	m.processedQueryParams = false
    43  	m.respBody = nil
    44  	m.respCode = 0
    45  }
    46  
    47  func (m *mocHTTPServer) handleReq(resp http.ResponseWriter, req *http.Request) {
    48  	m.reqMethod = req.Method
    49  	m.reqUserAgent = req.Header.Get("user-agent")
    50  	m.processHeaders(req)
    51  	m.processQueryParams(req)
    52  	m.readReqBody(req)
    53  	m.writeResponse(resp)
    54  }
    55  
    56  func (m *mocHTTPServer) processHeaders(req *http.Request) {
    57  	m.processedHeaders = true
    58  	for header, val := range m.expectedHeader {
    59  		reqHdrVal := req.Header.Get(header)
    60  		if val != reqHdrVal {
    61  			m.processedHeaders = false
    62  			return
    63  		}
    64  	}
    65  }
    66  
    67  func (m *mocHTTPServer) processQueryParams(req *http.Request) {
    68  	m.processedQueryParams = true
    69  	for param, val := range m.expectedQueryParams {
    70  		reqParamVal := req.URL.Query().Get(param)
    71  		if val != reqParamVal {
    72  			m.processedQueryParams = false
    73  			return
    74  		}
    75  	}
    76  }
    77  
    78  func (m *mocHTTPServer) readReqBody(req *http.Request) {
    79  	var reqBuffer bytes.Buffer
    80  	writer := bufio.NewWriter(&reqBuffer)
    81  	_, _ = io.CopyN(writer, req.Body, 1024)
    82  	m.reqBody = reqBuffer.Bytes()
    83  }
    84  
    85  func (m *mocHTTPServer) writeResponse(resp http.ResponseWriter) {
    86  	if m.respBody != nil && len(m.respBody) > 0 {
    87  		resp.Write(m.respBody)
    88  	} else {
    89  		resp.WriteHeader(m.respCode)
    90  	}
    91  }
    92  
    93  func newMockHTTPServer() *mocHTTPServer {
    94  	mockServer := &mocHTTPServer{}
    95  	mockServer.server = httptest.NewServer(http.HandlerFunc(mockServer.handleReq))
    96  	return mockServer
    97  }
    98  
    99  func TestNewClient(t *testing.T) {
   100  	tests := []struct {
   101  		name                string
   102  		tls                 config.TLSConfig
   103  		proxyURL            string
   104  		timeout             string
   105  		agentDialerExists   bool
   106  		expectedTimeout     time.Duration
   107  		expectedProxyScheme string
   108  	}{
   109  		{
   110  			name:                "insecure-no-proxy-default-timeout",
   111  			tls:                 nil,
   112  			proxyURL:            "",
   113  			timeout:             "",
   114  			agentDialerExists:   false,
   115  			expectedTimeout:     defaultTimeout,
   116  			expectedProxyScheme: "",
   117  		},
   118  		{
   119  			name:                "insecure-no-proxy-custom-timeout",
   120  			tls:                 nil,
   121  			proxyURL:            "",
   122  			timeout:             "30s",
   123  			agentDialerExists:   false,
   124  			expectedTimeout:     30 * time.Second,
   125  			expectedProxyScheme: "",
   126  		},
   127  		{
   128  			name:                "insecure-http-proxy-custom-timeout",
   129  			tls:                 nil,
   130  			proxyURL:            "http://localhost:8080",
   131  			timeout:             "30s",
   132  			agentDialerExists:   true,
   133  			expectedTimeout:     30 * time.Second,
   134  			expectedProxyScheme: "http",
   135  		},
   136  		{
   137  			name:                "secure-http-proxy-custom-timeout",
   138  			tls:                 config.NewTLSConfig(),
   139  			proxyURL:            "http://localhost:8080",
   140  			timeout:             "30s",
   141  			agentDialerExists:   true,
   142  			expectedTimeout:     30 * time.Second,
   143  			expectedProxyScheme: "http",
   144  		},
   145  	}
   146  	for _, tc := range tests {
   147  		t.Run(tc.name, func(t *testing.T) {
   148  			os.Setenv("HTTP_CLIENT_TIMEOUT", tc.timeout)
   149  			c := NewClient(tc.tls, tc.proxyURL)
   150  			hc, ok := c.(*httpClient)
   151  			assert.True(t, ok)
   152  			assert.Equal(t, hc.timeout, tc.expectedTimeout)
   153  			httpTransport := hc.httpClient.Transport.(*http.Transport)
   154  			assert.NotNil(t, httpTransport)
   155  			if tc.tls != nil {
   156  				assert.NotNil(t, httpTransport.TLSClientConfig)
   157  			} else {
   158  				assert.Nil(t, httpTransport.TLSClientConfig)
   159  			}
   160  			if tc.agentDialerExists {
   161  				assert.NotNil(t, hc.dialer)
   162  				assert.Equal(t, tc.expectedProxyScheme, hc.dialer.GetProxyScheme())
   163  			} else {
   164  				assert.Nil(t, hc.dialer)
   165  			}
   166  		})
   167  	}
   168  }
   169  
   170  func TestNewSingleEntryClient(t *testing.T) {
   171  	tests := []struct {
   172  		name               string
   173  		tls                config.TLSConfig
   174  		proxyURL           string
   175  		singleURL          string
   176  		singleEntryFilter  []string
   177  		expectedClientType string
   178  	}{
   179  		{
   180  			name:     "no-single-entry",
   181  			tls:      nil,
   182  			proxyURL: "",
   183  		},
   184  		{
   185  			name:              "insecure-no-proxy-default-timeout",
   186  			tls:               nil,
   187  			proxyURL:          "",
   188  			singleURL:         "http://test",
   189  			singleEntryFilter: []string{"http://abc"},
   190  		},
   191  	}
   192  
   193  	for _, tc := range tests {
   194  		t.Run(tc.name, func(t *testing.T) {
   195  			SetConfigAgent("", tc.singleURL, tc.singleEntryFilter)
   196  			c := NewSingleEntryClient(tc.tls, tc.proxyURL, defaultTimeout)
   197  			hc, ok := c.(*httpClient)
   198  			assert.True(t, ok)
   199  			httpTransport := hc.httpClient.Transport.(*http.Transport)
   200  			assert.NotNil(t, httpTransport)
   201  			if cfgAgent.singleURL != "" {
   202  				assert.NotNil(t, hc.dialer)
   203  				assert.Equal(t, len(hc.singleEntryHostMap), len(cfgAgent.singleEntryFilter))
   204  				singleEntryURL, _ := url.Parse(cfgAgent.singleURL)
   205  				singleEntryAddr := util.ParseAddr(singleEntryURL)
   206  				for _, filterURL := range cfgAgent.singleEntryFilter {
   207  					u, _ := url.Parse(filterURL)
   208  					mappedAddr, ok := hc.singleEntryHostMap[util.ParseAddr(u)]
   209  					assert.True(t, ok)
   210  					assert.Equal(t, singleEntryAddr, mappedAddr)
   211  				}
   212  			} else {
   213  				assert.Nil(t, hc.dialer)
   214  			}
   215  		})
   216  	}
   217  }
   218  
   219  func TestSend(t *testing.T) {
   220  	config.AgentTypeName = "Test"
   221  	config.AgentVersion = "1.0"
   222  	config.SDKVersion = "1.0"
   223  	httpServer := newMockHTTPServer()
   224  
   225  	tests := []struct {
   226  		name              string
   227  		method            string
   228  		url               string
   229  		queryParam        map[string]string
   230  		header            map[string]string
   231  		body              []byte
   232  		respBody          []byte
   233  		respCode          int
   234  		expectedUserAgent string
   235  		isErr             bool
   236  		envName           string
   237  		agentName         string
   238  		isDocker          bool
   239  		isGRPC            bool
   240  	}{
   241  		{
   242  			name:   "invalid-url",
   243  			url:    "socks://invalid-url",
   244  			method: GET,
   245  			isErr:  true,
   246  		},
   247  		{
   248  			name:              "get-request-with-queryparam-header",
   249  			url:               "http://test",
   250  			method:            GET,
   251  			queryParam:        map[string]string{"param1": "value1"},
   252  			header:            map[string]string{"header1": "value1"},
   253  			respCode:          200,
   254  			respBody:          []byte{},
   255  			envName:           "env",
   256  			isDocker:          false,
   257  			isGRPC:            true,
   258  			agentName:         "agent",
   259  			expectedUserAgent: "Test/1.0 SDK/1.0 env agent binary reactive",
   260  		},
   261  		{
   262  			name:              "post-request-with-response",
   263  			url:               "http://test",
   264  			method:            POST,
   265  			body:              []byte("test-req"),
   266  			respCode:          200,
   267  			respBody:          []byte("test-resp"),
   268  			envName:           "env",
   269  			isDocker:          true,
   270  			agentName:         "agent",
   271  			expectedUserAgent: "Test/1.0 SDK/1.0 env agent docker",
   272  		},
   273  		{
   274  			name:              "override-user-agent",
   275  			url:               "http://test",
   276  			method:            GET,
   277  			header:            map[string]string{"user-agent": "test"},
   278  			respCode:          401,
   279  			respBody:          []byte{},
   280  			envName:           "env",
   281  			isDocker:          true,
   282  			agentName:         "agent",
   283  			expectedUserAgent: "test",
   284  		},
   285  	}
   286  	for _, tc := range tests {
   287  		t.Run(tc.name, func(t *testing.T) {
   288  			ua := util.FormatUserAgent(
   289  				config.AgentTypeName,
   290  				config.AgentVersion,
   291  				config.SDKVersion,
   292  				tc.envName,
   293  				tc.agentName,
   294  				tc.isDocker,
   295  				tc.isGRPC)
   296  			SetConfigAgent(ua, httpServer.server.URL, []string{"http://test"})
   297  			httpServer.reset()
   298  			httpServer.respBody = tc.respBody
   299  			httpServer.respCode = tc.respCode
   300  			httpServer.expectedHeader = tc.header
   301  			httpServer.expectedQueryParams = tc.queryParam
   302  
   303  			c := NewSingleEntryClient(nil, "", defaultTimeout)
   304  			assert.NotNil(t, c)
   305  
   306  			req := Request{
   307  				Method:      tc.method,
   308  				URL:         tc.url,
   309  				QueryParams: tc.queryParam,
   310  				Headers:     tc.header,
   311  				Body:        tc.body,
   312  			}
   313  
   314  			res, err := c.Send(req)
   315  			if tc.isErr {
   316  				assert.NotNil(t, err)
   317  			} else {
   318  				assert.Nil(t, err)
   319  				assert.NotNil(t, res)
   320  				assert.Equal(t, tc.respCode, res.Code)
   321  				assert.Equal(t, tc.respBody, res.Body)
   322  				assert.Equal(t, tc.method, httpServer.reqMethod)
   323  				assert.Equal(t, tc.expectedUserAgent, httpServer.reqUserAgent)
   324  				assert.True(t, httpServer.processedHeaders)
   325  				assert.True(t, httpServer.processedQueryParams)
   326  			}
   327  		})
   328  	}
   329  }