github.com/lingyao2333/mo-zero@v1.4.1/rest/httpx/responses_test.go (about)

     1  package httpx
     2  
     3  import (
     4  	"errors"
     5  	"net/http"
     6  	"strings"
     7  	"testing"
     8  
     9  	"github.com/lingyao2333/mo-zero/core/logx"
    10  	"github.com/stretchr/testify/assert"
    11  	"google.golang.org/grpc/codes"
    12  	"google.golang.org/grpc/status"
    13  )
    14  
    15  type message struct {
    16  	Name string `json:"name"`
    17  }
    18  
    19  func init() {
    20  	logx.Disable()
    21  }
    22  
    23  func TestError(t *testing.T) {
    24  	const (
    25  		body        = "foo"
    26  		wrappedBody = `"foo"`
    27  	)
    28  
    29  	tests := []struct {
    30  		name          string
    31  		input         string
    32  		errorHandler  func(error) (int, interface{})
    33  		expectHasBody bool
    34  		expectBody    string
    35  		expectCode    int
    36  	}{
    37  		{
    38  			name:          "default error handler",
    39  			input:         body,
    40  			expectHasBody: true,
    41  			expectBody:    body,
    42  			expectCode:    http.StatusBadRequest,
    43  		},
    44  		{
    45  			name:  "customized error handler return string",
    46  			input: body,
    47  			errorHandler: func(err error) (int, interface{}) {
    48  				return http.StatusForbidden, err.Error()
    49  			},
    50  			expectHasBody: true,
    51  			expectBody:    wrappedBody,
    52  			expectCode:    http.StatusForbidden,
    53  		},
    54  		{
    55  			name:  "customized error handler return error",
    56  			input: body,
    57  			errorHandler: func(err error) (int, interface{}) {
    58  				return http.StatusForbidden, err
    59  			},
    60  			expectHasBody: true,
    61  			expectBody:    body,
    62  			expectCode:    http.StatusForbidden,
    63  		},
    64  		{
    65  			name:  "customized error handler return nil",
    66  			input: body,
    67  			errorHandler: func(err error) (int, interface{}) {
    68  				return http.StatusForbidden, nil
    69  			},
    70  			expectHasBody: false,
    71  			expectBody:    "",
    72  			expectCode:    http.StatusForbidden,
    73  		},
    74  	}
    75  
    76  	for _, test := range tests {
    77  		t.Run(test.name, func(t *testing.T) {
    78  			w := tracedResponseWriter{
    79  				headers: make(map[string][]string),
    80  			}
    81  			if test.errorHandler != nil {
    82  				lock.RLock()
    83  				prev := errorHandler
    84  				lock.RUnlock()
    85  				SetErrorHandler(test.errorHandler)
    86  				defer func() {
    87  					lock.Lock()
    88  					errorHandler = prev
    89  					lock.Unlock()
    90  				}()
    91  			}
    92  			Error(&w, errors.New(test.input))
    93  			assert.Equal(t, test.expectCode, w.code)
    94  			assert.Equal(t, test.expectHasBody, w.hasBody)
    95  			assert.Equal(t, test.expectBody, strings.TrimSpace(w.builder.String()))
    96  		})
    97  	}
    98  }
    99  
   100  func TestErrorWithGrpcError(t *testing.T) {
   101  	w := tracedResponseWriter{
   102  		headers: make(map[string][]string),
   103  	}
   104  	Error(&w, status.Error(codes.Unavailable, "foo"))
   105  	assert.Equal(t, http.StatusServiceUnavailable, w.code)
   106  	assert.True(t, w.hasBody)
   107  	assert.True(t, strings.Contains(w.builder.String(), "foo"))
   108  }
   109  
   110  func TestErrorWithHandler(t *testing.T) {
   111  	w := tracedResponseWriter{
   112  		headers: make(map[string][]string),
   113  	}
   114  	Error(&w, errors.New("foo"), func(w http.ResponseWriter, err error) {
   115  		http.Error(w, err.Error(), 499)
   116  	})
   117  	assert.Equal(t, 499, w.code)
   118  	assert.True(t, w.hasBody)
   119  	assert.Equal(t, "foo", strings.TrimSpace(w.builder.String()))
   120  }
   121  
   122  func TestOk(t *testing.T) {
   123  	w := tracedResponseWriter{
   124  		headers: make(map[string][]string),
   125  	}
   126  	Ok(&w)
   127  	assert.Equal(t, http.StatusOK, w.code)
   128  }
   129  
   130  func TestOkJson(t *testing.T) {
   131  	w := tracedResponseWriter{
   132  		headers: make(map[string][]string),
   133  	}
   134  	msg := message{Name: "anyone"}
   135  	OkJson(&w, msg)
   136  	assert.Equal(t, http.StatusOK, w.code)
   137  	assert.Equal(t, "{\"name\":\"anyone\"}", w.builder.String())
   138  }
   139  
   140  func TestWriteJsonTimeout(t *testing.T) {
   141  	// only log it and ignore
   142  	w := tracedResponseWriter{
   143  		headers: make(map[string][]string),
   144  		err:     http.ErrHandlerTimeout,
   145  	}
   146  	msg := message{Name: "anyone"}
   147  	WriteJson(&w, http.StatusOK, msg)
   148  	assert.Equal(t, http.StatusOK, w.code)
   149  }
   150  
   151  func TestWriteJsonError(t *testing.T) {
   152  	// only log it and ignore
   153  	w := tracedResponseWriter{
   154  		headers: make(map[string][]string),
   155  		err:     errors.New("foo"),
   156  	}
   157  	msg := message{Name: "anyone"}
   158  	WriteJson(&w, http.StatusOK, msg)
   159  	assert.Equal(t, http.StatusOK, w.code)
   160  }
   161  
   162  func TestWriteJsonLessWritten(t *testing.T) {
   163  	w := tracedResponseWriter{
   164  		headers:     make(map[string][]string),
   165  		lessWritten: true,
   166  	}
   167  	msg := message{Name: "anyone"}
   168  	WriteJson(&w, http.StatusOK, msg)
   169  	assert.Equal(t, http.StatusOK, w.code)
   170  }
   171  
   172  func TestWriteJsonMarshalFailed(t *testing.T) {
   173  	w := tracedResponseWriter{
   174  		headers: make(map[string][]string),
   175  	}
   176  	WriteJson(&w, http.StatusOK, map[string]interface{}{
   177  		"Data": complex(0, 0),
   178  	})
   179  	assert.Equal(t, http.StatusInternalServerError, w.code)
   180  }
   181  
   182  type tracedResponseWriter struct {
   183  	headers     map[string][]string
   184  	builder     strings.Builder
   185  	hasBody     bool
   186  	code        int
   187  	lessWritten bool
   188  	wroteHeader bool
   189  	err         error
   190  }
   191  
   192  func (w *tracedResponseWriter) Header() http.Header {
   193  	return w.headers
   194  }
   195  
   196  func (w *tracedResponseWriter) Write(bytes []byte) (n int, err error) {
   197  	if w.err != nil {
   198  		return 0, w.err
   199  	}
   200  
   201  	n, err = w.builder.Write(bytes)
   202  	if w.lessWritten {
   203  		n--
   204  	}
   205  	w.hasBody = true
   206  
   207  	return
   208  }
   209  
   210  func (w *tracedResponseWriter) WriteHeader(code int) {
   211  	if w.wroteHeader {
   212  		return
   213  	}
   214  	w.wroteHeader = true
   215  	w.code = code
   216  }