github.com/waldiirawan/apm-agent-go/v2@v2.2.2/tracer_test.go (about)

     1  // Licensed to Elasticsearch B.V. under one or more contributor
     2  // license agreements. See the NOTICE file distributed with
     3  // this work for additional information regarding copyright
     4  // ownership. Elasticsearch B.V. licenses this file to you under
     5  // the Apache License, Version 2.0 (the "License"); you may
     6  // not use this file except in compliance with the License.
     7  // You may obtain a copy of the License at
     8  //
     9  //     http://www.apache.org/licenses/LICENSE-2.0
    10  //
    11  // Unless required by applicable law or agreed to in writing,
    12  // software distributed under the License is distributed on an
    13  // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
    14  // KIND, either express or implied.  See the License for the
    15  // specific language governing permissions and limitations
    16  // under the License.
    17  
    18  package apm_test
    19  
    20  import (
    21  	"bufio"
    22  	"compress/zlib"
    23  	"context"
    24  	"encoding/json"
    25  	"errors"
    26  	"fmt"
    27  	"io"
    28  	"io/ioutil"
    29  	"net/http"
    30  	"net/http/httptest"
    31  	"os"
    32  	"runtime"
    33  	"strconv"
    34  	"sync"
    35  	"sync/atomic"
    36  	"testing"
    37  	"time"
    38  
    39  	"github.com/stretchr/testify/assert"
    40  	"github.com/stretchr/testify/require"
    41  
    42  	"github.com/waldiirawan/apm-agent-go/v2"
    43  	"github.com/waldiirawan/apm-agent-go/v2/apmtest"
    44  	"github.com/waldiirawan/apm-agent-go/v2/internal/apmhostutil"
    45  	"github.com/waldiirawan/apm-agent-go/v2/internal/apmversion"
    46  	"github.com/waldiirawan/apm-agent-go/v2/model"
    47  	"github.com/waldiirawan/apm-agent-go/v2/transport"
    48  	"github.com/waldiirawan/apm-agent-go/v2/transport/transporttest"
    49  )
    50  
    51  func TestDefaultTracer(t *testing.T) {
    52  	defer apm.SetDefaultTracer(nil)
    53  
    54  	// Call DefaultTracer concurrently to ensure there are
    55  	// no races in creating the default tracer.
    56  	tracers := make(chan *apm.Tracer, 1000)
    57  	for i := 0; i < cap(tracers); i++ {
    58  		go func() {
    59  			tracers <- apm.DefaultTracer()
    60  		}()
    61  	}
    62  
    63  	tracer0 := <-tracers
    64  	for i := 1; i < cap(tracers); i++ {
    65  		assert.Same(t, tracer0, <-tracers)
    66  	}
    67  }
    68  
    69  func TestTracerStats(t *testing.T) {
    70  	tracer := apmtest.NewDiscardTracer()
    71  	defer tracer.Close()
    72  
    73  	for i := 0; i < 500; i++ {
    74  		tracer.StartTransaction("name", "type").End()
    75  	}
    76  	tracer.Flush(nil)
    77  	assert.Equal(t, apm.TracerStats{
    78  		TransactionsSent: 500,
    79  	}, tracer.Stats())
    80  }
    81  
    82  func TestTracerUserAgent(t *testing.T) {
    83  	sendRequest := func(serviceVersion string) string {
    84  		waitc := make(chan string, 1)
    85  		srv := httptest.NewServer(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
    86  			select {
    87  			case waitc <- r.UserAgent():
    88  			default:
    89  			}
    90  		}))
    91  		defer func() {
    92  			srv.Close()
    93  			close(waitc)
    94  		}()
    95  
    96  		os.Setenv("ELASTIC_APM_SERVER_URL", srv.URL)
    97  		defer os.Unsetenv("ELASTIC_APM_SERVER_URL")
    98  		tracer, err := apm.NewTracerOptions(apm.TracerOptions{
    99  			ServiceName:    "apmtest",
   100  			ServiceVersion: serviceVersion,
   101  		})
   102  		require.NoError(t, err)
   103  		defer tracer.Close()
   104  
   105  		tracer.StartTransaction("name", "type").End()
   106  		tracer.Flush(nil)
   107  		return <-waitc
   108  	}
   109  	assert.Equal(t, fmt.Sprintf("apm-agent-go/%s (apmtest)", apmversion.AgentVersion), sendRequest(""))
   110  	assert.Equal(t, fmt.Sprintf("apm-agent-go/%s (apmtest 1.0.0)", apmversion.AgentVersion), sendRequest("1.0.0"))
   111  }
   112  
   113  func TestTracerClosedSendNonBlocking(t *testing.T) {
   114  	tracer, err := apm.NewTracer("tracer_testing", "")
   115  	assert.NoError(t, err)
   116  	tracer.Close()
   117  
   118  	for i := 0; i < 1001; i++ {
   119  		tracer.StartTransaction("name", "type").End()
   120  	}
   121  	assert.Equal(t, uint64(1), tracer.Stats().TransactionsDropped)
   122  }
   123  
   124  func TestNewTracerNonBlocking(t *testing.T) {
   125  	server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
   126  		<-req.Context().Done()
   127  	}))
   128  	defer server.Close()
   129  	os.Setenv("ELASTIC_APM_SERVER_URL", server.URL)
   130  	defer os.Unsetenv("ELASTIC_APM_SERVER_URL")
   131  
   132  	// NewTracer should not block for any significant amount of time,
   133  	// even if the server is initially unresponsive.
   134  	before := time.Now()
   135  	tracer, err := apm.NewTracer("tracer_testing", "")
   136  	assert.NoError(t, err)
   137  	tracer.Close()
   138  	newTracerTime := time.Since(before)
   139  	assert.Less(t, int64(newTracerTime), int64(time.Second))
   140  }
   141  
   142  func TestTracerCloseImmediately(t *testing.T) {
   143  	tracer, err := apm.NewTracer("tracer_testing", "")
   144  	assert.NoError(t, err)
   145  	tracer.Close()
   146  }
   147  
   148  func TestTracerFlushEmpty(t *testing.T) {
   149  	tracer, err := apm.NewTracer("tracer_testing", "")
   150  	assert.NoError(t, err)
   151  	defer tracer.Close()
   152  	tracer.Flush(nil)
   153  }
   154  
   155  func TestTracerMaxSpans(t *testing.T) {
   156  	test := func(n int) {
   157  		t.Run(fmt.Sprint(n), func(t *testing.T) {
   158  			tracer, r := transporttest.NewRecorderTracer()
   159  			defer tracer.Close()
   160  
   161  			tracer.SetMaxSpans(n)
   162  			tx := tracer.StartTransaction("name", "type")
   163  			defer tx.End()
   164  
   165  			// SetMaxSpans only affects transactions started
   166  			// after the call.
   167  			tracer.SetMaxSpans(99)
   168  
   169  			for i := 0; i < n; i++ {
   170  				span := tx.StartSpan("name", "type", nil)
   171  				assert.False(t, span.Dropped())
   172  				span.End()
   173  			}
   174  			span := tx.StartSpan("name", "type", nil)
   175  			assert.True(t, span.Dropped())
   176  			span.End()
   177  
   178  			tracer.Flush(nil)
   179  			assert.Len(t, r.Payloads().Spans, n)
   180  		})
   181  	}
   182  	test(0)
   183  	test(23)
   184  }
   185  
   186  func TestTracerErrors(t *testing.T) {
   187  	tracer, r := transporttest.NewRecorderTracer()
   188  	defer tracer.Close()
   189  
   190  	error_ := tracer.NewError(errors.New("zing"))
   191  	error_.Send()
   192  	tracer.Flush(nil)
   193  
   194  	payloads := r.Payloads()
   195  	exception := payloads.Errors[0].Exception
   196  	stacktrace := exception.Stacktrace
   197  	assert.Equal(t, "zing", exception.Message)
   198  	assert.Equal(t, "errors", exception.Module)
   199  	assert.Equal(t, "errorString", exception.Type)
   200  	require.NotEmpty(t, stacktrace)
   201  	assert.Equal(t, "TestTracerErrors", stacktrace[0].Function)
   202  }
   203  
   204  func TestTracerErrorFlushes(t *testing.T) {
   205  	tracer, recorder := transporttest.NewRecorderTracer()
   206  	defer tracer.Close()
   207  
   208  	payloads := make(chan transporttest.Payloads, 1)
   209  	var wg sync.WaitGroup
   210  	wg.Add(1)
   211  	done := make(chan struct{})
   212  	go func() {
   213  		defer wg.Done()
   214  		var last int
   215  		for {
   216  			select {
   217  			case <-time.After(10 * time.Millisecond):
   218  				p := recorder.Payloads()
   219  				if n := len(p.Errors) + len(p.Transactions); n > last {
   220  					last = n
   221  					payloads <- p
   222  				}
   223  			case <-done:
   224  				return
   225  			}
   226  		}
   227  	}()
   228  	defer wg.Wait()
   229  	defer close(done)
   230  
   231  	// Sending a transaction should not cause a request
   232  	// to be sent immediately.
   233  	tracer.StartTransaction("name", "type").End()
   234  	select {
   235  	case <-time.After(200 * time.Millisecond):
   236  	case p := <-payloads:
   237  		t.Fatalf("unexpected payloads: %+v", p)
   238  	}
   239  
   240  	// Sending an error flushes the request body.
   241  	tracer.NewError(errors.New("zing")).Send()
   242  	deadline := time.After(2 * time.Second)
   243  	for {
   244  		var p transporttest.Payloads
   245  		select {
   246  		case <-deadline:
   247  			t.Fatalf("timed out waiting for request")
   248  		case p = <-payloads:
   249  		}
   250  		if len(p.Errors) != 0 {
   251  			assert.Len(t, p.Errors, 1)
   252  			break
   253  		}
   254  		// The transport may not have decoded
   255  		// the error yet, continue waiting.
   256  	}
   257  }
   258  
   259  func TestTracerRecovered(t *testing.T) {
   260  	tracer, r := transporttest.NewRecorderTracer()
   261  	defer tracer.Close()
   262  
   263  	capturePanic(tracer, "blam")
   264  	tracer.Flush(nil)
   265  
   266  	payloads := r.Payloads()
   267  	error0 := payloads.Errors[0]
   268  	transaction := payloads.Transactions[0]
   269  	span := payloads.Spans[0]
   270  	assert.Equal(t, "blam", error0.Exception.Message)
   271  	assert.Equal(t, transaction.ID, error0.TransactionID)
   272  	assert.Equal(t, span.ID, error0.ParentID)
   273  }
   274  
   275  func capturePanic(tracer *apm.Tracer, v interface{}) {
   276  	tx := tracer.StartTransaction("name", "type")
   277  	defer tx.End()
   278  	span := tx.StartSpan("name", "type", nil)
   279  	defer span.End()
   280  	defer func() {
   281  		if v := recover(); v != nil {
   282  			e := tracer.Recovered(v)
   283  			e.SetSpan(span)
   284  			e.Send()
   285  		}
   286  	}()
   287  	panic(v)
   288  }
   289  
   290  func TestTracerServiceNameValidation(t *testing.T) {
   291  	_, err := apm.NewTracer("wot!", "")
   292  	assert.EqualError(t, err, `invalid service name "wot!": character '!' is not in the allowed set (a-zA-Z0-9 _-)`)
   293  }
   294  
   295  func TestSpanStackTrace(t *testing.T) {
   296  	tracer, r := transporttest.NewRecorderTracer()
   297  	defer tracer.Close()
   298  	tracer.SetSpanStackTraceMinDuration(10 * time.Millisecond)
   299  
   300  	tx := tracer.StartTransaction("name", "type")
   301  	s := tx.StartSpan("name", "type", nil)
   302  	s.Duration = 9 * time.Millisecond
   303  	s.End()
   304  	s = tx.StartSpan("name", "type", nil)
   305  	s.Duration = 10 * time.Millisecond
   306  	s.End()
   307  	s = tx.StartSpan("name", "type", nil)
   308  	s.SetStacktrace(1)
   309  	s.Duration = 11 * time.Millisecond
   310  	s.End()
   311  	tx.End()
   312  	tracer.Flush(nil)
   313  
   314  	spans := r.Payloads().Spans
   315  	require.Len(t, spans, 3)
   316  
   317  	// Span 0 took only 9ms, so we don't set its stacktrace.
   318  	assert.Nil(t, spans[0].Stacktrace)
   319  
   320  	// Span 1 took the required 10ms, so we set its stacktrace.
   321  	assert.NotNil(t, spans[1].Stacktrace)
   322  	assert.NotEqual(t, spans[1].Stacktrace[0].Function, "TestSpanStackTrace")
   323  
   324  	// Span 2 took more than the required 10ms, but its stacktrace
   325  	// was already set; we don't replace it.
   326  	assert.NotNil(t, spans[2].Stacktrace)
   327  	assert.Equal(t, spans[2].Stacktrace[0].Function, "TestSpanStackTrace")
   328  }
   329  
   330  func TestTracerRequestSize(t *testing.T) {
   331  	os.Setenv("ELASTIC_APM_API_REQUEST_SIZE", "1KB")
   332  	defer os.Unsetenv("ELASTIC_APM_API_REQUEST_SIZE")
   333  
   334  	// Set the request time to some very long duration,
   335  	// to highlight the fact that the request size is
   336  	// the cause of request completion.
   337  	os.Setenv("ELASTIC_APM_API_REQUEST_TIME", "60s")
   338  	defer os.Unsetenv("ELASTIC_APM_API_REQUEST_TIME")
   339  
   340  	requestHandled := make(chan struct{}, 1)
   341  	server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
   342  		if req.URL.Path == "/" {
   343  			return
   344  		}
   345  		io.Copy(ioutil.Discard, req.Body)
   346  		requestHandled <- struct{}{}
   347  	}))
   348  	defer server.Close()
   349  
   350  	os.Setenv("ELASTIC_APM_SERVER_URL", server.URL)
   351  	defer os.Unsetenv("ELASTIC_APM_SERVER_URL")
   352  
   353  	httpTransport, err := transport.NewHTTPTransport(transport.HTTPTransportOptions{})
   354  	require.NoError(t, err)
   355  	tracer, err := apm.NewTracerOptions(apm.TracerOptions{
   356  		ServiceName: "tracer_testing",
   357  		Transport:   httpTransport,
   358  	})
   359  	require.NoError(t, err)
   360  	defer tracer.Close()
   361  
   362  	// Send through a bunch of transactions, filling up the API request
   363  	// size, causing the request to be immediately completed.
   364  	clientStart := time.Now()
   365  	for i := 0; i < 500; i++ {
   366  		tracer.StartTransaction("name", "type").End()
   367  		// Yield to the tracer for more predictable timing.
   368  		runtime.Gosched()
   369  	}
   370  	<-requestHandled
   371  	clientEnd := time.Now()
   372  	assert.Condition(t, func() bool {
   373  		// Should be considerably less than 10s, which is
   374  		// considerably less than the configured 60s limit.
   375  		return clientEnd.Sub(clientStart) < 10*time.Second
   376  	})
   377  }
   378  
   379  func TestTracerBufferSize(t *testing.T) {
   380  	os.Setenv("ELASTIC_APM_API_REQUEST_SIZE", "1KB")
   381  	os.Setenv("ELASTIC_APM_API_BUFFER_SIZE", "10KB")
   382  	defer os.Unsetenv("ELASTIC_APM_API_REQUEST_SIZE")
   383  	defer os.Unsetenv("ELASTIC_APM_API_BUFFER_SIZE")
   384  
   385  	var recorder transporttest.RecorderTransport
   386  	unblock := make(chan struct{})
   387  	tracer, err := apm.NewTracerOptions(apm.TracerOptions{
   388  		ServiceName: "transporttest",
   389  		Transport: blockedTransport{
   390  			Transport: &recorder,
   391  			unblocked: unblock,
   392  		},
   393  	})
   394  	require.NoError(t, err)
   395  	defer tracer.Close()
   396  
   397  	// Send a bunch of transactions, which will be buffered. Because the
   398  	// buffer cannot hold all of them we should expect to see some of the
   399  	// older ones discarded.
   400  	const N = 1000
   401  	for i := 0; i < N; i++ {
   402  		tracer.StartTransaction(fmt.Sprint(i), "type").End()
   403  	}
   404  	close(unblock) // allow requests through now
   405  	for {
   406  		stats := tracer.Stats()
   407  		if stats.TransactionsSent+stats.TransactionsDropped == N {
   408  			require.NotZero(t, stats.TransactionsSent)
   409  			require.NotZero(t, stats.TransactionsDropped)
   410  			break
   411  		}
   412  		tracer.Flush(nil)
   413  	}
   414  
   415  	stats := tracer.Stats()
   416  	p := recorder.Payloads()
   417  	assert.Equal(t, int(stats.TransactionsSent), len(p.Transactions))
   418  
   419  	// It's possible that the tracer loop receives the flush request after
   420  	// all transactions are in the channel buffer, before any individual
   421  	// transactions make their way through. In most cases we would expect
   422  	// to see the "0" transaction in the request, but that won't be the
   423  	// case if the flush comes first.
   424  	offset := 0
   425  	for i, tx := range p.Transactions {
   426  		if tx.Name != fmt.Sprint(i+offset) {
   427  			require.Equal(t, 0, offset)
   428  			n, err := strconv.Atoi(tx.Name)
   429  			require.NoError(t, err)
   430  			offset = n - i
   431  			t.Logf("found gap of %d after first %d transactions", offset, i)
   432  		}
   433  	}
   434  	assert.NotEqual(t, 0, offset)
   435  }
   436  
   437  func TestTracerBodyUnread(t *testing.T) {
   438  	os.Setenv("ELASTIC_APM_API_REQUEST_SIZE", "1KB")
   439  	defer os.Unsetenv("ELASTIC_APM_API_REQUEST_SIZE")
   440  
   441  	// Don't consume the request body in the handler; close the connection.
   442  	var requests int64
   443  	server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
   444  		atomic.AddInt64(&requests, 1)
   445  		w.Header().Set("Connection", "close")
   446  	}))
   447  	defer server.Close()
   448  
   449  	os.Setenv("ELASTIC_APM_SERVER_URL", server.URL)
   450  	defer os.Unsetenv("ELASTIC_APM_SERVER_URL")
   451  
   452  	httpTransport, err := transport.NewHTTPTransport(transport.HTTPTransportOptions{})
   453  	require.NoError(t, err)
   454  	tracer, err := apm.NewTracerOptions(apm.TracerOptions{
   455  		ServiceName: "tracer_testing",
   456  		Transport:   httpTransport,
   457  	})
   458  	require.NoError(t, err)
   459  	defer tracer.Close()
   460  
   461  	for atomic.LoadInt64(&requests) <= 2 {
   462  		tracer.StartTransaction("name", "type").End()
   463  		tracer.Flush(nil)
   464  	}
   465  }
   466  
   467  func TestTracerMetadata(t *testing.T) {
   468  	tracer, recorder := transporttest.NewRecorderTracer()
   469  	defer tracer.Close()
   470  
   471  	tracer.StartTransaction("name", "type").End()
   472  	tracer.Flush(nil)
   473  
   474  	// TODO(axw) check other metadata
   475  	system, _, _, _ := recorder.Metadata()
   476  	container, err := apmhostutil.Container()
   477  	if err != nil {
   478  		assert.Nil(t, system.Container)
   479  	} else {
   480  		require.NotNil(t, system.Container)
   481  		assert.Equal(t, container, system.Container)
   482  	}
   483  
   484  	// Cloud metadata is disabled by apmtest by default.
   485  	assert.Equal(t, "none", os.Getenv("ELASTIC_APM_CLOUD_PROVIDER"))
   486  	assert.Zero(t, recorder.CloudMetadata())
   487  }
   488  
   489  func TestTracerKubernetesMetadata(t *testing.T) {
   490  	t.Run("no-env", func(t *testing.T) {
   491  		system, _, _, _ := getSubprocessMetadata(t)
   492  		assert.Nil(t, system.Kubernetes)
   493  	})
   494  
   495  	t.Run("namespace-only", func(t *testing.T) {
   496  		system, _, _, _ := getSubprocessMetadata(t, "KUBERNETES_NAMESPACE=myapp")
   497  		assert.Equal(t, &model.Kubernetes{
   498  			Namespace: "myapp",
   499  		}, system.Kubernetes)
   500  	})
   501  
   502  	t.Run("pod-only", func(t *testing.T) {
   503  		system, _, _, _ := getSubprocessMetadata(t, "KUBERNETES_POD_NAME=luna", "KUBERNETES_POD_UID=oneone!11")
   504  		assert.Equal(t, &model.Kubernetes{
   505  			Pod: &model.KubernetesPod{
   506  				Name: "luna",
   507  				UID:  "oneone!11",
   508  			},
   509  		}, system.Kubernetes)
   510  	})
   511  
   512  	t.Run("node-only", func(t *testing.T) {
   513  		system, _, _, _ := getSubprocessMetadata(t, "KUBERNETES_NODE_NAME=noddy")
   514  		assert.Equal(t, &model.Kubernetes{
   515  			Node: &model.KubernetesNode{
   516  				Name: "noddy",
   517  			},
   518  		}, system.Kubernetes)
   519  	})
   520  }
   521  
   522  func TestTracerActive(t *testing.T) {
   523  	tracer, _ := transporttest.NewRecorderTracer()
   524  	defer tracer.Close()
   525  	assert.True(t, tracer.Active())
   526  
   527  	// Kick off calls to tracer.Active concurrently
   528  	// with the tracer.Close, to test that we ensure
   529  	// there are no data races.
   530  	go func() {
   531  		for i := 0; i < 100; i++ {
   532  			tracer.Active()
   533  		}
   534  	}()
   535  }
   536  
   537  func TestTracerCaptureHeaders(t *testing.T) {
   538  	tracer, recorder := transporttest.NewRecorderTracer()
   539  	defer tracer.Close()
   540  
   541  	req, err := http.NewRequest("GET", "http://testing.invalid", nil)
   542  	require.NoError(t, err)
   543  	req.Header.Set("foo", "bar")
   544  	respHeaders := make(http.Header)
   545  	respHeaders.Set("baz", "qux")
   546  
   547  	for _, enabled := range []bool{false, true} {
   548  		tracer.SetCaptureHeaders(enabled)
   549  		tx := tracer.StartTransaction("name", "type")
   550  		tx.Context.SetHTTPRequest(req)
   551  		tx.Context.SetHTTPResponseHeaders(respHeaders)
   552  		tx.Context.SetHTTPStatusCode(202)
   553  		tx.End()
   554  	}
   555  
   556  	tracer.Flush(nil)
   557  	payloads := recorder.Payloads()
   558  	require.Len(t, payloads.Transactions, 2)
   559  
   560  	for i, enabled := range []bool{false, true} {
   561  		tx := payloads.Transactions[i]
   562  		require.NotNil(t, tx.Context.Request)
   563  		require.NotNil(t, tx.Context.Response)
   564  		if enabled {
   565  			assert.NotNil(t, tx.Context.Request.Headers)
   566  			assert.NotNil(t, tx.Context.Response.Headers)
   567  		} else {
   568  			assert.Nil(t, tx.Context.Request.Headers)
   569  			assert.Nil(t, tx.Context.Response.Headers)
   570  		}
   571  	}
   572  }
   573  
   574  func TestTracerDefaultTransport(t *testing.T) {
   575  	mux := http.NewServeMux()
   576  	mux.HandleFunc("/intake/v2/events", func(w http.ResponseWriter, r *http.Request) {})
   577  	srv := httptest.NewServer(mux)
   578  
   579  	t.Run("valid", func(t *testing.T) {
   580  		os.Setenv("ELASTIC_APM_SERVER_URL", srv.URL)
   581  		defer os.Unsetenv("ELASTIC_APM_SERVER_URL")
   582  		tracer, err := apm.NewTracer("", "")
   583  		require.NoError(t, err)
   584  		defer tracer.Close()
   585  
   586  		tracer.StartTransaction("name", "type").End()
   587  		tracer.Flush(nil)
   588  		assert.Equal(t, apm.TracerStats{TransactionsSent: 1}, tracer.Stats())
   589  	})
   590  
   591  	t.Run("invalid", func(t *testing.T) {
   592  		os.Setenv("ELASTIC_APM_SERVER_TIMEOUT", "never")
   593  		defer os.Unsetenv("ELASTIC_APM_SERVER_TIMEOUT")
   594  
   595  		// NewTracer returns errors.
   596  		tracer, err := apm.NewTracer("", "")
   597  		require.Error(t, err)
   598  		assert.EqualError(t, err, "failed to parse ELASTIC_APM_SERVER_TIMEOUT: invalid duration never")
   599  
   600  		// Implicitly created Tracers will have a discard tracer.
   601  		apm.SetDefaultTracer(nil)
   602  		tracer = apm.DefaultTracer()
   603  
   604  		tracer.StartTransaction("name", "type").End()
   605  		tracer.Flush(nil)
   606  		assert.Equal(t, apm.TracerStats{
   607  			Errors: apm.TracerStatsErrors{
   608  				SendStream: 1,
   609  			},
   610  		}, tracer.Stats())
   611  	})
   612  }
   613  
   614  func TestTracerUnsampledTransactions(t *testing.T) {
   615  	newTracer := func(v, remoteV uint32) (*apm.Tracer, *serverVersionRecorderTransport) {
   616  		transport := serverVersionRecorderTransport{
   617  			RecorderTransport:   &transporttest.RecorderTransport{},
   618  			ServerVersion:       v,
   619  			RemoteServerVersion: remoteV,
   620  		}
   621  		tracer, err := apm.NewTracerOptions(apm.TracerOptions{
   622  			ServiceName: "transporttest",
   623  			Transport:   &transport,
   624  		})
   625  		require.NoError(t, err)
   626  		return tracer, &transport
   627  	}
   628  
   629  	t.Run("drop", func(t *testing.T) {
   630  		tracer, recorder := newTracer(0, 8)
   631  		defer tracer.Close()
   632  		tracer.SetSampler(apm.NewRatioSampler(0.0))
   633  		tx := tracer.StartTransaction("tx", "unsampled")
   634  		tx.End()
   635  		tracer.Flush(nil)
   636  
   637  		txs := recorder.Payloads().Transactions
   638  		require.Empty(t, txs)
   639  	})
   640  	t.Run("send", func(t *testing.T) {
   641  		tracer, recorder := newTracer(0, 7)
   642  		defer tracer.Close()
   643  		tracer.SetSampler(apm.NewRatioSampler(0.0))
   644  		tx := tracer.StartTransaction("tx", "unsampled")
   645  		tx.End()
   646  		tracer.Flush(nil)
   647  
   648  		txs := recorder.Payloads().Transactions
   649  		require.NotEmpty(t, txs)
   650  		assert.Equal(t, txs[0].Type, "unsampled")
   651  	})
   652  	t.Run("send-sampled-7", func(t *testing.T) {
   653  		tracer, recorder := newTracer(0, 8)
   654  		defer tracer.Close()
   655  		tx := tracer.StartTransaction("tx", "sampled")
   656  		tx.End()
   657  		tracer.Flush(nil)
   658  
   659  		txs := recorder.Payloads().Transactions
   660  		require.NotEmpty(t, txs)
   661  		assert.Equal(t, txs[0].Type, "sampled")
   662  	})
   663  	t.Run("send-sampled-8", func(t *testing.T) {
   664  		tracer, recorder := newTracer(0, 8)
   665  		defer tracer.Close()
   666  		tx := tracer.StartTransaction("tx", "sampled")
   667  		tx.End()
   668  		tracer.Flush(nil)
   669  
   670  		txs := recorder.Payloads().Transactions
   671  		require.NotEmpty(t, txs)
   672  		assert.Equal(t, txs[0].Type, "sampled")
   673  	})
   674  	t.Run("send-unimplemented-interface", func(t *testing.T) {
   675  		tracer, recorder := transporttest.NewRecorderTracer()
   676  		defer tracer.Close()
   677  		tracer.SetSampler(apm.NewRatioSampler(0.0))
   678  		tx := tracer.StartTransaction("tx", "unsampled")
   679  		tx.End()
   680  		tracer.Flush(nil)
   681  
   682  		txs := recorder.Payloads().Transactions
   683  		require.NotEmpty(t, txs)
   684  		assert.Equal(t, txs[0].Type, "unsampled")
   685  	})
   686  	t.Run("send-onerror", func(t *testing.T) {
   687  		tracer, recorder := newTracer(0, 0)
   688  		defer tracer.Close()
   689  		tracer.SetSampler(apm.NewRatioSampler(0.0))
   690  		tx := tracer.StartTransaction("tx", "unsampled")
   691  		tx.End()
   692  		tracer.Flush(nil)
   693  
   694  		txs := recorder.Payloads().Transactions
   695  		require.NotEmpty(t, txs)
   696  		assert.Equal(t, txs[0].Type, "unsampled")
   697  	})
   698  }
   699  
   700  func TestTracerUnsampledTransactionsHTTPTransport(t *testing.T) {
   701  	newTracer := func(srvURL string) (*apm.Tracer, *transport.HTTPTransport) {
   702  		os.Setenv("ELASTIC_APM_SERVER_URL", srvURL)
   703  		defer os.Unsetenv("ELASTIC_APM_SERVER_URL")
   704  		transport, err := transport.NewHTTPTransport(transport.HTTPTransportOptions{})
   705  		require.NoError(t, err)
   706  		tracer, err := apm.NewTracerOptions(apm.TracerOptions{
   707  			ServiceName: "transporttest",
   708  			Transport:   transport,
   709  		})
   710  		require.NoError(t, err)
   711  		return tracer, transport
   712  	}
   713  
   714  	type event struct {
   715  		Tx *model.Transaction `json:"transaction,omitempty"`
   716  	}
   717  	countTransactions := func(body io.ReadCloser) uint32 {
   718  		reader, err := zlib.NewReader(body)
   719  		require.NoError(t, err)
   720  		scanner := bufio.NewScanner(reader)
   721  		var tCount uint32
   722  		for scanner.Scan() {
   723  			var e event
   724  			json.Unmarshal([]byte(scanner.Text()), &e)
   725  			assert.NoError(t, err)
   726  
   727  			if e.Tx != nil {
   728  				tCount++
   729  			}
   730  		}
   731  		return tCount
   732  	}
   733  
   734  	intakeHandlerFunc := func(tCounter *uint32) http.Handler {
   735  		return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
   736  			defer r.Body.Close()
   737  			atomic.AddUint32(tCounter, countTransactions(r.Body))
   738  			rw.WriteHeader(202)
   739  		})
   740  	}
   741  	// This handler is used to test for cache invalidation, it will return an
   742  	// error only once when the number of transactions is 100, so we can test
   743  	// the cache invalidation.
   744  	intakeHandlerErr100Func := func(tCounter *uint32) http.Handler {
   745  		var hasErrored bool
   746  		return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
   747  			defer r.Body.Close()
   748  			if atomic.LoadUint32(tCounter) == 100 && !hasErrored {
   749  				hasErrored = true
   750  				io.Copy(ioutil.Discard, r.Body)
   751  				http.Error(rw, "error-message", http.StatusInternalServerError)
   752  				return
   753  			}
   754  			atomic.AddUint32(tCounter, countTransactions(r.Body))
   755  			rw.WriteHeader(202)
   756  		})
   757  	}
   758  	rootHandlerFunc := func(v string, rootCounter *uint32) http.Handler {
   759  		return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
   760  			// Only handle requests that match the path.
   761  			if r.URL.Path != "/" {
   762  				return
   763  			}
   764  			rw.WriteHeader(200)
   765  			rw.Write([]byte(fmt.Sprintf(`{"version":"%s"}`, v)))
   766  			atomic.AddUint32(rootCounter, 1)
   767  		})
   768  	}
   769  
   770  	generateTx := func(tracer *apm.Tracer) {
   771  		// Sends 100 unsampled transactions to the tracer.
   772  		tracer.SetSampler(apm.NewRatioSampler(0.0))
   773  		for i := 0; i < 100; i++ {
   774  			tx := tracer.StartTransaction("tx", "unsampled")
   775  			tx.End()
   776  		}
   777  		// Sends 100 sampled transactions to the tracer.
   778  		tracer.SetSampler(apm.NewRatioSampler(1.0))
   779  		for i := 0; i < 100; i++ {
   780  			tx := tracer.StartTransaction("tx", "sampled")
   781  			tx.End()
   782  		}
   783  		<-time.After(time.Millisecond)
   784  		tracer.Flush(nil)
   785  	}
   786  
   787  	waitMajorServerVersion := func(t *testing.T, transport *transport.HTTPTransport, expected int) {
   788  		ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
   789  		defer cancel()
   790  		for ctx.Err() == nil {
   791  			actual := int(transport.MajorServerVersion(ctx, false))
   792  			if actual == expected {
   793  				return
   794  			}
   795  		}
   796  		t.Fatalf("timed out waiting for major server version to become %d", expected)
   797  	}
   798  
   799  	t.Run("pre-8-sends-all", func(t *testing.T) {
   800  		var tCounter, rootCounter uint32
   801  		mux := http.NewServeMux()
   802  		mux.Handle("/intake/v2/events", intakeHandlerFunc(&tCounter))
   803  		mux.Handle("/", rootHandlerFunc("7.17.0", &rootCounter))
   804  		srv := httptest.NewServer(mux)
   805  		defer srv.Close()
   806  		tracer, transport := newTracer(srv.URL)
   807  
   808  		waitMajorServerVersion(t, transport, 7)
   809  		generateTx(tracer)
   810  
   811  		assert.Equal(t, uint32(200), atomic.LoadUint32(&tCounter))
   812  		assert.Equal(t, uint32(1), atomic.LoadUint32(&rootCounter))
   813  	})
   814  	t.Run("post-8-sends-sampled-only", func(t *testing.T) {
   815  		var tCounter, rootCounter uint32
   816  		mux := http.NewServeMux()
   817  		mux.Handle("/intake/v2/events", intakeHandlerFunc(&tCounter))
   818  		mux.Handle("/", rootHandlerFunc("8.0.0", &rootCounter))
   819  		srv := httptest.NewServer(mux)
   820  		defer srv.Close()
   821  		tracer, transport := newTracer(srv.URL)
   822  
   823  		waitMajorServerVersion(t, transport, 8)
   824  		generateTx(tracer)
   825  
   826  		assert.Equal(t, uint32(100), atomic.LoadUint32(&tCounter))
   827  		assert.Equal(t, uint32(1), atomic.LoadUint32(&rootCounter))
   828  	})
   829  	t.Run("post-8-sends-sampled-only-after-cache-invalidation-send-all", func(t *testing.T) {
   830  		// This test case asserts that when the server's major version is >= 8
   831  		// only the sampled transactions are sent. After 100 transactions have
   832  		// been sent to the server, the server will return a 500 error and will
   833  		// invalidate the cache, causing all transactions (sampled and unsampled)
   834  		// to be sent, until the version is refreshed. Since it will take 10s
   835  		// for the version to be refreshed, this test doesn't assert that.
   836  		var tCounter, rootCounter uint32
   837  		mux := http.NewServeMux()
   838  		mux.Handle("/intake/v2/events", intakeHandlerErr100Func(&tCounter))
   839  		mux.Handle("/", rootHandlerFunc("8.0.0", &rootCounter))
   840  		srv := httptest.NewServer(mux)
   841  		defer srv.Close()
   842  		tracer, transport := newTracer(srv.URL)
   843  
   844  		waitMajorServerVersion(t, transport, 8)
   845  		for i := 0; i < 3; i++ {
   846  			generateTx(tracer)
   847  		}
   848  		assert.Equal(t, uint32(300), atomic.LoadUint32(&tCounter))
   849  		assert.Equal(t, uint32(1), atomic.LoadUint32(&rootCounter))
   850  
   851  		// Manually refresh the remote version.
   852  		ctx, cancel := context.WithTimeout(context.Background(), time.Second)
   853  		defer cancel()
   854  		transport.MajorServerVersion(ctx, true)
   855  		assert.Equal(t, uint32(2), atomic.LoadUint32(&rootCounter))
   856  
   857  		// Send 100 sampled and 100 unsampled txs.
   858  		generateTx(tracer)
   859  		assert.Equal(t, uint32(400), atomic.LoadUint32(&tCounter))
   860  	})
   861  	t.Run("invalid-version-sends-all", func(t *testing.T) {
   862  		var tCounter, rootCounter uint32
   863  		mux := http.NewServeMux()
   864  		mux.Handle("/intake/v2/events", intakeHandlerFunc(&tCounter))
   865  		mux.Handle("/", rootHandlerFunc("invalid-version", &rootCounter))
   866  		srv := httptest.NewServer(mux)
   867  		defer srv.Close()
   868  
   869  		tracer, _ := newTracer(srv.URL)
   870  		generateTx(tracer)
   871  
   872  		assert.Equal(t, uint32(200), atomic.LoadUint32(&tCounter))
   873  		assert.Equal(t, uint32(1), atomic.LoadUint32(&rootCounter))
   874  	})
   875  }
   876  
   877  type blockedTransport struct {
   878  	transport.Transport
   879  	unblocked chan struct{}
   880  }
   881  
   882  func (bt blockedTransport) SendStream(ctx context.Context, r io.Reader) error {
   883  	select {
   884  	case <-ctx.Done():
   885  		return ctx.Err()
   886  	case <-bt.unblocked:
   887  		return bt.Transport.SendStream(ctx, r)
   888  	}
   889  }
   890  
   891  // serverVersionRecorderTransport wraps a RecorderTransport providing the
   892  type serverVersionRecorderTransport struct {
   893  	*transporttest.RecorderTransport
   894  	ServerVersion       uint32
   895  	RemoteServerVersion uint32
   896  }
   897  
   898  // MajorServerVersion returns the stored version.
   899  func (r *serverVersionRecorderTransport) MajorServerVersion(_ context.Context, refreshStale bool) uint32 {
   900  	if refreshStale {
   901  		atomic.StoreUint32(&r.ServerVersion, r.RemoteServerVersion)
   902  	}
   903  	return atomic.LoadUint32(&r.ServerVersion)
   904  }