github.com/m3db/m3@v1.5.0/src/x/net/http/errors_test.go (about) 1 // Copyright (c) 2018 Uber Technologies, Inc. 2 // 3 // Permission is hereby granted, free of charge, to any person obtaining a copy 4 // of this software and associated documentation files (the "Software"), to deal 5 // in the Software without restriction, including without limitation the rights 6 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 // copies of the Software, and to permit persons to whom the Software is 8 // furnished to do so, subject to the following conditions: 9 // 10 // The above copyright notice and this permission notice shall be included in 11 // all copies or substantial portions of the Software. 12 // 13 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19 // THE SOFTWARE. 20 21 package xhttp 22 23 import ( 24 "context" 25 "errors" 26 "fmt" 27 "net/http/httptest" 28 "testing" 29 30 "github.com/prometheus/prometheus/promql" 31 "github.com/stretchr/testify/assert" 32 "github.com/stretchr/testify/require" 33 "github.com/uber/tchannel-go" 34 35 terrors "github.com/m3db/m3/src/dbnode/network/server/tchannelthrift/errors" 36 xerrors "github.com/m3db/m3/src/x/errors" 37 ) 38 39 func TestErrorStatus(t *testing.T) { 40 a := terrors.NewInternalError(tchannel.ErrTimeout) 41 42 tests := []struct { 43 name string 44 err error 45 expectedStatus int 46 }{ 47 { 48 name: "generic error", 49 err: errors.New("generic error"), 50 expectedStatus: 500, 51 }, 52 { 53 name: "invalid params", 54 err: xerrors.NewInvalidParamsError(errors.New("bad param")), 55 expectedStatus: 400, 56 }, 57 { 58 name: "deadline exceeded", 59 err: context.DeadlineExceeded, 60 expectedStatus: 504, 61 }, 62 { 63 name: "canceled", 64 err: context.Canceled, 65 expectedStatus: 499, 66 }, 67 { 68 name: "prom canceled", 69 err: promql.ErrQueryCanceled("canceled"), 70 expectedStatus: 499, 71 }, 72 { 73 name: "client timeout", 74 err: terrors.NewTimeoutError(fmt.Errorf("timeout")), 75 expectedStatus: 504, 76 }, 77 { 78 name: "tchannel timeout", 79 err: tchannel.ErrTimeout, 80 expectedStatus: 504, 81 }, 82 { 83 name: "http error", 84 err: NewError(errors.New("some error"), 504), 85 expectedStatus: 504, 86 }, 87 } 88 89 for _, tt := range tests { 90 t.Run(tt.name, func(t *testing.T) { 91 assert.Error(t, a) 92 recorder := httptest.NewRecorder() 93 WriteError(recorder, tt.err) 94 assert.Equal(t, tt.expectedStatus, recorder.Code) 95 }) 96 } 97 } 98 99 func TestErrorRewrite(t *testing.T) { 100 tests := []struct { 101 name string 102 err error 103 expectedBody string 104 expectedStatus int 105 }{ 106 { 107 name: "error that should not be rewritten", 108 err: errors.New("random error"), 109 expectedBody: `{"status":"error","error":"random error"}`, 110 expectedStatus: 500, 111 }, 112 { 113 name: "error that should be rewritten", 114 err: xerrors.NewInvalidParamsError(errors.New("to be rewritten")), 115 expectedBody: `{"status":"error","error":"rewritten error"}`, 116 expectedStatus: 500, 117 }, 118 } 119 120 invalidParamsRewriteFn := func(err error) error { 121 if xerrors.IsInvalidParams(err) { 122 return errors.New("rewritten error") 123 } 124 return err 125 } 126 127 SetErrorRewriteFn(invalidParamsRewriteFn) 128 129 for _, tt := range tests { 130 t.Run(tt.name, func(t *testing.T) { 131 recorder := httptest.NewRecorder() 132 WriteError(recorder, tt.err) 133 assert.JSONEq(t, tt.expectedBody, recorder.Body.String()) 134 assert.Equal(t, tt.expectedStatus, recorder.Code) 135 }) 136 } 137 } 138 139 func TestIsClientError(t *testing.T) { 140 tests := []struct { 141 err error 142 expected bool 143 }{ 144 {NewError(fmt.Errorf("xhttp.Error(400)"), 400), true}, 145 {NewError(fmt.Errorf("xhttp.Error(499)"), 499), true}, 146 {xerrors.NewInvalidParamsError(fmt.Errorf("InvalidParamsError")), true}, 147 {xerrors.NewRetryableError(xerrors.NewInvalidParamsError( 148 fmt.Errorf("InvalidParamsError insde RetyrableError"))), true}, 149 150 {NewError(fmt.Errorf("xhttp.Error(399)"), 399), false}, 151 {NewError(fmt.Errorf("xhttp.Error(500)"), 500), false}, 152 {xerrors.NewRetryableError(fmt.Errorf("any error inside RetryableError")), false}, 153 {fmt.Errorf("any error"), false}, 154 } 155 156 for _, tt := range tests { 157 t.Run(tt.err.Error(), func(t *testing.T) { 158 require.Equal(t, tt.expected, IsClientError(tt.err)) 159 }) 160 } 161 }