github.com/newrelic/newrelic-client-go@v1.1.0/internal/http/client_test.go (about)

     1  //go:build unit
     2  // +build unit
     3  
     4  package http
     5  
     6  import (
     7  	"context"
     8  	"encoding/json"
     9  	"net/http"
    10  	"net/http/httptest"
    11  	"testing"
    12  	"time"
    13  
    14  	"github.com/newrelic/newrelic-client-go/pkg/contextkeys"
    15  
    16  	"github.com/stretchr/testify/assert"
    17  	"github.com/stretchr/testify/require"
    18  
    19  	"github.com/newrelic/newrelic-client-go/pkg/config"
    20  	"github.com/newrelic/newrelic-client-go/pkg/errors"
    21  	"github.com/newrelic/newrelic-client-go/pkg/logging"
    22  	mock "github.com/newrelic/newrelic-client-go/pkg/testhelpers"
    23  )
    24  
    25  const (
    26  	testServiceName = "serviceName"
    27  )
    28  
    29  func TestConfig(t *testing.T) {
    30  	t.Parallel()
    31  	testRestURL := "https://www.mocky.io"
    32  	testTimeout := time.Second * 5
    33  	testTransport := http.DefaultTransport
    34  
    35  	tc := config.New()
    36  
    37  	tc.HTTPTransport = testTransport
    38  	tc.Region().SetRestBaseURL(testRestURL)
    39  	tc.ServiceName = testServiceName
    40  	tc.Timeout = &testTimeout
    41  	tc.UserAgent = mock.UserAgent
    42  
    43  	c := NewClient(tc)
    44  
    45  	require.NotNil(t, c.logger)
    46  	require.Equal(t, &testTimeout, c.config.Timeout)
    47  	require.Equal(t, testRestURL, c.config.Region().RestURL())
    48  	require.Equal(t, mock.UserAgent, c.config.UserAgent)
    49  	require.Equal(t, c.config.ServiceName, testServiceName+"|newrelic-client-go")
    50  
    51  	require.Same(t, testTransport, c.config.HTTPTransport)
    52  }
    53  
    54  func TestConfigDefaults(t *testing.T) {
    55  	t.Parallel()
    56  	tc := mock.NewTestConfig(t, nil)
    57  	tc.ServiceName = testServiceName
    58  
    59  	c := NewClient(tc)
    60  
    61  	assert.Contains(t, c.config.UserAgent, "newrelic/newrelic-client-go")
    62  	assert.Equal(t, c.config.ServiceName, testServiceName+"|newrelic-client-go")
    63  }
    64  
    65  func TestConfigLogger(t *testing.T) {
    66  	t.Parallel()
    67  	tc := mock.NewTestConfig(t, nil)
    68  
    69  	tc.Logger = logging.NewMockLogger(t)
    70  
    71  	c := NewClient(tc)
    72  	// The logger used should be the same as the config
    73  	require.Same(t, tc.Logger, c.logger)
    74  }
    75  
    76  func TestDefaultErrorValue(t *testing.T) {
    77  	t.Parallel()
    78  	c := NewTestAPIClient(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    79  		w.Header().Set("Content-Type", "application/json")
    80  		w.WriteHeader(http.StatusBadRequest)
    81  		_, _ = w.Write([]byte(`{"error":{"title":"error message"}}`))
    82  	}))
    83  
    84  	_, err := c.Get(c.config.Region().RestURL("path"), nil, nil)
    85  
    86  	assert.Contains(t, err.Error(), "error message")
    87  }
    88  
    89  func TestUnauthorizedErrorValue(t *testing.T) {
    90  	t.Parallel()
    91  	c := NewTestAPIClient(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    92  		w.Header().Set("Content-Type", "application/json")
    93  		w.WriteHeader(http.StatusUnauthorized)
    94  		_, _ = w.Write([]byte(`{"error":{"title": "No API key specified"}}`)) // REST API 401 response body
    95  	}))
    96  
    97  	_, err := c.Get(c.config.Region().RestURL("path"), nil, nil)
    98  
    99  	// Ensure our custom 401 unauthorized error message is returned
   100  	assert.Contains(t, err.Error(), "Invalid credentials provided")
   101  	assert.IsType(t, &errors.UnauthorizedError{}, err)
   102  }
   103  
   104  type CustomErrorResponse struct {
   105  	CustomError string `json:"custom"`
   106  }
   107  
   108  func (c *CustomErrorResponse) New() ErrorResponse {
   109  	return &CustomErrorResponse{}
   110  }
   111  
   112  func (c *CustomErrorResponse) Error() string {
   113  	return c.CustomError
   114  }
   115  
   116  func (c *CustomErrorResponse) IsNotFound() bool {
   117  	return false
   118  }
   119  
   120  func (c *CustomErrorResponse) IsRetryableError() bool {
   121  	return false
   122  }
   123  
   124  func (c *CustomErrorResponse) IsDeprecated() bool {
   125  	return false
   126  }
   127  
   128  func (c *CustomErrorResponse) IsUnauthorized(resp *http.Response) bool {
   129  	return resp.StatusCode == http.StatusUnauthorized
   130  }
   131  
   132  func (c *CustomErrorResponse) IsPaymentRequired(resp *http.Response) bool {
   133  	return resp.StatusCode == http.StatusPaymentRequired
   134  }
   135  
   136  func TestCustomErrorValue(t *testing.T) {
   137  	t.Parallel()
   138  	c := NewTestAPIClient(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   139  		w.Header().Set("Content-Type", "application/json")
   140  		w.WriteHeader(http.StatusBadRequest)
   141  		_, _ = w.Write([]byte(`{"custom":"error message"}`))
   142  	}))
   143  
   144  	c.SetErrorValue(&CustomErrorResponse{})
   145  
   146  	_, err := c.Get(c.config.Region().RestURL("path"), nil, nil)
   147  
   148  	assert.Contains(t, err.Error(), "error message")
   149  }
   150  
   151  type CustomResponseValue struct {
   152  	Custom string `json:"custom"`
   153  }
   154  
   155  func TestResponseValue(t *testing.T) {
   156  	t.Parallel()
   157  	c := NewTestAPIClient(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   158  		w.Header().Set("Content-Type", "application/json")
   159  		_, _ = w.Write([]byte(`{"custom":"custom response string"}`))
   160  	}))
   161  
   162  	v := &CustomResponseValue{}
   163  	_, err := c.Get(c.config.Region().RestURL("path"), nil, v)
   164  
   165  	assert.NoError(t, err)
   166  	assert.Equal(t, &CustomResponseValue{Custom: "custom response string"}, v)
   167  }
   168  
   169  func TestQueryParams(t *testing.T) {
   170  	t.Parallel()
   171  	queryParams := struct {
   172  		A int `url:"a,omitempty"`
   173  		B int `url:"b,omitempty"`
   174  	}{
   175  		A: 1,
   176  		B: 2,
   177  	}
   178  
   179  	c := NewTestAPIClient(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   180  		w.Header().Set("Content-Type", "application/json")
   181  		w.WriteHeader(http.StatusOK)
   182  
   183  		a := r.URL.Query().Get("a")
   184  		assert.Equal(t, "1", a)
   185  
   186  		b := r.URL.Query().Get("b")
   187  		assert.Equal(t, "2", b)
   188  	}))
   189  
   190  	_, _ = c.Get(c.config.Region().RestURL("path"), &queryParams, nil)
   191  }
   192  
   193  type TestRequestBody struct {
   194  	A string `json:"a"`
   195  	B string `json:"b"`
   196  }
   197  
   198  func TestRequestBodyMarshal(t *testing.T) {
   199  	t.Parallel()
   200  	expected := TestRequestBody{
   201  		A: "1",
   202  		B: "2",
   203  	}
   204  
   205  	c := NewTestAPIClient(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   206  		w.Header().Set("Content-Type", "application/json")
   207  		w.WriteHeader(http.StatusOK)
   208  
   209  		actual := &TestRequestBody{}
   210  		err := json.NewDecoder(r.Body).Decode(&actual)
   211  
   212  		assert.NoError(t, err)
   213  		assert.Equal(t, &expected, actual)
   214  	}))
   215  
   216  	_, _ = c.Post(c.config.Region().RestURL("path"), nil, expected, nil)
   217  }
   218  
   219  type TestInvalidRequestBody struct {
   220  	Channel chan int `json:"a"`
   221  }
   222  
   223  func TestRequestBodyMarshalError(t *testing.T) {
   224  	t.Parallel()
   225  	b := TestInvalidRequestBody{
   226  		Channel: make(chan int),
   227  	}
   228  
   229  	c := NewTestAPIClient(t, nil)
   230  
   231  	_, err := c.Post(c.config.Region().RestURL("/path"), nil, b, nil)
   232  	assert.Error(t, err)
   233  }
   234  
   235  func TestUrlParseError(t *testing.T) {
   236  	t.Parallel()
   237  	c := NewTestAPIClient(t, nil)
   238  
   239  	_, err := c.Get(c.config.Region().RestURL("\\"), nil, nil)
   240  	assert.Error(t, err)
   241  }
   242  
   243  type TestInvalidReponseBody struct {
   244  	Channel chan int `json:"channel"`
   245  }
   246  
   247  func TestResponseUnmarshalError(t *testing.T) {
   248  	t.Parallel()
   249  	c := NewTestAPIClient(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   250  		w.Header().Set("Content-Type", "application/json")
   251  		_, _ = w.Write([]byte(`{"channel": "test"}`))
   252  	}))
   253  
   254  	_, err := c.Get(c.config.Region().RestURL("path"), nil, &TestInvalidReponseBody{})
   255  
   256  	assert.Error(t, err)
   257  }
   258  
   259  func TestHeaders(t *testing.T) {
   260  	t.Parallel()
   261  	c := NewTestAPIClient(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   262  		w.Header().Set("Content-Type", "application/json")
   263  		w.WriteHeader(http.StatusOK)
   264  
   265  		assert.Equal(t, mock.UserAgent, r.Header.Get("user-agent"))
   266  		assert.Equal(t, "newrelic-client-go", r.Header.Get("newrelic-requesting-services"))
   267  	}))
   268  
   269  	_, err := c.Get(c.config.Region().RestURL("path"), nil, nil)
   270  
   271  	assert.Nil(t, err)
   272  }
   273  
   274  func TestCustomClientHeaders(t *testing.T) {
   275  	t.Parallel()
   276  
   277  	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   278  		w.Header().Set("Content-Type", "application/json")
   279  		w.WriteHeader(http.StatusOK)
   280  
   281  		assert.Equal(t, "custom-user-agent", r.Header.Get("user-agent"))
   282  		assert.Equal(t, "custom-requesting-service|newrelic-client-go", r.Header.Get("newrelic-requesting-services"))
   283  	}))
   284  
   285  	tc := mock.NewTestConfig(t, ts)
   286  	tc.UserAgent = "custom-user-agent"
   287  	tc.ServiceName = "custom-requesting-service"
   288  
   289  	c := NewClient(tc)
   290  
   291  	_, err := c.Get(c.config.Region().RestURL("path"), nil, nil)
   292  
   293  	assert.Nil(t, err)
   294  }
   295  
   296  func TestCustomRequestHeaders(t *testing.T) {
   297  	t.Parallel()
   298  
   299  	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   300  		w.Header().Set("Content-Type", "application/json")
   301  		w.WriteHeader(http.StatusOK)
   302  
   303  		assert.Equal(t, "custom-user-agent", r.Header.Get("user-agent"))
   304  		assert.Equal(t, "custom-requesting-service|newrelic-client-go", r.Header.Get("newrelic-requesting-services"))
   305  	}))
   306  
   307  	tc := mock.NewTestConfig(t, ts)
   308  
   309  	c := NewClient(tc)
   310  
   311  	req, err := c.NewRequest("GET", c.config.Region().RestURL("path"), nil, nil, nil)
   312  
   313  	req.SetHeader("user-agent", "custom-user-agent")
   314  	req.SetServiceName("custom-requesting-service")
   315  
   316  	_, err = c.Do(req)
   317  
   318  	assert.Nil(t, err)
   319  }
   320  
   321  func TestAccountIDHeaderWithPersonalAPIKeyCapableV2Authorizer(t *testing.T) {
   322  	// Given mock server
   323  	t.Parallel()
   324  	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   325  		// Then X-Account-ID should be set in the header
   326  		assert.Equal(t, "custom-account-id", r.Header.Get("X-Account-ID"))
   327  	}))
   328  	tc := mock.NewTestConfig(t, ts)
   329  
   330  	// Given a client with PersonalAPIKeyCapableV2Authorizer Auth Strategy
   331  	c := NewClient(tc)
   332  	c.SetAuthStrategy(&PersonalAPIKeyCapableV2Authorizer{})
   333  
   334  	// When a request is made with context
   335  	req, err := c.NewRequest("GET", c.config.Region().RestURL("path"), nil, nil, nil)
   336  	ctx := contextkeys.SetAccountID(context.Background(), "custom-account-id")
   337  	req.WithContext(ctx)
   338  
   339  	// Then there are no errors with the request
   340  	_, err = c.Do(req)
   341  	assert.Nil(t, err)
   342  }
   343  
   344  func TestErrNotFound(t *testing.T) {
   345  	t.Parallel()
   346  	c := NewTestAPIClient(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   347  		w.WriteHeader(http.StatusNotFound)
   348  	}))
   349  
   350  	_, err := c.Get(c.config.Region().RestURL("path"), nil, nil)
   351  
   352  	assert.IsType(t, &errors.NotFound{}, err)
   353  }
   354  
   355  func TestRetryOnNerdGraphTooManyRequests(t *testing.T) {
   356  	t.Parallel()
   357  	attempts := 0
   358  	c := NewTestAPIClient(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   359  		w.Header().Set("Content-Type", "application/json")
   360  		_, _ = w.Write([]byte(`{"errors":[{"message": "some error", "extensions":{"errorClass":"TOO_MANY_REQUESTS"}}]}`))
   361  		attempts++
   362  	}))
   363  
   364  	c.client.RetryWaitMax = 10 * time.Millisecond
   365  	c.errorValue = &GraphQLErrorResponse{}
   366  	_, err := c.Get(c.config.Region().NerdGraphURL("graphql"), nil, nil)
   367  
   368  	assert.Equal(t, 4, attempts)
   369  	assert.Error(t, err)
   370  	assert.Contains(t, err.Error(), ErrClassTooManyRequests.Error())
   371  }
   372  
   373  func TestRetryOnNerdGraphTimeout(t *testing.T) {
   374  	t.Parallel()
   375  	attempts := 0
   376  	c := NewTestAPIClient(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   377  		w.Header().Set("Content-Type", "application/json")
   378  		_, _ = w.Write([]byte(`{"errors":[{"message": "some error", "extensions":{"errorClass":"TIMEOUT"}}]}`))
   379  		attempts++
   380  	}))
   381  
   382  	c.client.RetryWaitMax = 10 * time.Millisecond
   383  
   384  	c.errorValue = &GraphQLErrorResponse{}
   385  	_, err := c.Get(c.config.Region().NerdGraphURL("path"), nil, nil)
   386  
   387  	assert.Equal(t, 4, attempts)
   388  	assert.Error(t, err)
   389  	assert.IsType(t, &errors.MaxRetriesReached{}, err)
   390  	assert.Contains(t, err.Error(), "maximum retries reached")
   391  	assert.Contains(t, err.Error(), "some error")
   392  }
   393  
   394  func TestRetryOnNerdGraphInternalServerError(t *testing.T) {
   395  	t.Parallel()
   396  	attempts := 0
   397  	c := NewTestAPIClient(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   398  		w.Header().Set("Content-Type", "application/json")
   399  		_, _ = w.Write([]byte(`{"errors":[{"message": "some error", "extensions":{"errorClass":"INTERNAL_SERVER_ERROR"}}]}`))
   400  		attempts++
   401  	}))
   402  
   403  	c.client.RetryWaitMax = 10 * time.Millisecond
   404  
   405  	c.errorValue = &GraphQLErrorResponse{}
   406  	_, err := c.Get(c.config.Region().NerdGraphURL("path"), nil, nil)
   407  
   408  	assert.Equal(t, 4, attempts)
   409  	assert.Error(t, err)
   410  	assert.IsType(t, &errors.MaxRetriesReached{}, err)
   411  	assert.Contains(t, err.Error(), "maximum retries reached")
   412  	assert.Contains(t, err.Error(), "some error")
   413  }
   414  
   415  func TestInternalServerError(t *testing.T) {
   416  	t.Parallel()
   417  	c := NewTestAPIClient(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   418  		w.WriteHeader(http.StatusBadRequest)
   419  	}))
   420  
   421  	_, err := c.Get(c.config.Region().RestURL("path"), nil, nil)
   422  
   423  	assert.IsType(t, &errors.UnexpectedStatusCode{}, err)
   424  }
   425  
   426  func TestPost(t *testing.T) {
   427  	t.Parallel()
   428  	c := NewTestAPIClient(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   429  		w.WriteHeader(http.StatusOK)
   430  		_, _ = w.Write([]byte(`{}`))
   431  	}))
   432  
   433  	// string
   434  	_, err := c.Post(c.config.Region().RestURL("path"), &struct{}{}, "test string payload", &struct{}{})
   435  	assert.NoError(t, err)
   436  
   437  	// []byte
   438  	_, err = c.Post(c.config.Region().RestURL("path"), &struct{}{}, []byte(`bytes`), &struct{}{})
   439  	assert.NoError(t, err)
   440  
   441  	// other data type
   442  	_, err = c.Post(c.config.Region().RestURL("path"), &struct{}{}, &struct{}{}, &struct{}{})
   443  	assert.NoError(t, err)
   444  }
   445  
   446  func TestPut(t *testing.T) {
   447  	t.Parallel()
   448  	c := NewTestAPIClient(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   449  		w.WriteHeader(http.StatusOK)
   450  		_, _ = w.Write([]byte(`{}`))
   451  	}))
   452  
   453  	_, err := c.Put(c.config.Region().RestURL("path"), &struct{}{}, &struct{}{}, &struct{}{})
   454  
   455  	assert.NoError(t, err)
   456  }
   457  
   458  func TestDelete(t *testing.T) {
   459  	t.Parallel()
   460  	c := NewTestAPIClient(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   461  		w.WriteHeader(http.StatusOK)
   462  
   463  		_, _ = w.Write([]byte(`{}`))
   464  	}))
   465  
   466  	_, err := c.Delete(c.config.Region().RestURL("path"), &struct{}{}, &struct{}{})
   467  
   468  	assert.NoError(t, err)
   469  }
   470  
   471  func TestPaymentRequiredError(t *testing.T) {
   472  	t.Parallel()
   473  	c := NewTestAPIClient(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   474  		w.WriteHeader(http.StatusPaymentRequired)
   475  	}))
   476  
   477  	_, err := c.Get(c.config.Region().RestURL("path"), nil, nil)
   478  
   479  	assert.IsType(t, &errors.PaymentRequiredError{}, err)
   480  }