github.com/sohaha/zlsgo@v1.7.13-0.20240501141223-10dd1a906f76/znet/inject_test.go (about)

     1  package znet
     2  
     3  import (
     4  	"errors"
     5  	"net/http"
     6  	"net/http/httptest"
     7  	"reflect"
     8  	"testing"
     9  	"time"
    10  
    11  	"github.com/sohaha/zlsgo"
    12  	"github.com/sohaha/zlsgo/zdi"
    13  	"github.com/sohaha/zlsgo/zjson"
    14  	"github.com/sohaha/zlsgo/zlog"
    15  )
    16  
    17  var _ zdi.PreInvoker = (*CustomInvoker)(nil)
    18  
    19  type CustomInvoker func(ctx *Context) (b []byte)
    20  
    21  func (fn CustomInvoker) Invoke(i []interface{}) ([]reflect.Value, error) {
    22  	c := i[0].(*Context)
    23  	b := fn(c)
    24  	c.Byte(404, b)
    25  	return nil, nil
    26  }
    27  
    28  func TestInject(t *testing.T) {
    29  	tt := zlsgo.NewTest(t)
    30  	r := newServer()
    31  
    32  	r.Log.Discard()
    33  
    34  	w := newRequest(r, "GET", "/NoInject", "/NoInject", func(c *Context) {
    35  		c.String(200, "NoInject")
    36  	})
    37  	tt.Equal(200, w.Code)
    38  	tt.Equal("NoInject", w.Body.String())
    39  
    40  	w = newRequest(r, "GET", "/Inject", "/Inject", func() (int, string) {
    41  		return 403, "Inject"
    42  	})
    43  	tt.Equal(403, w.Code)
    44  	tt.Equal("Inject", w.Body.String())
    45  
    46  	rewriteError := ""
    47  	w = newRequest(r, "GET", "/InjectErr", "/InjectErr", func() (int, string, error) {
    48  		return 403, "test InjectErr", errors.New("test InjectErr")
    49  	})
    50  	tt.Equal(500, w.Code)
    51  	tt.Equal("test InjectErr", w.Body.String())
    52  	tt.Equal("", rewriteError)
    53  	w = newRequest(r, "GET", "/InjectErrRewrite", "/InjectErrRewrite", func() (int, string, error) {
    54  		return 403, "test InjectErr", errors.New("InjectErrRewrite")
    55  	}, RewriteErrorHandler(func(c *Context, err error) {
    56  		tt.Equal("InjectErrRewrite", err.Error())
    57  		rewriteError = err.Error()
    58  	}))
    59  	tt.Equal(403, w.Code)
    60  	tt.Equal("test InjectErr", w.Body.String())
    61  	tt.Equal("InjectErrRewrite", rewriteError)
    62  
    63  	w = newRequest(r, "GET", "/InjectCustom", "/InjectCustom", CustomInvoker(func(ctx *Context) (b []byte) {
    64  		return []byte("InjectCustom")
    65  	}), func() {
    66  
    67  	})
    68  	tt.Equal(404, w.Code)
    69  	tt.Equal("InjectCustom", w.Body.String())
    70  
    71  	w = newRequest(r, "GET", "/InjectAny", "/InjectAny", func(ctx *Context) (c uint, api ApiData, err error) {
    72  		return 302, ApiData{Code: 301, Msg: "InjectAny"}, nil
    73  	})
    74  	tt.Equal(302, w.Code)
    75  	tt.Equal("application/json; charset=utf-8", w.Header().Get("Content-Type"))
    76  	tt.Equal("InjectAny", zjson.Get(w.Body.String(), "msg").String())
    77  
    78  	w = httptest.NewRecorder()
    79  	req, _ := http.NewRequest("GET", "__404__", nil)
    80  	r.ServeHTTP(w, req)
    81  	t.Log(w)
    82  }
    83  
    84  func TestInjectMiddleware(t *testing.T) {
    85  	tt := zlsgo.NewTest(t)
    86  	r := newServer()
    87  
    88  	now := time.Now()
    89  	r.Injector().Map(now)
    90  
    91  	w := newRequest(r, "GET", "/TestInjectMiddleware", "/TestInjectMiddleware", func() (int, string) {
    92  		t.Log("run")
    93  		return 403, "Inject"
    94  	}, Recovery(func(c *Context, err error) {
    95  		zlog.Error("Recovery", err)
    96  	}), func(c *Context) {
    97  		c.Next()
    98  	}, func(n time.Time) error {
    99  		tt.Equal(now, n)
   100  		return errors.New("return exit")
   101  	})
   102  	tt.Equal(500, w.Code)
   103  	tt.Equal("return exit", w.Body.String())
   104  
   105  	pc := make([]int, 0)
   106  	w = newRequest(r, "GET", "/TestInjectMiddleware2", "/TestInjectMiddleware2", func() (int, string) {
   107  		t.Log("run")
   108  		return 403, "Inject"
   109  	}, func(c *Context) {
   110  		pc = append(pc, 1)
   111  		c.Next()
   112  		pc = append(pc, 9)
   113  	}, func(c *Context) string {
   114  		pc = append(pc, 2)
   115  		c.Next()
   116  		pc = append(pc, 8)
   117  		return "middleware"
   118  	}, func(c *Context) error {
   119  		pc = append(pc, 3)
   120  		c.Next()
   121  		pc = append(pc, 7)
   122  		var s string
   123  		err := c.Injector().Resolve(&s)
   124  		tt.NoError(err)
   125  		tt.Equal("test", s)
   126  		return nil
   127  	}, func() {
   128  		pc = append(pc, 4)
   129  	}, func(c *Context) {
   130  		c.Next()
   131  		pc = append(pc, 6)
   132  	}, func(c *Context) {
   133  		pc = append(pc, 5)
   134  		c.Injector().Map("test")
   135  		c.Next()
   136  	})
   137  	tt.Equal(403, w.Code)
   138  	tt.Equal([]int{1, 2, 3, 4, 5, 6, 7, 8, 9}, pc)
   139  	tt.Equal("middleware", w.Body.String())
   140  }
   141  
   142  func BenchmarkInjectNo(b *testing.B) {
   143  	r := newServer()
   144  	path := "/BenchmarkInjectNo"
   145  	r.SetMode(QuietMode)
   146  	r.GET(path, func(c *Context) {
   147  		c.String(200, path)
   148  	})
   149  	b.ResetTimer()
   150  	b.ReportAllocs()
   151  	for i := 0; i < b.N; i++ {
   152  		w := httptest.NewRecorder()
   153  		req, _ := http.NewRequest("GET", path, nil)
   154  		r.ServeHTTP(w, req)
   155  		if w.Code != 200 || w.Body.String() != path {
   156  			b.Fail()
   157  		}
   158  	}
   159  }
   160  
   161  func BenchmarkInjectFast(b *testing.B) {
   162  	r := newServer()
   163  	r.SetMode(QuietMode)
   164  	path := "/BenchmarkInjectFast"
   165  	r.GET(path, func() (int, string) {
   166  		return 200, path
   167  	})
   168  	b.ResetTimer()
   169  	b.ReportAllocs()
   170  	for i := 0; i < b.N; i++ {
   171  		w := httptest.NewRecorder()
   172  		req, _ := http.NewRequest("GET", path, nil)
   173  		r.ServeHTTP(w, req)
   174  		if w.Code != 200 || w.Body.String() != path {
   175  			b.Fail()
   176  		}
   177  	}
   178  }
   179  
   180  func BenchmarkInjectBasis(b *testing.B) {
   181  	r := newServer()
   182  	r.SetMode(QuietMode)
   183  	path := "/BenchmarkInjectBasis"
   184  	r.GET(path, func() (int, []byte) {
   185  		return 200, []byte(path)
   186  	})
   187  	b.ResetTimer()
   188  	b.ReportAllocs()
   189  	for i := 0; i < b.N; i++ {
   190  		w := httptest.NewRecorder()
   191  		req, _ := http.NewRequest("GET", path, nil)
   192  		r.ServeHTTP(w, req)
   193  		if w.Code != 200 || w.Body.String() != path {
   194  			b.Fail()
   195  		}
   196  	}
   197  }