github.com/newrelic/go-agent@v3.26.0+incompatible/internal/collector_test.go (about)

     1  // Copyright 2020 New Relic Corporation. All rights reserved.
     2  // SPDX-License-Identifier: Apache-2.0
     3  
     4  package internal
     5  
     6  import (
     7  	"compress/gzip"
     8  	"encoding/json"
     9  	"errors"
    10  	"fmt"
    11  	"io"
    12  	"io/ioutil"
    13  	"net/http"
    14  	"net/url"
    15  	"strings"
    16  	"sync"
    17  	"testing"
    18  
    19  	"github.com/newrelic/go-agent/internal/crossagent"
    20  	"github.com/newrelic/go-agent/internal/logger"
    21  )
    22  
    23  func TestResponseCodeError(t *testing.T) {
    24  	testcases := []struct {
    25  		code            int
    26  		success         bool
    27  		disconnect      bool
    28  		restart         bool
    29  		saveHarvestData bool
    30  	}{
    31  		// success
    32  		{code: 200, success: true, disconnect: false, restart: false, saveHarvestData: false},
    33  		{code: 202, success: true, disconnect: false, restart: false, saveHarvestData: false},
    34  		// disconnect
    35  		{code: 410, success: false, disconnect: true, restart: false, saveHarvestData: false},
    36  		// restart
    37  		{code: 401, success: false, disconnect: false, restart: true, saveHarvestData: false},
    38  		{code: 409, success: false, disconnect: false, restart: true, saveHarvestData: false},
    39  		// save data
    40  		{code: 408, success: false, disconnect: false, restart: false, saveHarvestData: true},
    41  		{code: 429, success: false, disconnect: false, restart: false, saveHarvestData: true},
    42  		{code: 500, success: false, disconnect: false, restart: false, saveHarvestData: true},
    43  		{code: 503, success: false, disconnect: false, restart: false, saveHarvestData: true},
    44  		// other errors
    45  		{code: 400, success: false, disconnect: false, restart: false, saveHarvestData: false},
    46  		{code: 403, success: false, disconnect: false, restart: false, saveHarvestData: false},
    47  		{code: 404, success: false, disconnect: false, restart: false, saveHarvestData: false},
    48  		{code: 405, success: false, disconnect: false, restart: false, saveHarvestData: false},
    49  		{code: 407, success: false, disconnect: false, restart: false, saveHarvestData: false},
    50  		{code: 411, success: false, disconnect: false, restart: false, saveHarvestData: false},
    51  		{code: 413, success: false, disconnect: false, restart: false, saveHarvestData: false},
    52  		{code: 414, success: false, disconnect: false, restart: false, saveHarvestData: false},
    53  		{code: 415, success: false, disconnect: false, restart: false, saveHarvestData: false},
    54  		{code: 417, success: false, disconnect: false, restart: false, saveHarvestData: false},
    55  		{code: 431, success: false, disconnect: false, restart: false, saveHarvestData: false},
    56  		// unexpected weird codes
    57  		{code: -1, success: false, disconnect: false, restart: false, saveHarvestData: false},
    58  		{code: 1, success: false, disconnect: false, restart: false, saveHarvestData: false},
    59  		{code: 999999, success: false, disconnect: false, restart: false, saveHarvestData: false},
    60  	}
    61  	for _, tc := range testcases {
    62  		resp := newRPMResponse(tc.code)
    63  		if tc.success != (nil == resp.Err) {
    64  			t.Error("error", tc.code, tc.success, resp.Err)
    65  		}
    66  		if tc.disconnect != resp.IsDisconnect() {
    67  			t.Error("disconnect", tc.code, tc.disconnect, resp.Err)
    68  		}
    69  		if tc.restart != resp.IsRestartException() {
    70  			t.Error("restart", tc.code, tc.restart, resp.Err)
    71  		}
    72  		if tc.saveHarvestData != resp.ShouldSaveHarvestData() {
    73  			t.Error("save harvest data", tc.code, tc.saveHarvestData, resp.Err)
    74  		}
    75  	}
    76  }
    77  
    78  type roundTripperFunc func(*http.Request) (*http.Response, error)
    79  
    80  func (fn roundTripperFunc) RoundTrip(r *http.Request) (*http.Response, error) {
    81  	return fn(r)
    82  }
    83  
    84  func TestCollectorRequest(t *testing.T) {
    85  	cmd := RpmCmd{
    86  		Name:              "cmd_name",
    87  		Collector:         "collector.com",
    88  		RunID:             "run_id",
    89  		Data:              nil,
    90  		RequestHeadersMap: map[string]string{"zip": "zap"},
    91  		MaxPayloadSize:    maxPayloadSizeInBytes,
    92  	}
    93  	testField := func(name, v1, v2 string) {
    94  		if v1 != v2 {
    95  			t.Error(name, v1, v2)
    96  		}
    97  	}
    98  	cs := RpmControls{
    99  		License: "the_license",
   100  		Client: &http.Client{
   101  			Transport: roundTripperFunc(func(r *http.Request) (*http.Response, error) {
   102  				testField("method", r.Method, "POST")
   103  				testField("url", r.URL.String(), "https://collector.com/agent_listener/invoke_raw_method?license_key=the_license&marshal_format=json&method=cmd_name&protocol_version=17&run_id=run_id")
   104  				testField("Accept-Encoding", r.Header.Get("Accept-Encoding"), "identity, deflate")
   105  				testField("Content-Type", r.Header.Get("Content-Type"), "application/octet-stream")
   106  				testField("User-Agent", r.Header.Get("User-Agent"), "NewRelic-Go-Agent/agent_version")
   107  				testField("Content-Encoding", r.Header.Get("Content-Encoding"), "gzip")
   108  				testField("zip", r.Header.Get("zip"), "zap")
   109  				return &http.Response{
   110  					StatusCode: 200,
   111  					Body:       ioutil.NopCloser(strings.NewReader("body")),
   112  				}, nil
   113  			}),
   114  		},
   115  		Logger:       logger.ShimLogger{IsDebugEnabled: true},
   116  		AgentVersion: "agent_version",
   117  		GzipWriterPool: &sync.Pool{
   118  			New: func() interface{} {
   119  				return gzip.NewWriter(io.Discard)
   120  			},
   121  		},
   122  	}
   123  	resp := CollectorRequest(cmd, cs)
   124  	if nil != resp.Err {
   125  		t.Error(resp.Err)
   126  	}
   127  }
   128  
   129  func TestCollectorBadRequest(t *testing.T) {
   130  	cmd := RpmCmd{
   131  		Name:              "cmd_name",
   132  		Collector:         "collector.com",
   133  		RunID:             "run_id",
   134  		Data:              nil,
   135  		RequestHeadersMap: map[string]string{"zip": "zap"},
   136  	}
   137  	cs := RpmControls{
   138  		License: "the_license",
   139  		Client: &http.Client{
   140  			Transport: roundTripperFunc(func(r *http.Request) (*http.Response, error) {
   141  				return &http.Response{
   142  					StatusCode: 200,
   143  					Body:       ioutil.NopCloser(strings.NewReader("body")),
   144  				}, nil
   145  			}),
   146  		},
   147  		Logger:       logger.ShimLogger{IsDebugEnabled: true},
   148  		AgentVersion: "agent_version",
   149  		GzipWriterPool: &sync.Pool{
   150  			New: func() interface{} {
   151  				return gzip.NewWriter(io.Discard)
   152  			},
   153  		},
   154  	}
   155  	u := ":" // bad url
   156  	resp := collectorRequestInternal(u, cmd, cs)
   157  	if nil == resp.Err {
   158  		t.Error("missing expected error")
   159  	}
   160  
   161  }
   162  
   163  func TestUrl(t *testing.T) {
   164  	cmd := RpmCmd{
   165  		Name:      "foo_method",
   166  		Collector: "example.com",
   167  	}
   168  	cs := RpmControls{
   169  		License:      "123abc",
   170  		Client:       nil,
   171  		Logger:       nil,
   172  		AgentVersion: "1",
   173  		GzipWriterPool: &sync.Pool{
   174  			New: func() interface{} {
   175  				return gzip.NewWriter(io.Discard)
   176  			},
   177  		},
   178  	}
   179  
   180  	out := rpmURL(cmd, cs)
   181  	u, err := url.Parse(out)
   182  	if err != nil {
   183  		t.Fatalf("url.Parse(%q) = %q", out, err)
   184  	}
   185  
   186  	got := u.Query().Get("license_key")
   187  	if got != cs.License {
   188  		t.Errorf("got=%q cmd.License=%q", got, cs.License)
   189  	}
   190  	if u.Scheme != "https" {
   191  		t.Error(u.Scheme)
   192  	}
   193  }
   194  
   195  const (
   196  	unknownRequiredPolicyBody = `{"return_value":{"redirect_host":"special_collector","security_policies":{"unknown_policy":{"enabled":true,"required":true}}}}`
   197  	redirectBody              = `{"return_value":{"redirect_host":"special_collector"}}`
   198  	connectBody               = `{"return_value":{"agent_run_id":"my_agent_run_id"}}`
   199  	malformedBody             = `{"return_value":}}`
   200  )
   201  
   202  func makeResponse(code int, body string) *http.Response {
   203  	return &http.Response{
   204  		StatusCode: code,
   205  		Body:       ioutil.NopCloser(strings.NewReader(body)),
   206  	}
   207  }
   208  
   209  type endpointResult struct {
   210  	response *http.Response
   211  	err      error
   212  }
   213  
   214  type connectMock struct {
   215  	redirect endpointResult
   216  	connect  endpointResult
   217  	// testConfig will be used if this is nil
   218  	config ConnectJSONCreator
   219  }
   220  
   221  func (m connectMock) RoundTrip(r *http.Request) (*http.Response, error) {
   222  	cmd := r.URL.Query().Get("method")
   223  	switch cmd {
   224  	case cmdPreconnect:
   225  		return m.redirect.response, m.redirect.err
   226  	case cmdConnect:
   227  		return m.connect.response, m.connect.err
   228  	default:
   229  		return nil, fmt.Errorf("unknown cmd: %s", cmd)
   230  	}
   231  }
   232  
   233  func (m connectMock) CancelRequest(req *http.Request) {}
   234  
   235  type testConfig struct{}
   236  
   237  func (tc testConfig) CreateConnectJSON(*SecurityPolicies) ([]byte, error) {
   238  	return []byte(`"connect-json"`), nil
   239  }
   240  
   241  type errorConfig struct{}
   242  
   243  func (c errorConfig) CreateConnectJSON(*SecurityPolicies) ([]byte, error) {
   244  	return nil, errors.New("error creating config JSON")
   245  }
   246  
   247  func testConnectHelper(cm connectMock) (*ConnectReply, RPMResponse) {
   248  	config := cm.config
   249  	if nil == config {
   250  		config = testConfig{}
   251  	}
   252  	cs := RpmControls{
   253  		License:      "12345",
   254  		Client:       &http.Client{Transport: cm},
   255  		Logger:       logger.ShimLogger{IsDebugEnabled: true},
   256  		AgentVersion: "1",
   257  		GzipWriterPool: &sync.Pool{
   258  			New: func() interface{} {
   259  				return gzip.NewWriter(io.Discard)
   260  			},
   261  		},
   262  	}
   263  
   264  	return ConnectAttempt(config, "", false, cs)
   265  }
   266  
   267  func TestConnectAttemptSuccess(t *testing.T) {
   268  	run, resp := testConnectHelper(connectMock{
   269  		redirect: endpointResult{response: makeResponse(200, redirectBody)},
   270  		connect:  endpointResult{response: makeResponse(200, connectBody)},
   271  	})
   272  	if nil == run || nil != resp.Err {
   273  		t.Fatal(run, resp.Err)
   274  	}
   275  	if run.Collector != "special_collector" {
   276  		t.Error(run.Collector)
   277  	}
   278  	if run.RunID != "my_agent_run_id" {
   279  		t.Error(run)
   280  	}
   281  }
   282  
   283  func TestConnectClientError(t *testing.T) {
   284  	run, resp := testConnectHelper(connectMock{
   285  		redirect: endpointResult{response: makeResponse(200, redirectBody)},
   286  		connect:  endpointResult{err: errors.New("client error")},
   287  	})
   288  	if nil != run {
   289  		t.Fatal(run)
   290  	}
   291  	if resp.Err == nil {
   292  		t.Fatal("missing expected error")
   293  	}
   294  }
   295  
   296  func TestConnectConfigJSONError(t *testing.T) {
   297  	run, resp := testConnectHelper(connectMock{
   298  		redirect: endpointResult{response: makeResponse(200, redirectBody)},
   299  		connect:  endpointResult{response: makeResponse(200, connectBody)},
   300  		config:   errorConfig{},
   301  	})
   302  	if nil != run {
   303  		t.Fatal(run)
   304  	}
   305  	if resp.Err == nil {
   306  		t.Fatal("missing expected error")
   307  	}
   308  }
   309  
   310  func TestConnectAttemptDisconnectOnRedirect(t *testing.T) {
   311  	run, resp := testConnectHelper(connectMock{
   312  		redirect: endpointResult{response: makeResponse(410, "")},
   313  		connect:  endpointResult{response: makeResponse(200, connectBody)},
   314  	})
   315  	if nil != run {
   316  		t.Error(run)
   317  	}
   318  	if nil == resp.Err {
   319  		t.Fatal("missing error")
   320  	}
   321  	if !resp.IsDisconnect() {
   322  		t.Fatal("should be disconnect")
   323  	}
   324  }
   325  
   326  func TestConnectAttemptDisconnectOnConnect(t *testing.T) {
   327  	run, resp := testConnectHelper(connectMock{
   328  		redirect: endpointResult{response: makeResponse(200, redirectBody)},
   329  		connect:  endpointResult{response: makeResponse(410, "")},
   330  	})
   331  	if nil != run {
   332  		t.Error(run)
   333  	}
   334  	if nil == resp.Err {
   335  		t.Fatal("missing error")
   336  	}
   337  	if !resp.IsDisconnect() {
   338  		t.Fatal("should be disconnect")
   339  	}
   340  }
   341  
   342  func TestConnectAttemptBadSecurityPolicies(t *testing.T) {
   343  	run, resp := testConnectHelper(connectMock{
   344  		redirect: endpointResult{response: makeResponse(200, unknownRequiredPolicyBody)},
   345  		connect:  endpointResult{response: makeResponse(200, connectBody)},
   346  	})
   347  	if nil != run {
   348  		t.Error(run)
   349  	}
   350  	if nil == resp.Err {
   351  		t.Fatal("missing error")
   352  	}
   353  	if !resp.IsDisconnect() {
   354  		t.Fatal("should be disconnect")
   355  	}
   356  }
   357  
   358  func TestConnectAttemptInvalidJSON(t *testing.T) {
   359  	run, resp := testConnectHelper(connectMock{
   360  		redirect: endpointResult{response: makeResponse(200, redirectBody)},
   361  		connect:  endpointResult{response: makeResponse(200, malformedBody)},
   362  	})
   363  	if nil != run {
   364  		t.Error(run)
   365  	}
   366  	if nil == resp.Err {
   367  		t.Fatal("missing error")
   368  	}
   369  }
   370  
   371  func TestConnectAttemptCollectorNotString(t *testing.T) {
   372  	run, resp := testConnectHelper(connectMock{
   373  		redirect: endpointResult{response: makeResponse(200, `{"return_value":123}`)},
   374  		connect:  endpointResult{response: makeResponse(200, connectBody)},
   375  	})
   376  	if nil != run {
   377  		t.Error(run)
   378  	}
   379  	if nil == resp.Err {
   380  		t.Fatal("missing error")
   381  	}
   382  }
   383  
   384  func TestConnectAttempt401(t *testing.T) {
   385  	run, resp := testConnectHelper(connectMock{
   386  		redirect: endpointResult{response: makeResponse(200, redirectBody)},
   387  		connect:  endpointResult{response: makeResponse(401, connectBody)},
   388  	})
   389  	if nil != run {
   390  		t.Error(run)
   391  	}
   392  	if nil == resp.Err {
   393  		t.Fatal("missing error")
   394  	}
   395  	if !resp.IsRestartException() {
   396  		t.Fatal("should be restart")
   397  	}
   398  }
   399  
   400  func TestConnectAttemptOtherReturnCode(t *testing.T) {
   401  	run, resp := testConnectHelper(connectMock{
   402  		redirect: endpointResult{response: makeResponse(200, redirectBody)},
   403  		connect:  endpointResult{response: makeResponse(413, connectBody)},
   404  	})
   405  	if nil != run {
   406  		t.Error(run)
   407  	}
   408  	if nil == resp.Err {
   409  		t.Fatal("missing error")
   410  	}
   411  }
   412  
   413  func TestConnectAttemptMissingRunID(t *testing.T) {
   414  	run, resp := testConnectHelper(connectMock{
   415  		redirect: endpointResult{response: makeResponse(200, redirectBody)},
   416  		connect:  endpointResult{response: makeResponse(200, `{"return_value":{}}`)},
   417  	})
   418  	if nil != run {
   419  		t.Error(run)
   420  	}
   421  	if errMissingAgentRunID != resp.Err {
   422  		t.Fatal("wrong error", resp.Err)
   423  	}
   424  }
   425  
   426  func TestCalculatePreconnectHost(t *testing.T) {
   427  	// non-region license
   428  	host := calculatePreconnectHost("0123456789012345678901234567890123456789", "")
   429  	if host != preconnectHostDefault {
   430  		t.Error(host)
   431  	}
   432  	// override present
   433  	override := "other-collector.newrelic.com"
   434  	host = calculatePreconnectHost("0123456789012345678901234567890123456789", override)
   435  	if host != override {
   436  		t.Error(host)
   437  	}
   438  	// four letter region
   439  	host = calculatePreconnectHost("eu01xx6789012345678901234567890123456789", "")
   440  	if host != "collector.eu01.nr-data.net" {
   441  		t.Error(host)
   442  	}
   443  	// five letter region
   444  	host = calculatePreconnectHost("gov01x6789012345678901234567890123456789", "")
   445  	if host != "collector.gov01.nr-data.net" {
   446  		t.Error(host)
   447  	}
   448  	// six letter region
   449  	host = calculatePreconnectHost("foo001x6789012345678901234567890123456789", "")
   450  	if host != "collector.foo001.nr-data.net" {
   451  		t.Error(host)
   452  	}
   453  }
   454  
   455  func TestPreconnectHostCrossAgent(t *testing.T) {
   456  	var testcases []struct {
   457  		Name               string `json:"name"`
   458  		ConfigFileKey      string `json:"config_file_key"`
   459  		EnvKey             string `json:"env_key"`
   460  		ConfigOverrideHost string `json:"config_override_host"`
   461  		EnvOverrideHost    string `json:"env_override_host"`
   462  		ExpectHostname     string `json:"hostname"`
   463  	}
   464  	err := crossagent.ReadJSON("collector_hostname.json", &testcases)
   465  	if err != nil {
   466  		t.Fatal(err)
   467  	}
   468  
   469  	for _, tc := range testcases {
   470  		// mimic file/environment precedence of other agents
   471  		configKey := tc.ConfigFileKey
   472  		if "" != tc.EnvKey {
   473  			configKey = tc.EnvKey
   474  		}
   475  		overrideHost := tc.ConfigOverrideHost
   476  		if "" != tc.EnvOverrideHost {
   477  			overrideHost = tc.EnvOverrideHost
   478  		}
   479  
   480  		host := calculatePreconnectHost(configKey, overrideHost)
   481  		if host != tc.ExpectHostname {
   482  			t.Errorf(`test="%s" got="%s" expected="%s"`, tc.Name, host, tc.ExpectHostname)
   483  		}
   484  	}
   485  }
   486  
   487  func TestCollectorRequestRespectsMaxPayloadSize(t *testing.T) {
   488  	// Test that CollectorRequest returns an error when MaxPayloadSize is
   489  	// exceeded
   490  	cmd := RpmCmd{
   491  		Name:           "cmd_name",
   492  		Collector:      "collector.com",
   493  		RunID:          "run_id",
   494  		Data:           []byte("abcdefghijklmnopqrstuvwxyz"),
   495  		MaxPayloadSize: 3,
   496  	}
   497  	cs := RpmControls{
   498  		Client: &http.Client{
   499  			Transport: roundTripperFunc(func(r *http.Request) (*http.Response, error) {
   500  				t.Error("no response should have gone out!")
   501  				return nil, nil
   502  			}),
   503  		},
   504  		Logger: logger.ShimLogger{IsDebugEnabled: true},
   505  		GzipWriterPool: &sync.Pool{
   506  			New: func() interface{} {
   507  				return gzip.NewWriter(io.Discard)
   508  			},
   509  		},
   510  	}
   511  	resp := CollectorRequest(cmd, cs)
   512  	if nil == resp.Err {
   513  		t.Error("response should have contained error")
   514  	}
   515  	if resp.ShouldSaveHarvestData() {
   516  		t.Error("harvest data should be discarded when max_payload_size_in_bytes is exceeded")
   517  	}
   518  }
   519  
   520  func TestConnectReplyMaxPayloadSize(t *testing.T) {
   521  	testcases := []struct {
   522  		replyBody              string
   523  		expectedMaxPayloadSize int
   524  	}{
   525  		{
   526  			replyBody:              `{"return_value":{"agent_run_id":"my_agent_run_id"}}`,
   527  			expectedMaxPayloadSize: 1000 * 1000,
   528  		},
   529  		{
   530  			replyBody:              `{"return_value":{"agent_run_id":"my_agent_run_id","max_payload_size_in_bytes":123}}`,
   531  			expectedMaxPayloadSize: 123,
   532  		},
   533  	}
   534  
   535  	controls := func(replyBody string) RpmControls {
   536  		return RpmControls{
   537  			Client: &http.Client{
   538  				Transport: roundTripperFunc(func(r *http.Request) (*http.Response, error) {
   539  					return &http.Response{
   540  						StatusCode: 200,
   541  						Body:       ioutil.NopCloser(strings.NewReader(replyBody)),
   542  					}, nil
   543  				}),
   544  			},
   545  			Logger: logger.ShimLogger{IsDebugEnabled: true},
   546  			GzipWriterPool: &sync.Pool{
   547  				New: func() interface{} {
   548  					return gzip.NewWriter(io.Discard)
   549  				},
   550  			},
   551  		}
   552  	}
   553  
   554  	for _, test := range testcases {
   555  		reply, resp := ConnectAttempt(testConfig{}, "", false, controls(test.replyBody))
   556  		if nil != resp.Err {
   557  			t.Error("resp returned unexpected error:", resp.Err)
   558  		}
   559  		if test.expectedMaxPayloadSize != reply.MaxPayloadSizeInBytes {
   560  			t.Errorf("incorrect MaxPayloadSizeInBytes: expected=%d actual=%d",
   561  				test.expectedMaxPayloadSize, reply.MaxPayloadSizeInBytes)
   562  		}
   563  	}
   564  }
   565  
   566  func TestPreconnectRequestMarshall(t *testing.T) {
   567  	tests := map[string]preconnectRequest{
   568  		`[{"security_policies_token":"securityPoliciesToken","high_security":false}]`: {
   569  			SecurityPoliciesToken: "securityPoliciesToken",
   570  			HighSecurity:          false,
   571  		},
   572  		`[{"security_policies_token":"securityPoliciesToken","high_security":true}]`: {
   573  			SecurityPoliciesToken: "securityPoliciesToken",
   574  			HighSecurity:          true,
   575  		},
   576  		`[{"high_security":true}]`: {
   577  			SecurityPoliciesToken: "",
   578  			HighSecurity:          true,
   579  		},
   580  		`[{"high_security":false}]`: {
   581  			SecurityPoliciesToken: "",
   582  			HighSecurity:          false,
   583  		},
   584  	}
   585  	for expected, request := range tests {
   586  		b, e := json.Marshal([]preconnectRequest{request})
   587  		if e != nil {
   588  			t.Fatal("Unable to marshall preconnect request", e)
   589  		}
   590  		result := string(b)
   591  		if result != expected {
   592  			t.Errorf("Invalid preconnect request marshall: expected %s, got %s", expected, result)
   593  		}
   594  	}
   595  }