github.com/yandex/pandora@v0.5.32/components/guns/http_scenario/gun.go (about)

     1  package httpscenario
     2  
     3  import (
     4  	"bytes"
     5  	"fmt"
     6  	"io"
     7  	"math/rand"
     8  	"net"
     9  	"net/http"
    10  	"net/http/httptrace"
    11  	"net/http/httputil"
    12  	"strconv"
    13  	"strings"
    14  	"time"
    15  
    16  	phttp "github.com/yandex/pandora/components/guns/http"
    17  	"github.com/yandex/pandora/core"
    18  	"github.com/yandex/pandora/core/aggregator/netsample"
    19  	"github.com/yandex/pandora/core/warmup"
    20  	"go.uber.org/zap"
    21  )
    22  
    23  type Gun interface {
    24  	Shoot(ammo *Scenario)
    25  	Bind(sample netsample.Aggregator, deps core.GunDeps) error
    26  	WarmUp(opts *warmup.Options) (any, error)
    27  }
    28  
    29  const (
    30  	EmptyTag = "__EMPTY__"
    31  )
    32  
    33  type ScenarioGun struct {
    34  	base *phttp.BaseGun
    35  }
    36  
    37  var _ Gun = (*ScenarioGun)(nil)
    38  var _ io.Closer = (*ScenarioGun)(nil)
    39  
    40  func (g *ScenarioGun) WarmUp(opts *warmup.Options) (any, error) {
    41  	return g.base.WarmUp(opts)
    42  }
    43  
    44  func (g *ScenarioGun) Bind(aggregator netsample.Aggregator, deps core.GunDeps) error {
    45  	return g.base.Bind(aggregator, deps)
    46  }
    47  
    48  // Shoot is thread safe if Do and Connect hooks are thread safe.
    49  func (g *ScenarioGun) Shoot(ammo *Scenario) {
    50  	if g.base.Aggregator == nil {
    51  		zap.L().Panic("must bind before shoot")
    52  	}
    53  	if g.base.Connect != nil {
    54  		err := g.base.Connect(g.base.Ctx)
    55  		if err != nil {
    56  			g.base.Log.Warn("Connect fail", zap.Error(err))
    57  			return
    58  		}
    59  	}
    60  
    61  	templateVars := map[string]any{
    62  		"source": ammo.VariableStorage.Variables(),
    63  	}
    64  
    65  	err := g.shoot(ammo, templateVars)
    66  	if err != nil {
    67  		g.base.Log.Warn("Invalid ammo", zap.Uint64("request", ammo.ID), zap.Error(err))
    68  		return
    69  	}
    70  }
    71  
    72  func (g *ScenarioGun) Do(req *http.Request) (*http.Response, error) {
    73  	return g.base.Client.Do(req)
    74  }
    75  
    76  func (g *ScenarioGun) Close() error {
    77  	if g.base.OnClose != nil {
    78  		return g.base.OnClose()
    79  	}
    80  	return nil
    81  }
    82  
    83  func (g *ScenarioGun) shoot(ammo *Scenario, templateVars map[string]any) error {
    84  	if templateVars == nil {
    85  		templateVars = map[string]any{}
    86  	}
    87  
    88  	requestVars := map[string]any{}
    89  	templateVars["request"] = requestVars
    90  
    91  	startAt := time.Now()
    92  	var idBuilder strings.Builder
    93  	rnd := strconv.Itoa(rand.Int())
    94  	for _, req := range ammo.Requests {
    95  		tag := ammo.Name + "." + req.Name
    96  		g.buildLogID(&idBuilder, tag, ammo.ID, rnd)
    97  		sample := netsample.Acquire(tag)
    98  
    99  		err := g.shootStep(req, sample, ammo.Name, templateVars, requestVars, idBuilder.String())
   100  		if err != nil {
   101  			g.reportErr(sample, err)
   102  			return err
   103  		}
   104  	}
   105  	spent := time.Since(startAt)
   106  	if ammo.MinWaitingTime > spent {
   107  		time.Sleep(ammo.MinWaitingTime - spent)
   108  	}
   109  	return nil
   110  }
   111  
   112  func (g *ScenarioGun) shootStep(step Request, sample *netsample.Sample, ammoName string, templateVars map[string]any, requestVars map[string]any, stepLogID string) error {
   113  	const op = "base_gun.shootStep"
   114  
   115  	stepVars := map[string]any{}
   116  	requestVars[step.Name] = stepVars
   117  
   118  	// Preprocessor
   119  	if step.Preprocessor != nil {
   120  		preProcVars, err := step.Preprocessor.Process(templateVars)
   121  		if err != nil {
   122  			return fmt.Errorf("%s preProcessor %w", op, err)
   123  		}
   124  		stepVars["preprocessor"] = preProcVars
   125  		if g.base.DebugLog {
   126  			g.base.GunDeps.Log.Debug("Preprocessor variables", zap.Any(fmt.Sprintf(".request.%s.preprocessor", step.Name), preProcVars))
   127  		}
   128  	}
   129  
   130  	// Entities
   131  	reqParts := RequestParts{
   132  		URL:     step.URI,
   133  		Method:  step.Method,
   134  		Body:    step.GetBody(),
   135  		Headers: step.GetHeaders(),
   136  	}
   137  
   138  	// Template
   139  	if err := step.Templater.Apply(&reqParts, templateVars, ammoName, step.Name); err != nil {
   140  		return fmt.Errorf("%s templater.Apply %w", op, err)
   141  	}
   142  
   143  	// Prepare request
   144  	req, err := g.prepareRequest(reqParts)
   145  	if err != nil {
   146  		return fmt.Errorf("%s prepareRequest %w", op, err)
   147  	}
   148  
   149  	var reqBytes []byte
   150  	if g.base.Config.AnswLog.Enabled {
   151  		var dumpErr error
   152  		reqBytes, dumpErr = httputil.DumpRequestOut(req, true)
   153  		if dumpErr != nil {
   154  			g.base.Log.Error("Error dumping request:", zap.Error(dumpErr))
   155  		}
   156  	}
   157  
   158  	timings, req := g.initTracing(req, sample)
   159  
   160  	resp, err := g.base.Client.Do(req)
   161  
   162  	g.saveTrace(timings, sample, resp)
   163  
   164  	if err != nil {
   165  		return fmt.Errorf("%s g.Do %w", op, err)
   166  	}
   167  
   168  	// Log
   169  	processors := step.Postprocessors
   170  	var respBody *bytes.Reader
   171  	var respBodyBytes []byte
   172  	if g.base.Config.AnswLog.Enabled || g.base.DebugLog || len(processors) > 0 {
   173  		respBodyBytes, err = io.ReadAll(resp.Body)
   174  		if err == nil {
   175  			respBody = bytes.NewReader(respBodyBytes)
   176  		}
   177  	} else {
   178  		_, err = io.Copy(io.Discard, resp.Body)
   179  	}
   180  	if err != nil {
   181  		return fmt.Errorf("%s io.Copy %w", op, err)
   182  	}
   183  	defer func() {
   184  		closeErr := resp.Body.Close()
   185  		if closeErr != nil {
   186  			g.base.GunDeps.Log.Error("resp.Body.Close", zap.Error(closeErr))
   187  		}
   188  	}()
   189  
   190  	if g.base.DebugLog {
   191  		g.verboseLogging(resp, reqBytes, respBodyBytes)
   192  	}
   193  	if g.base.Config.AnswLog.Enabled {
   194  		g.answReqRespLogging(reqBytes, resp, respBodyBytes, stepLogID)
   195  	}
   196  
   197  	// Postprocessor
   198  	postprocessorVars := map[string]any{}
   199  	var vars map[string]any
   200  	for _, postprocessor := range processors {
   201  		vars, err = postprocessor.Process(resp, respBody)
   202  		if err != nil {
   203  			return fmt.Errorf("%s postprocessor.Postprocess %w", op, err)
   204  		}
   205  		for k, v := range vars {
   206  			postprocessorVars[k] = v
   207  		}
   208  		_, err = respBody.Seek(0, io.SeekStart)
   209  		if err != nil {
   210  			return fmt.Errorf("%s postprocessor.Postprocess %w", op, err)
   211  		}
   212  	}
   213  	stepVars["postprocessor"] = postprocessorVars
   214  
   215  	sample.SetProtoCode(resp.StatusCode)
   216  	g.base.Aggregator.Report(sample)
   217  
   218  	if g.base.DebugLog {
   219  		g.base.GunDeps.Log.Debug("Postprocessor variables", zap.Any(fmt.Sprintf(".request.%s.postprocessor", step.Name), postprocessorVars))
   220  	}
   221  
   222  	if step.Sleep > 0 {
   223  		time.Sleep(step.Sleep)
   224  	}
   225  	return nil
   226  }
   227  
   228  func (g *ScenarioGun) buildLogID(idBuilder *strings.Builder, tag string, ammoID uint64, rnd string) {
   229  	idBuilder.Reset()
   230  	idBuilder.WriteString(tag)
   231  	idBuilder.WriteByte('.')
   232  	idBuilder.WriteString(rnd)
   233  	idBuilder.WriteByte('.')
   234  	idBuilder.WriteString(strconv.Itoa(int(ammoID)))
   235  }
   236  
   237  func (g *ScenarioGun) prepareRequest(reqParts RequestParts) (*http.Request, error) {
   238  	const op = "base_gun.prepareRequest"
   239  
   240  	var reader io.Reader
   241  	if reqParts.Body != nil {
   242  		reader = bytes.NewReader(reqParts.Body)
   243  	}
   244  
   245  	req, err := http.NewRequest(reqParts.Method, reqParts.URL, reader)
   246  	if err != nil {
   247  		return nil, fmt.Errorf("%s http.NewRequest %w", op, err)
   248  	}
   249  	for k, v := range reqParts.Headers {
   250  		req.Header.Set(k, v)
   251  	}
   252  
   253  	if g.base.Config.SSL {
   254  		req.URL.Scheme = "https"
   255  	} else {
   256  		req.URL.Scheme = "http"
   257  	}
   258  	if req.Host == "" {
   259  		req.Host = getHostWithoutPort(g.base.Config.Target)
   260  	}
   261  	req.URL.Host = g.base.Config.TargetResolved
   262  
   263  	return req, err
   264  }
   265  
   266  func (g *ScenarioGun) initTracing(req *http.Request, sample *netsample.Sample) (*phttp.TraceTimings, *http.Request) {
   267  	var timings *phttp.TraceTimings
   268  	if g.base.Config.HTTPTrace.TraceEnabled {
   269  		var clientTracer *httptrace.ClientTrace
   270  		clientTracer, timings = phttp.CreateHTTPTrace()
   271  		req = req.WithContext(httptrace.WithClientTrace(req.Context(), clientTracer))
   272  	}
   273  	if g.base.Config.HTTPTrace.DumpEnabled {
   274  		requestDump, err := httputil.DumpRequest(req, true)
   275  		if err != nil {
   276  			g.base.Log.Error("DumpRequest error", zap.Error(err))
   277  		} else {
   278  			sample.SetRequestBytes(len(requestDump))
   279  		}
   280  	}
   281  	return timings, req
   282  }
   283  
   284  func (g *ScenarioGun) saveTrace(timings *phttp.TraceTimings, sample *netsample.Sample, resp *http.Response) {
   285  	if g.base.Config.HTTPTrace.TraceEnabled && timings != nil {
   286  		sample.SetReceiveTime(timings.GetReceiveTime())
   287  	}
   288  	if g.base.Config.HTTPTrace.DumpEnabled && resp != nil {
   289  		responseDump, e := httputil.DumpResponse(resp, true)
   290  		if e != nil {
   291  			g.base.Log.Error("DumpResponse error", zap.Error(e))
   292  		} else {
   293  			sample.SetResponseBytes(len(responseDump))
   294  		}
   295  	}
   296  	if g.base.Config.HTTPTrace.TraceEnabled && timings != nil {
   297  		sample.SetConnectTime(timings.GetConnectTime())
   298  		sample.SetSendTime(timings.GetSendTime())
   299  		sample.SetLatency(timings.GetLatency())
   300  	}
   301  }
   302  
   303  func (g *ScenarioGun) verboseLogging(resp *http.Response, reqBody, respBody []byte) {
   304  	if resp == nil {
   305  		g.base.Log.Error("Response is nil")
   306  		return
   307  	}
   308  	fields := make([]zap.Field, 0, 4)
   309  	fields = append(fields, zap.String("URL", resp.Request.URL.String()))
   310  	fields = append(fields, zap.String("Host", resp.Request.Host))
   311  	fields = append(fields, zap.Any("Headers", resp.Request.Header))
   312  	if reqBody != nil {
   313  		fields = append(fields, zap.ByteString("Body", reqBody))
   314  	}
   315  	g.base.Log.Debug("Request debug info", fields...)
   316  
   317  	fields = fields[:0]
   318  	fields = append(fields, zap.Int("Status Code", resp.StatusCode))
   319  	fields = append(fields, zap.String("Status", resp.Status))
   320  	fields = append(fields, zap.Any("Headers", resp.Header))
   321  	if reqBody != nil {
   322  		fields = append(fields, zap.ByteString("Body", respBody))
   323  	}
   324  	g.base.Log.Debug("Response debug info", fields...)
   325  }
   326  
   327  func (g *ScenarioGun) answLogging(bodyBytes []byte, resp *http.Response, respBytes []byte, stepName string) {
   328  	msg := fmt.Sprintf("REQUEST[%s]:\n%s\n", stepName, string(bodyBytes))
   329  	g.base.AnswLog.Debug(msg)
   330  
   331  	headers := ""
   332  	var writer bytes.Buffer
   333  	err := resp.Header.Write(&writer)
   334  	if err == nil {
   335  		headers = writer.String()
   336  	} else {
   337  		g.base.AnswLog.Error("error writing header", zap.Error(err))
   338  	}
   339  
   340  	msg = fmt.Sprintf("RESPONSE[%s]:\n%s %s\n%s\n%s\n", stepName, resp.Proto, resp.Status, headers, string(respBytes))
   341  	g.base.AnswLog.Debug(msg)
   342  }
   343  
   344  func (g *ScenarioGun) answReqRespLogging(reqBytes []byte, resp *http.Response, respBytes []byte, stepName string) {
   345  	switch g.base.Config.AnswLog.Filter {
   346  	case "all":
   347  		g.answLogging(reqBytes, resp, respBytes, stepName)
   348  	case "warning":
   349  		if resp.StatusCode >= 400 {
   350  			g.answLogging(reqBytes, resp, respBytes, stepName)
   351  		}
   352  	case "error":
   353  		if resp.StatusCode >= 500 {
   354  			g.answLogging(reqBytes, resp, respBytes, stepName)
   355  		}
   356  	}
   357  }
   358  
   359  func (g *ScenarioGun) reportErr(sample *netsample.Sample, err error) {
   360  	if err == nil {
   361  		return
   362  	}
   363  	sample.AddTag(EmptyTag)
   364  	sample.SetProtoCode(0)
   365  	sample.SetErr(err)
   366  	g.base.Aggregator.Report(sample)
   367  }
   368  
   369  func getHostWithoutPort(target string) string {
   370  	host, _, err := net.SplitHostPort(target)
   371  	if err != nil {
   372  		host = target
   373  	}
   374  	return host
   375  }