github.com/m3db/m3@v1.5.0/src/x/net/http/errors_test.go (about)

     1  // Copyright (c) 2018 Uber Technologies, Inc.
     2  //
     3  // Permission is hereby granted, free of charge, to any person obtaining a copy
     4  // of this software and associated documentation files (the "Software"), to deal
     5  // in the Software without restriction, including without limitation the rights
     6  // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
     7  // copies of the Software, and to permit persons to whom the Software is
     8  // furnished to do so, subject to the following conditions:
     9  //
    10  // The above copyright notice and this permission notice shall be included in
    11  // all copies or substantial portions of the Software.
    12  //
    13  // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    14  // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    15  // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    16  // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    17  // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    18  // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
    19  // THE SOFTWARE.
    20  
    21  package xhttp
    22  
    23  import (
    24  	"context"
    25  	"errors"
    26  	"fmt"
    27  	"net/http/httptest"
    28  	"testing"
    29  
    30  	"github.com/prometheus/prometheus/promql"
    31  	"github.com/stretchr/testify/assert"
    32  	"github.com/stretchr/testify/require"
    33  	"github.com/uber/tchannel-go"
    34  
    35  	terrors "github.com/m3db/m3/src/dbnode/network/server/tchannelthrift/errors"
    36  	xerrors "github.com/m3db/m3/src/x/errors"
    37  )
    38  
    39  func TestErrorStatus(t *testing.T) {
    40  	a := terrors.NewInternalError(tchannel.ErrTimeout)
    41  
    42  	tests := []struct {
    43  		name           string
    44  		err            error
    45  		expectedStatus int
    46  	}{
    47  		{
    48  			name:           "generic error",
    49  			err:            errors.New("generic error"),
    50  			expectedStatus: 500,
    51  		},
    52  		{
    53  			name:           "invalid params",
    54  			err:            xerrors.NewInvalidParamsError(errors.New("bad param")),
    55  			expectedStatus: 400,
    56  		},
    57  		{
    58  			name:           "deadline exceeded",
    59  			err:            context.DeadlineExceeded,
    60  			expectedStatus: 504,
    61  		},
    62  		{
    63  			name:           "canceled",
    64  			err:            context.Canceled,
    65  			expectedStatus: 499,
    66  		},
    67  		{
    68  			name:           "prom canceled",
    69  			err:            promql.ErrQueryCanceled("canceled"),
    70  			expectedStatus: 499,
    71  		},
    72  		{
    73  			name:           "client timeout",
    74  			err:            terrors.NewTimeoutError(fmt.Errorf("timeout")),
    75  			expectedStatus: 504,
    76  		},
    77  		{
    78  			name:           "tchannel timeout",
    79  			err:            tchannel.ErrTimeout,
    80  			expectedStatus: 504,
    81  		},
    82  		{
    83  			name:           "http error",
    84  			err:            NewError(errors.New("some error"), 504),
    85  			expectedStatus: 504,
    86  		},
    87  	}
    88  
    89  	for _, tt := range tests {
    90  		t.Run(tt.name, func(t *testing.T) {
    91  			assert.Error(t, a)
    92  			recorder := httptest.NewRecorder()
    93  			WriteError(recorder, tt.err)
    94  			assert.Equal(t, tt.expectedStatus, recorder.Code)
    95  		})
    96  	}
    97  }
    98  
    99  func TestErrorRewrite(t *testing.T) {
   100  	tests := []struct {
   101  		name           string
   102  		err            error
   103  		expectedBody   string
   104  		expectedStatus int
   105  	}{
   106  		{
   107  			name:           "error that should not be rewritten",
   108  			err:            errors.New("random error"),
   109  			expectedBody:   `{"status":"error","error":"random error"}`,
   110  			expectedStatus: 500,
   111  		},
   112  		{
   113  			name:           "error that should be rewritten",
   114  			err:            xerrors.NewInvalidParamsError(errors.New("to be rewritten")),
   115  			expectedBody:   `{"status":"error","error":"rewritten error"}`,
   116  			expectedStatus: 500,
   117  		},
   118  	}
   119  
   120  	invalidParamsRewriteFn := func(err error) error {
   121  		if xerrors.IsInvalidParams(err) {
   122  			return errors.New("rewritten error")
   123  		}
   124  		return err
   125  	}
   126  
   127  	SetErrorRewriteFn(invalidParamsRewriteFn)
   128  
   129  	for _, tt := range tests {
   130  		t.Run(tt.name, func(t *testing.T) {
   131  			recorder := httptest.NewRecorder()
   132  			WriteError(recorder, tt.err)
   133  			assert.JSONEq(t, tt.expectedBody, recorder.Body.String())
   134  			assert.Equal(t, tt.expectedStatus, recorder.Code)
   135  		})
   136  	}
   137  }
   138  
   139  func TestIsClientError(t *testing.T) {
   140  	tests := []struct {
   141  		err      error
   142  		expected bool
   143  	}{
   144  		{NewError(fmt.Errorf("xhttp.Error(400)"), 400), true},
   145  		{NewError(fmt.Errorf("xhttp.Error(499)"), 499), true},
   146  		{xerrors.NewInvalidParamsError(fmt.Errorf("InvalidParamsError")), true},
   147  		{xerrors.NewRetryableError(xerrors.NewInvalidParamsError(
   148  			fmt.Errorf("InvalidParamsError insde RetyrableError"))), true},
   149  
   150  		{NewError(fmt.Errorf("xhttp.Error(399)"), 399), false},
   151  		{NewError(fmt.Errorf("xhttp.Error(500)"), 500), false},
   152  		{xerrors.NewRetryableError(fmt.Errorf("any error inside RetryableError")), false},
   153  		{fmt.Errorf("any error"), false},
   154  	}
   155  
   156  	for _, tt := range tests {
   157  		t.Run(tt.err.Error(), func(t *testing.T) {
   158  			require.Equal(t, tt.expected, IsClientError(tt.err))
   159  		})
   160  	}
   161  }