github.com/opensearch-project/opensearch-go/v2@v2.3.0/opensearchtransport/opensearchtransport.go (about)

     1  // SPDX-License-Identifier: Apache-2.0
     2  //
     3  // The OpenSearch Contributors require contributions made to
     4  // this file be licensed under the Apache-2.0 license or a
     5  // compatible open source license.
     6  //
     7  // Modifications Copyright OpenSearch Contributors. See
     8  // GitHub history for details.
     9  
    10  // Licensed to Elasticsearch B.V. under one or more contributor
    11  // license agreements. See the NOTICE file distributed with
    12  // this work for additional information regarding copyright
    13  // ownership. Elasticsearch B.V. licenses this file to you under
    14  // the Apache License, Version 2.0 (the "License"); you may
    15  // not use this file except in compliance with the License.
    16  // You may obtain a copy of the License at
    17  //
    18  //    http://www.apache.org/licenses/LICENSE-2.0
    19  //
    20  // Unless required by applicable law or agreed to in writing,
    21  // software distributed under the License is distributed on an
    22  // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
    23  // KIND, either express or implied.  See the License for the
    24  // specific language governing permissions and limitations
    25  // under the License.
    26  
    27  package opensearchtransport
    28  
    29  import (
    30  	"bytes"
    31  	"compress/gzip"
    32  	"crypto/x509"
    33  	"errors"
    34  	"fmt"
    35  	"io"
    36  	"io/ioutil"
    37  	"net"
    38  	"net/http"
    39  	"net/url"
    40  	"os"
    41  	"regexp"
    42  	"runtime"
    43  	"strconv"
    44  	"strings"
    45  	"sync"
    46  	"time"
    47  
    48  	"github.com/opensearch-project/opensearch-go/v2/signer"
    49  
    50  	"github.com/opensearch-project/opensearch-go/v2/internal/version"
    51  )
    52  
    53  const (
    54  	// Version returns the package version as a string.
    55  	Version = version.Client
    56  
    57  	// esCompatHeader defines the env var for Compatibility header.
    58  	esCompatHeader = "ELASTIC_CLIENT_APIVERSIONING"
    59  )
    60  
    61  var (
    62  	userAgent           string
    63  	compatibilityHeader bool
    64  	reGoVersion         = regexp.MustCompile(`go(\d+\.\d+\..+)`)
    65  
    66  	defaultMaxRetries    = 3
    67  	defaultRetryOnStatus = [...]int{502, 503, 504}
    68  )
    69  
    70  func init() {
    71  	userAgent = initUserAgent()
    72  
    73  	compatHeaderEnv := os.Getenv(esCompatHeader)
    74  	compatibilityHeader, _ = strconv.ParseBool(compatHeaderEnv)
    75  }
    76  
    77  // Interface defines the interface for HTTP client.
    78  type Interface interface {
    79  	Perform(*http.Request) (*http.Response, error)
    80  }
    81  
    82  // Config represents the configuration of HTTP client.
    83  type Config struct {
    84  	URLs     []*url.URL
    85  	Username string
    86  	Password string
    87  
    88  	Header http.Header
    89  	CACert []byte
    90  
    91  	Signer signer.Signer
    92  
    93  	RetryOnStatus        []int
    94  	DisableRetry         bool
    95  	EnableRetryOnTimeout bool
    96  	MaxRetries           int
    97  	RetryBackoff         func(attempt int) time.Duration
    98  
    99  	CompressRequestBody bool
   100  
   101  	EnableMetrics     bool
   102  	EnableDebugLogger bool
   103  
   104  	DiscoverNodesInterval time.Duration
   105  
   106  	Transport http.RoundTripper
   107  	Logger    Logger
   108  	Selector  Selector
   109  
   110  	ConnectionPoolFunc func([]*Connection, Selector) ConnectionPool
   111  }
   112  
   113  // Client represents the HTTP client.
   114  type Client struct {
   115  	sync.Mutex
   116  
   117  	urls     []*url.URL
   118  	username string
   119  	password string
   120  	header   http.Header
   121  
   122  	signer signer.Signer
   123  
   124  	retryOnStatus         []int
   125  	disableRetry          bool
   126  	enableRetryOnTimeout  bool
   127  	maxRetries            int
   128  	retryBackoff          func(attempt int) time.Duration
   129  	discoverNodesInterval time.Duration
   130  	discoverNodesTimer    *time.Timer
   131  
   132  	compressRequestBody bool
   133  
   134  	metrics *metrics
   135  
   136  	transport http.RoundTripper
   137  	logger    Logger
   138  	selector  Selector
   139  	pool      ConnectionPool
   140  	poolFunc  func([]*Connection, Selector) ConnectionPool
   141  }
   142  
   143  // New creates new transport client.
   144  //
   145  // http.DefaultTransport will be used if no transport is passed in the configuration.
   146  func New(cfg Config) (*Client, error) {
   147  	if cfg.Transport == nil {
   148  		cfg.Transport = http.DefaultTransport
   149  	}
   150  
   151  	if cfg.CACert != nil {
   152  		httpTransport, ok := cfg.Transport.(*http.Transport)
   153  		if !ok {
   154  			return nil, fmt.Errorf("unable to set CA certificate for transport of type %T", cfg.Transport)
   155  		}
   156  
   157  		httpTransport = httpTransport.Clone()
   158  		httpTransport.TLSClientConfig.RootCAs = x509.NewCertPool()
   159  
   160  		if ok := httpTransport.TLSClientConfig.RootCAs.AppendCertsFromPEM(cfg.CACert); !ok {
   161  			return nil, errors.New("unable to add CA certificate")
   162  		}
   163  
   164  		cfg.Transport = httpTransport
   165  	}
   166  
   167  	if len(cfg.RetryOnStatus) == 0 {
   168  		cfg.RetryOnStatus = defaultRetryOnStatus[:]
   169  	}
   170  
   171  	if cfg.MaxRetries == 0 {
   172  		cfg.MaxRetries = defaultMaxRetries
   173  	}
   174  
   175  	var conns []*Connection
   176  	for _, u := range cfg.URLs {
   177  		conns = append(conns, &Connection{URL: u})
   178  	}
   179  
   180  	client := Client{
   181  		urls:     cfg.URLs,
   182  		username: cfg.Username,
   183  		password: cfg.Password,
   184  		header:   cfg.Header,
   185  
   186  		signer: cfg.Signer,
   187  
   188  		retryOnStatus:         cfg.RetryOnStatus,
   189  		disableRetry:          cfg.DisableRetry,
   190  		enableRetryOnTimeout:  cfg.EnableRetryOnTimeout,
   191  		maxRetries:            cfg.MaxRetries,
   192  		retryBackoff:          cfg.RetryBackoff,
   193  		discoverNodesInterval: cfg.DiscoverNodesInterval,
   194  
   195  		compressRequestBody: cfg.CompressRequestBody,
   196  
   197  		transport: cfg.Transport,
   198  		logger:    cfg.Logger,
   199  		selector:  cfg.Selector,
   200  		poolFunc:  cfg.ConnectionPoolFunc,
   201  	}
   202  
   203  	if client.poolFunc != nil {
   204  		client.pool = client.poolFunc(conns, client.selector)
   205  	} else {
   206  		client.pool, _ = NewConnectionPool(conns, client.selector)
   207  	}
   208  
   209  	if cfg.EnableDebugLogger {
   210  		debugLogger = &debuggingLogger{Output: os.Stdout}
   211  	}
   212  
   213  	if cfg.EnableMetrics {
   214  		client.metrics = &metrics{responses: make(map[int]int)}
   215  		// TODO(karmi): Type assertion to interface
   216  		if pool, ok := client.pool.(*singleConnectionPool); ok {
   217  			pool.metrics = client.metrics
   218  		}
   219  		if pool, ok := client.pool.(*statusConnectionPool); ok {
   220  			pool.metrics = client.metrics
   221  		}
   222  	}
   223  
   224  	if client.discoverNodesInterval > 0 {
   225  		time.AfterFunc(client.discoverNodesInterval, func() {
   226  			client.scheduleDiscoverNodes(client.discoverNodesInterval)
   227  		})
   228  	}
   229  
   230  	return &client, nil
   231  }
   232  
   233  // Perform executes the request and returns a response or error.
   234  func (c *Client) Perform(req *http.Request) (*http.Response, error) {
   235  	var (
   236  		res *http.Response
   237  		err error
   238  	)
   239  
   240  	// Compatibility Header
   241  	if compatibilityHeader {
   242  		if req.Body != nil {
   243  			req.Header.Set("Content-Type", "application/vnd.elasticsearch+json;compatible-with=7")
   244  		}
   245  		req.Header.Set("Accept", "application/vnd.elasticsearch+json;compatible-with=7")
   246  	}
   247  
   248  	// Record metrics, when enabled
   249  	if c.metrics != nil {
   250  		c.metrics.Lock()
   251  		c.metrics.requests++
   252  		c.metrics.Unlock()
   253  	}
   254  
   255  	// Update request
   256  	c.setReqUserAgent(req)
   257  	c.setReqGlobalHeader(req)
   258  
   259  	if req.Body != nil && req.Body != http.NoBody {
   260  		if c.compressRequestBody {
   261  			var buf bytes.Buffer
   262  			zw := gzip.NewWriter(&buf)
   263  			if _, err := io.Copy(zw, req.Body); err != nil {
   264  				return nil, fmt.Errorf("failed to compress request body: %s", err)
   265  			}
   266  			if err := zw.Close(); err != nil {
   267  				return nil, fmt.Errorf("failed to compress request body (during close): %s", err)
   268  			}
   269  
   270  			req.GetBody = func() (io.ReadCloser, error) {
   271  				r := buf
   272  				return ioutil.NopCloser(&r), nil
   273  			}
   274  			req.Body, _ = req.GetBody()
   275  
   276  			req.Header.Set("Content-Encoding", "gzip")
   277  			req.ContentLength = int64(buf.Len())
   278  
   279  		} else if req.GetBody == nil {
   280  			if !c.disableRetry || (c.logger != nil && c.logger.RequestBodyEnabled()) {
   281  				var buf bytes.Buffer
   282  				buf.ReadFrom(req.Body)
   283  
   284  				req.GetBody = func() (io.ReadCloser, error) {
   285  					r := buf
   286  					return ioutil.NopCloser(&r), nil
   287  				}
   288  				req.Body, _ = req.GetBody()
   289  			}
   290  		}
   291  	}
   292  
   293  	for i := 0; i <= c.maxRetries; i++ {
   294  		var (
   295  			conn            *Connection
   296  			shouldRetry     bool
   297  			shouldCloseBody bool
   298  		)
   299  
   300  		// Get connection from the pool
   301  		c.Lock()
   302  		conn, err = c.pool.Next()
   303  		c.Unlock()
   304  		if err != nil {
   305  			if c.logger != nil {
   306  				c.logRoundTrip(req, nil, err, time.Time{}, time.Duration(0))
   307  			}
   308  			return nil, fmt.Errorf("cannot get connection: %s", err)
   309  		}
   310  
   311  		// Update request
   312  		c.setReqURL(conn.URL, req)
   313  		c.setReqAuth(conn.URL, req)
   314  
   315  		if err = c.signRequest(req); err != nil {
   316  			return nil, fmt.Errorf("failed to sign request: %s", err)
   317  		}
   318  
   319  		if !c.disableRetry && i > 0 && req.Body != nil && req.Body != http.NoBody {
   320  			body, err := req.GetBody()
   321  			if err != nil {
   322  				return nil, fmt.Errorf("cannot get request body: %s", err)
   323  			}
   324  			req.Body = body
   325  		}
   326  
   327  		// Set up time measures and execute the request
   328  		start := time.Now().UTC()
   329  		res, err = c.transport.RoundTrip(req)
   330  		dur := time.Since(start)
   331  
   332  		// Log request and response
   333  		if c.logger != nil {
   334  			if c.logger.RequestBodyEnabled() && req.Body != nil && req.Body != http.NoBody {
   335  				req.Body, _ = req.GetBody()
   336  			}
   337  			c.logRoundTrip(req, res, err, start, dur)
   338  		}
   339  
   340  		if err != nil {
   341  			// Record metrics, when enabled
   342  			if c.metrics != nil {
   343  				c.metrics.Lock()
   344  				c.metrics.failures++
   345  				c.metrics.Unlock()
   346  			}
   347  
   348  			// Report the connection as unsuccessful
   349  			c.Lock()
   350  			c.pool.OnFailure(conn)
   351  			c.Unlock()
   352  
   353  			// Retry on EOF errors
   354  			if err == io.EOF {
   355  				shouldRetry = true
   356  			}
   357  
   358  			// Retry on network errors, but not on timeout errors, unless configured
   359  			if err, ok := err.(net.Error); ok {
   360  				if (!err.Timeout() || c.enableRetryOnTimeout) && !c.disableRetry {
   361  					shouldRetry = true
   362  				}
   363  			}
   364  		} else {
   365  			// Report the connection as succesfull
   366  			c.Lock()
   367  			c.pool.OnSuccess(conn)
   368  			c.Unlock()
   369  		}
   370  
   371  		if res != nil && c.metrics != nil {
   372  			c.metrics.Lock()
   373  			c.metrics.responses[res.StatusCode]++
   374  			c.metrics.Unlock()
   375  		}
   376  
   377  		// Retry on configured response statuses
   378  		if res != nil && !c.disableRetry {
   379  			for _, code := range c.retryOnStatus {
   380  				if res.StatusCode == code {
   381  					shouldRetry = true
   382  					shouldCloseBody = true
   383  				}
   384  			}
   385  		}
   386  
   387  		// Break if retry should not be performed
   388  		if !shouldRetry {
   389  			break
   390  		}
   391  
   392  		// Drain and close body when retrying after response
   393  		if shouldCloseBody && i < c.maxRetries {
   394  			if res.Body != nil {
   395  				io.Copy(ioutil.Discard, res.Body)
   396  				res.Body.Close()
   397  			}
   398  		}
   399  
   400  		// Delay the retry if a backoff function is configured
   401  		if c.retryBackoff != nil {
   402  			time.Sleep(c.retryBackoff(i + 1))
   403  		}
   404  	}
   405  	// Read, close and replace the http reponse body to close the connection
   406  	if res != nil && res.Body != nil {
   407  		body, err := io.ReadAll(res.Body)
   408  		res.Body.Close()
   409  		if err == nil {
   410  			res.Body = io.NopCloser(bytes.NewReader(body))
   411  		}
   412  	}
   413  
   414  	// TODO(karmi): Wrap error
   415  	return res, err
   416  }
   417  
   418  // URLs returns a list of transport URLs.
   419  func (c *Client) URLs() []*url.URL {
   420  	return c.pool.URLs()
   421  }
   422  
   423  func (c *Client) setReqURL(u *url.URL, req *http.Request) *http.Request {
   424  	req.URL.Scheme = u.Scheme
   425  	req.URL.Host = u.Host
   426  
   427  	if u.Path != "" {
   428  		var b strings.Builder
   429  		b.Grow(len(u.Path) + len(req.URL.Path))
   430  		b.WriteString(u.Path)
   431  		b.WriteString(req.URL.Path)
   432  		req.URL.Path = b.String()
   433  	}
   434  
   435  	return req
   436  }
   437  
   438  func (c *Client) setReqAuth(u *url.URL, req *http.Request) *http.Request {
   439  	if _, ok := req.Header["Authorization"]; !ok {
   440  		if u.User != nil {
   441  			password, _ := u.User.Password()
   442  			req.SetBasicAuth(u.User.Username(), password)
   443  			return req
   444  		}
   445  
   446  		if c.username != "" && c.password != "" {
   447  			req.SetBasicAuth(c.username, c.password)
   448  			return req
   449  		}
   450  	}
   451  
   452  	return req
   453  }
   454  
   455  func (c *Client) signRequest(req *http.Request) error {
   456  	if c.signer != nil {
   457  		return c.signer.SignRequest(req)
   458  	}
   459  	return nil
   460  }
   461  
   462  func (c *Client) setReqUserAgent(req *http.Request) *http.Request {
   463  	req.Header.Set("User-Agent", userAgent)
   464  	return req
   465  }
   466  
   467  func (c *Client) setReqGlobalHeader(req *http.Request) *http.Request {
   468  	if len(c.header) > 0 {
   469  		for k, v := range c.header {
   470  			if req.Header.Get(k) != k {
   471  				for _, vv := range v {
   472  					req.Header.Add(k, vv)
   473  				}
   474  			}
   475  		}
   476  	}
   477  	return req
   478  }
   479  
   480  func (c *Client) logRoundTrip(
   481  	req *http.Request,
   482  	res *http.Response,
   483  	err error,
   484  	start time.Time,
   485  	dur time.Duration,
   486  ) {
   487  	var dupRes http.Response
   488  	if res != nil {
   489  		dupRes = *res
   490  	}
   491  	if c.logger.ResponseBodyEnabled() {
   492  		if res != nil && res.Body != nil && res.Body != http.NoBody {
   493  			b1, b2, _ := duplicateBody(res.Body)
   494  			dupRes.Body = b1
   495  			res.Body = b2
   496  		}
   497  	}
   498  	c.logger.LogRoundTrip(req, &dupRes, err, start, dur) // errcheck exclude
   499  }
   500  
   501  func initUserAgent() string {
   502  	var b strings.Builder
   503  
   504  	b.WriteString("opensearch-go")
   505  	b.WriteRune('/')
   506  	b.WriteString(Version)
   507  	b.WriteRune(' ')
   508  	b.WriteRune('(')
   509  	b.WriteString(runtime.GOOS)
   510  	b.WriteRune(' ')
   511  	b.WriteString(runtime.GOARCH)
   512  	b.WriteString("; ")
   513  	b.WriteString("Go ")
   514  	if v := reGoVersion.ReplaceAllString(runtime.Version(), "$1"); v != "" {
   515  		b.WriteString(v)
   516  	} else {
   517  		b.WriteString(runtime.Version())
   518  	}
   519  	b.WriteRune(')')
   520  
   521  	return b.String()
   522  }