github.com/Axway/agent-sdk@v1.1.101/pkg/api/client.go (about)

     1  package api
     2  
     3  import (
     4  	"bufio"
     5  	"bytes"
     6  	"context"
     7  	"io"
     8  	"net"
     9  	"net/http"
    10  	"net/url"
    11  	"os"
    12  	"strings"
    13  	"sync"
    14  	"time"
    15  
    16  	"github.com/Axway/agent-sdk/pkg/config"
    17  	"github.com/Axway/agent-sdk/pkg/util"
    18  	"github.com/Axway/agent-sdk/pkg/util/log"
    19  	"github.com/google/uuid"
    20  )
    21  
    22  // HTTP const definitions
    23  const (
    24  	GET    string = http.MethodGet
    25  	POST   string = http.MethodPost
    26  	PUT    string = http.MethodPut
    27  	PATCH  string = http.MethodPatch
    28  	DELETE string = http.MethodDelete
    29  
    30  	defaultTimeout     = time.Second * 60
    31  	responseBufferSize = 2048
    32  )
    33  
    34  // Request - the request object used when communicating to an API
    35  type Request struct {
    36  	Method      string
    37  	URL         string
    38  	QueryParams map[string]string
    39  	Headers     map[string]string
    40  	Body        []byte
    41  	FormData    map[string]string
    42  }
    43  
    44  // Response - the response object given back when communicating to an API
    45  type Response struct {
    46  	Code    int
    47  	Body    []byte
    48  	Headers map[string][]string
    49  }
    50  
    51  // Client -
    52  type Client interface {
    53  	Send(request Request) (*Response, error)
    54  }
    55  
    56  type httpClient struct {
    57  	Client
    58  	logger             log.FieldLogger
    59  	httpClient         *http.Client
    60  	timeout            time.Duration
    61  	dialer             util.Dialer
    62  	singleEntryHostMap map[string]string
    63  	singleURL          string
    64  }
    65  
    66  type configAgent struct {
    67  	singleURL         string
    68  	singleEntryFilter []string
    69  	userAgent         string
    70  }
    71  
    72  var cfgAgent *configAgent
    73  var cfgAgentMutex *sync.Mutex
    74  
    75  func init() {
    76  	cfgAgent = &configAgent{}
    77  	cfgAgentMutex = &sync.Mutex{}
    78  }
    79  
    80  // SetConfigAgent -
    81  func SetConfigAgent(userAgent, singleURL string, singleEntryFilter []string) {
    82  	cfgAgentMutex.Lock()
    83  	defer cfgAgentMutex.Unlock()
    84  	cfgAgent.userAgent = userAgent
    85  	cfgAgent.singleURL = singleURL
    86  	if cfgAgent.singleEntryFilter != nil {
    87  		cfgAgent.singleEntryFilter = append(cfgAgent.singleEntryFilter, singleEntryFilter...)
    88  	} else {
    89  		cfgAgent.singleEntryFilter = singleEntryFilter
    90  	}
    91  }
    92  
    93  type ClientOpt func(*httpClient)
    94  
    95  func WithTimeout(timeout time.Duration) func(*httpClient) {
    96  	return func(h *httpClient) {
    97  		h.timeout = timeout
    98  	}
    99  }
   100  
   101  func WithSingleURL() func(*httpClient) {
   102  	return func(h *httpClient) {
   103  		h.singleURL = ""
   104  		if cfgAgent != nil {
   105  			h.singleURL = cfgAgent.singleURL
   106  			if h.singleURL != "" {
   107  				h.singleEntryHostMap = initializeSingleEntryMapping(h.singleURL, cfgAgent.singleEntryFilter)
   108  			}
   109  		}
   110  	}
   111  }
   112  
   113  // NewClient - creates a new HTTP client
   114  func NewClient(tlsCfg config.TLSConfig, proxyURL string, options ...ClientOpt) Client {
   115  	timeout := getTimeoutFromEnvironment()
   116  	client := newClient(timeout)
   117  
   118  	for _, o := range options {
   119  		o(client)
   120  	}
   121  
   122  	client.initialize(tlsCfg, proxyURL)
   123  	return client
   124  }
   125  
   126  // NewClientWithTimeout - creates a new HTTP client, with a timeout
   127  func NewClientWithTimeout(tlsCfg config.TLSConfig, proxyURL string, timeout time.Duration) Client {
   128  	log.DeprecationWarningReplace("NewClientWithTimeout", "NewClient and WithTimeout optional func")
   129  	return NewClient(tlsCfg, proxyURL, WithTimeout(timeout))
   130  }
   131  
   132  // NewSingleEntryClient - creates a new HTTP client for single entry point with a timeout
   133  func NewSingleEntryClient(tlsCfg config.TLSConfig, proxyURL string, timeout time.Duration) Client {
   134  	log.DeprecationWarningReplace("NewSingleEntryClient", "NewClient and WithSingleURL optional func")
   135  	return NewClient(tlsCfg, proxyURL, WithTimeout(timeout), WithSingleURL())
   136  }
   137  
   138  func newClient(timeout time.Duration) *httpClient {
   139  	return &httpClient{
   140  		timeout: timeout,
   141  		logger: log.NewFieldLogger().
   142  			WithComponent("httpClient").
   143  			WithPackage("sdk.api"),
   144  	}
   145  }
   146  
   147  func initializeSingleEntryMapping(singleEntryURL string, singleEntryFilter []string) map[string]string {
   148  	hostMapping := make(map[string]string)
   149  	entryURL, err := url.Parse(singleEntryURL)
   150  	if err == nil {
   151  		for _, filteredURL := range singleEntryFilter {
   152  			svcURL, err := url.Parse(filteredURL)
   153  			if err == nil {
   154  				hostMapping[util.ParseAddr(svcURL)] = util.ParseAddr(entryURL)
   155  			}
   156  		}
   157  	}
   158  	return hostMapping
   159  }
   160  
   161  func parseProxyURL(proxyURL string) *url.URL {
   162  	if proxyURL != "" {
   163  		pURL, err := url.Parse(proxyURL)
   164  		if err == nil {
   165  			return pURL
   166  		}
   167  		log.Errorf("Error parsing proxyURL from config; creating a non-proxy client: %s", err.Error())
   168  	}
   169  	return nil
   170  }
   171  
   172  func (c *httpClient) initialize(tlsCfg config.TLSConfig, proxyURL string) {
   173  	c.httpClient = c.createClient(tlsCfg)
   174  	if c.singleURL == "" && proxyURL == "" {
   175  		return
   176  	}
   177  
   178  	c.dialer = util.NewDialer(parseProxyURL(proxyURL), c.singleEntryHostMap)
   179  	c.httpClient.Transport.(*http.Transport).DialContext = c.httpDialer
   180  }
   181  
   182  func (c *httpClient) createClient(tlsCfg config.TLSConfig) *http.Client {
   183  	if tlsCfg != nil {
   184  		return c.createHTTPSClient(tlsCfg)
   185  	}
   186  	return c.createHTTPClient()
   187  }
   188  
   189  func (c *httpClient) createHTTPClient() *http.Client {
   190  	httpClient := &http.Client{
   191  		Transport: &http.Transport{},
   192  		Timeout:   c.timeout,
   193  	}
   194  	return httpClient
   195  }
   196  
   197  func (c *httpClient) createHTTPSClient(tlsCfg config.TLSConfig) *http.Client {
   198  	httpClient := &http.Client{
   199  		Transport: &http.Transport{
   200  			TLSClientConfig: tlsCfg.BuildTLSConfig(),
   201  		},
   202  		Timeout: c.timeout,
   203  	}
   204  	return httpClient
   205  }
   206  
   207  func (c *httpClient) httpDialer(ctx context.Context, network, addr string) (net.Conn, error) {
   208  	return c.dialer.DialContext(ctx, network, addr)
   209  }
   210  
   211  func getTimeoutFromEnvironment() time.Duration {
   212  	cfgHTTPClientTimeout := os.Getenv("HTTP_CLIENT_TIMEOUT")
   213  	if cfgHTTPClientTimeout == "" {
   214  		return defaultTimeout
   215  	}
   216  	timeout, err := time.ParseDuration(cfgHTTPClientTimeout)
   217  	if err != nil {
   218  		log.Tracef("Unable to parse the HTTP_CLIENT_TIMEOUT value, using the default http client timeout")
   219  		return defaultTimeout
   220  	}
   221  	return timeout
   222  }
   223  
   224  func (c *httpClient) getURLEncodedQueryParams(queryParams map[string]string) string {
   225  	params := url.Values{}
   226  	for key, value := range queryParams {
   227  		params.Add(key, value)
   228  	}
   229  	return params.Encode()
   230  }
   231  
   232  func (c *httpClient) prepareAPIRequest(ctx context.Context, request Request) (*http.Request, error) {
   233  	requestURL := request.URL
   234  	if len(request.QueryParams) != 0 {
   235  		requestURL += "?" + c.getURLEncodedQueryParams(request.QueryParams)
   236  	}
   237  	var req *http.Request
   238  	var err error
   239  	if request.FormData != nil {
   240  		formData := make(url.Values)
   241  		for k, v := range request.FormData {
   242  			formData.Add(k, v)
   243  		}
   244  
   245  		req, err = http.NewRequestWithContext(ctx, request.Method, requestURL, strings.NewReader(formData.Encode()))
   246  		req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
   247  	} else {
   248  		req, err = http.NewRequestWithContext(ctx, request.Method, requestURL, bytes.NewBuffer(request.Body))
   249  	}
   250  
   251  	if err != nil {
   252  		return req, err
   253  	}
   254  	hasUserAgentHeader := false
   255  	for key, value := range request.Headers {
   256  		req.Header.Set(key, value)
   257  		if strings.ToLower(key) == "user-agent" {
   258  			hasUserAgentHeader = true
   259  		}
   260  	}
   261  	if !hasUserAgentHeader {
   262  		cfgAgentMutex.Lock()
   263  		defer cfgAgentMutex.Unlock()
   264  		req.Header.Set("User-Agent", cfgAgent.userAgent)
   265  	}
   266  	return req, err
   267  }
   268  
   269  func (c *httpClient) prepareAPIResponse(res *http.Response, timer *time.Timer) (*Response, error) {
   270  	var err error
   271  	var responeBuffer bytes.Buffer
   272  	writer := bufio.NewWriter(&responeBuffer)
   273  	for {
   274  		// Reset the timeout timer for reading the response
   275  		timer.Reset(c.timeout)
   276  		_, err = io.CopyN(writer, res.Body, responseBufferSize)
   277  		if err != nil {
   278  			if err == io.EOF {
   279  				err = nil
   280  			}
   281  			break
   282  		}
   283  	}
   284  
   285  	if err != nil {
   286  		return nil, err
   287  	}
   288  
   289  	response := Response{
   290  		Code:    res.StatusCode,
   291  		Body:    responeBuffer.Bytes(),
   292  		Headers: res.Header,
   293  	}
   294  	return &response, err
   295  }
   296  
   297  // Send - send the http request and returns the API Response
   298  func (c *httpClient) Send(request Request) (*Response, error) {
   299  	startTime := time.Now()
   300  	ctx := context.Background()
   301  	cancelCtx, cancel := context.WithCancel(ctx)
   302  	defer cancel()
   303  
   304  	req, err := c.prepareAPIRequest(cancelCtx, request)
   305  	if err != nil {
   306  		log.Errorf("Error preparing api request: %s", err.Error())
   307  		return nil, err
   308  	}
   309  	reqID := uuid.New().String()
   310  
   311  	// Logging for the HTTP request
   312  	statusCode := 0
   313  	receivedData := int64(0)
   314  	defer func() {
   315  		duration := time.Since(startTime)
   316  		targetURL := req.URL.String()
   317  		if c.dialer != nil {
   318  			svcHost := util.ParseAddr(req.URL)
   319  			if entryHost, ok := c.singleEntryHostMap[svcHost]; ok {
   320  				targetURL = req.URL.Scheme + "://" + entryHost + req.URL.Path
   321  			}
   322  		}
   323  
   324  		logger := c.logger.
   325  			WithField("id", reqID).
   326  			WithField("method", req.Method).
   327  			WithField("status", statusCode).
   328  			WithField("duration(ms)", duration.Milliseconds()).
   329  			WithField("url", targetURL)
   330  
   331  		if req.ContentLength > 0 {
   332  			logger = logger.WithField("sent(bytes)", req.ContentLength)
   333  		}
   334  
   335  		if receivedData > 0 {
   336  			logger = logger.WithField("received(bytes)", receivedData)
   337  		}
   338  
   339  		if err != nil {
   340  			logger.WithError(err).
   341  				Trace("request failed")
   342  		} else {
   343  			logger.Trace("request succeeded")
   344  		}
   345  	}()
   346  
   347  	// Start the timer to manage the timeout
   348  	timer := time.AfterFunc(c.timeout, func() {
   349  		cancel()
   350  	})
   351  
   352  	// Prevent reuse of the tcp connection to the same host
   353  	req.Close = true
   354  
   355  	if log.IsHTTPLogTraceEnabled() {
   356  		req = log.NewRequestWithTraceContext(reqID, req)
   357  	}
   358  	res, err := c.httpClient.Do(req)
   359  	if err != nil {
   360  		return nil, err
   361  	}
   362  	defer res.Body.Close()
   363  
   364  	statusCode = res.StatusCode
   365  	receivedData = res.ContentLength
   366  	parseResponse, err := c.prepareAPIResponse(res, timer)
   367  
   368  	return parseResponse, err
   369  }