istio.io/istio@v0.0.0-20240520182934-d79c90f27776/pkg/test/echo/server/forwarder/util.go (about)

     1  // Copyright Istio Authors
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package forwarder
    16  
    17  import (
    18  	"bytes"
    19  	"context"
    20  	"fmt"
    21  	"net"
    22  	"net/http"
    23  	"net/url"
    24  	"sync"
    25  	"time"
    26  
    27  	"github.com/hashicorp/go-multierror"
    28  	"golang.org/x/net/proxy"
    29  
    30  	"istio.io/istio/pkg/hbone"
    31  	"istio.io/istio/pkg/log"
    32  	"istio.io/istio/pkg/test/echo"
    33  	"istio.io/istio/pkg/test/echo/common"
    34  	"istio.io/istio/pkg/test/echo/proto"
    35  )
    36  
    37  const (
    38  	hostHeader = "Host"
    39  )
    40  
    41  var fwLog = log.RegisterScope("forwarder", "echo clientside")
    42  
    43  func writeForwardedHeaders(out *bytes.Buffer, requestID int, header http.Header) {
    44  	for key, values := range header {
    45  		for _, v := range values {
    46  			echo.ForwarderHeaderField.WriteKeyValueForRequest(out, requestID, key, v)
    47  		}
    48  	}
    49  }
    50  
    51  func newDialer(cfg *Config) hbone.Dialer {
    52  	if cfg.Request.Hbone.GetAddress() != "" {
    53  		out := hbone.NewDialer(hbone.Config{
    54  			ProxyAddress: cfg.Request.Hbone.GetAddress(),
    55  			Headers:      cfg.hboneHeaders,
    56  			TLS:          cfg.hboneTLSConfig,
    57  		})
    58  		return out
    59  	}
    60  	proxyURL, _ := url.Parse(cfg.Proxy)
    61  	if len(cfg.Proxy) > 0 && proxyURL.Scheme == "socks5" {
    62  		dialer, _ := proxy.SOCKS5("tcp", proxyURL.Host, nil, proxy.Direct)
    63  		return dialer.(hbone.Dialer)
    64  	}
    65  	out := &net.Dialer{
    66  		Timeout: common.ConnectionTimeout,
    67  	}
    68  	if cfg.forceDNSLookup {
    69  		out.Resolver = newResolver(common.ConnectionTimeout, "", "")
    70  	}
    71  	return out
    72  }
    73  
    74  func newResolver(timeout time.Duration, protocol, dnsServer string) *net.Resolver {
    75  	return &net.Resolver{
    76  		PreferGo: true,
    77  		Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
    78  			d := net.Dialer{
    79  				Timeout: timeout,
    80  			}
    81  			nt := protocol
    82  			if nt == "" {
    83  				nt = network
    84  			}
    85  			addr := dnsServer
    86  			if addr == "" {
    87  				addr = address
    88  			}
    89  			return d.DialContext(ctx, nt, addr)
    90  		},
    91  	}
    92  }
    93  
    94  // doForward sends the requests and collect the responses.
    95  func doForward(ctx context.Context, cfg *Config, e *executor, doReq func(context.Context, *Config, int) (string, error)) (*proto.ForwardEchoResponse, error) {
    96  	// make the timeout apply to the entire set of requests
    97  	ctx, cancel := context.WithTimeout(ctx, cfg.timeout)
    98  	defer cancel()
    99  
   100  	responses := make([]string, cfg.count)
   101  	responseTimes := make([]time.Duration, cfg.count)
   102  	var responsesMu sync.Mutex
   103  
   104  	var throttle *time.Ticker
   105  	qps := int(cfg.Request.Qps)
   106  	if qps > 0 {
   107  		sleepTime := time.Second / time.Duration(qps)
   108  		fwLog.Debugf("Sleeping %v between requests", sleepTime)
   109  		throttle = time.NewTicker(sleepTime)
   110  		defer throttle.Stop()
   111  	}
   112  
   113  	g := e.NewGroup()
   114  	for index := 0; index < cfg.count; index++ {
   115  		index := index
   116  		workFn := func() error {
   117  			st := time.Now()
   118  			resp, err := doReq(ctx, cfg, index)
   119  			if err != nil {
   120  				fwLog.Debugf("request failed: %v", err)
   121  				return err
   122  			}
   123  			fwLog.Debugf("got resp: %v", resp)
   124  
   125  			responsesMu.Lock()
   126  			responses[index] = resp
   127  			responseTimes[index] = time.Since(st)
   128  			responsesMu.Unlock()
   129  			return nil
   130  		}
   131  		if throttle != nil {
   132  			select {
   133  			case <-ctx.Done():
   134  				break
   135  			case <-throttle.C:
   136  			}
   137  		}
   138  
   139  		if cfg.PropagateResponse != nil {
   140  			workFn() // nolint: errcheck
   141  		} else {
   142  			g.Go(ctx, workFn)
   143  		}
   144  	}
   145  
   146  	// Convert the result of the wait into a channel.
   147  	requestsDone := make(chan *multierror.Error)
   148  	go func() {
   149  		requestsDone <- g.Wait()
   150  	}()
   151  
   152  	select {
   153  	case merr := <-requestsDone:
   154  		if err := merr.ErrorOrNil(); err != nil {
   155  			return nil, fmt.Errorf("%d/%d requests had errors; first error: %v", merr.Len(), cfg.count, merr.Errors[0])
   156  		}
   157  
   158  		return &proto.ForwardEchoResponse{
   159  			Output: responses,
   160  		}, nil
   161  	case <-ctx.Done():
   162  		responsesMu.Lock()
   163  		defer responsesMu.Unlock()
   164  
   165  		var c int
   166  		var tt time.Duration
   167  		for id, res := range responses {
   168  			if res != "" && responseTimes[id] != 0 {
   169  				c++
   170  				tt += responseTimes[id]
   171  			}
   172  		}
   173  		var avgTime time.Duration
   174  		if c > 0 {
   175  			avgTime = tt / time.Duration(c)
   176  		}
   177  		return nil, fmt.Errorf("request set timed out after %v and only %d/%d requests completed (%v avg)",
   178  			cfg.timeout, c, cfg.count, avgTime)
   179  	}
   180  }