github.com/Bytom/bytom@v1.1.2-0.20210127130405-ae40204c0b09/net/http/httpjson/handler_test.go (about)

     1  package httpjson
     2  
     3  import (
     4  	"context"
     5  	"net/http"
     6  	"net/http/httptest"
     7  	"reflect"
     8  	"strings"
     9  	"testing"
    10  	"testing/iotest"
    11  
    12  	"github.com/bytom/bytom/errors"
    13  )
    14  
    15  func TestHandler(t *testing.T) {
    16  	errX := errors.New("x")
    17  
    18  	cases := []struct {
    19  		rawQuery string
    20  		input    string
    21  		output   string
    22  		f        interface{}
    23  		wantErr  error
    24  	}{
    25  		{"", ``, `{"message":"ok"}`, func() {}, nil},
    26  		{"", ``, `1`, func() int { return 1 }, nil},
    27  		{"", ``, `{"message":"ok"}`, func() error { return nil }, nil},
    28  		{"", ``, ``, func() error { return errX }, errX},
    29  		{"", ``, `1`, func() (int, error) { return 1, nil }, nil},
    30  		{"", ``, ``, func() (int, error) { return 0, errX }, errX},
    31  		{"", `1`, `1`, func(i int) int { return i }, nil},
    32  		{"", `1`, `1`, func(i *int) int { return *i }, nil},
    33  		{"", `"foo"`, `"foo"`, func(s string) string { return s }, nil},
    34  		{"", `{"x":1}`, `1`, func(x struct{ X int }) int { return x.X }, nil},
    35  		{"", `{"x":1}`, `1`, func(x *struct{ X int }) int { return x.X }, nil},
    36  		{"", ``, `1`, func(ctx context.Context) int { return ctx.Value("k").(int) }, nil},
    37  	}
    38  
    39  	for _, test := range cases {
    40  		var gotErr error
    41  		errFunc := func(ctx context.Context, w http.ResponseWriter, err error) {
    42  			gotErr = err
    43  		}
    44  		h, err := Handler(test.f, errFunc)
    45  		if err != nil {
    46  			t.Errorf("Handler(%v) got err %v", test.f, err)
    47  			continue
    48  		}
    49  
    50  		resp := httptest.NewRecorder()
    51  		req, _ := http.NewRequest("GET", "/", strings.NewReader(test.input))
    52  		req.URL.RawQuery = test.rawQuery
    53  		ctx := context.WithValue(context.Background(), "k", 1)
    54  		h.ServeHTTP(resp, req.WithContext(ctx))
    55  		if resp.Code != 200 {
    56  			t.Errorf("%T response code = %d want 200", test.f, resp.Code)
    57  		}
    58  		got := strings.TrimSpace(resp.Body.String())
    59  		if got != test.output {
    60  			t.Errorf("%T response body = %#q want %#q", test.f, got, test.output)
    61  		}
    62  		if gotErr != test.wantErr {
    63  			t.Errorf("%T err = %v want %v", test.f, gotErr, test.wantErr)
    64  		}
    65  	}
    66  }
    67  
    68  func TestReadErr(t *testing.T) {
    69  	var gotErr error
    70  	errFunc := func(ctx context.Context, w http.ResponseWriter, err error) {
    71  		gotErr = errors.Root(err)
    72  	}
    73  	h, _ := Handler(func(int) {}, errFunc)
    74  
    75  	resp := httptest.NewRecorder()
    76  	body := iotest.OneByteReader(iotest.TimeoutReader(strings.NewReader("123456")))
    77  	req, _ := http.NewRequest("GET", "/", body)
    78  	h.ServeHTTP(resp, req)
    79  	if got := resp.Body.Len(); got != 0 {
    80  		t.Errorf("len(response) = %d want 0", got)
    81  	}
    82  	wantErr := ErrBadRequest
    83  	if gotErr != wantErr {
    84  		t.Errorf("err = %v want %v", gotErr, wantErr)
    85  	}
    86  }
    87  
    88  func TestFuncInputTypeError(t *testing.T) {
    89  	cases := []interface{}{
    90  		0,
    91  		"foo",
    92  		func() (int, int) { return 0, 0 },
    93  		func(string, int) {},
    94  		func() (int, int, error) { return 0, 0, nil },
    95  	}
    96  
    97  	for _, testf := range cases {
    98  		_, _, err := funcInputType(reflect.ValueOf(testf))
    99  		if err == nil {
   100  			t.Errorf("funcInputType(%T) want error", testf)
   101  		}
   102  
   103  		_, err = Handler(testf, nil)
   104  		if err == nil {
   105  			t.Errorf("funcInputType(%T) want error", testf)
   106  		}
   107  	}
   108  }
   109  
   110  var (
   111  	intType    = reflect.TypeOf(0)
   112  	intpType   = reflect.TypeOf((*int)(nil))
   113  	stringType = reflect.TypeOf("")
   114  )
   115  
   116  func TestFuncInputTypeOk(t *testing.T) {
   117  	cases := []struct {
   118  		f       interface{}
   119  		wantCtx bool
   120  		wantT   reflect.Type
   121  	}{
   122  		{func() {}, false, nil},
   123  		{func() int { return 0 }, false, nil},
   124  		{func() error { return nil }, false, nil},
   125  		{func() (int, error) { return 0, nil }, false, nil},
   126  		{func(int) {}, false, intType},
   127  		{func(*int) {}, false, intpType},
   128  		{func(context.Context) {}, true, nil},
   129  		{func(string) {}, false, stringType}, // req body is string
   130  	}
   131  
   132  	for _, test := range cases {
   133  		gotCtx, gotT, err := funcInputType(reflect.ValueOf(test.f))
   134  		if err != nil {
   135  			t.Errorf("funcInputType(%T) got error: %v", test.f, err)
   136  		}
   137  		if gotCtx != test.wantCtx {
   138  			t.Errorf("funcInputType(%T) context = %v want %v", test.f, gotCtx, test.wantCtx)
   139  		}
   140  		if gotT != test.wantT {
   141  			t.Errorf("funcInputType(%T) = %v want %v", test.f, gotT, test.wantT)
   142  		}
   143  	}
   144  }