github.com/sohaha/zlsgo@v1.7.13-0.20240501141223-10dd1a906f76/znet/inject_test.go (about) 1 package znet 2 3 import ( 4 "errors" 5 "net/http" 6 "net/http/httptest" 7 "reflect" 8 "testing" 9 "time" 10 11 "github.com/sohaha/zlsgo" 12 "github.com/sohaha/zlsgo/zdi" 13 "github.com/sohaha/zlsgo/zjson" 14 "github.com/sohaha/zlsgo/zlog" 15 ) 16 17 var _ zdi.PreInvoker = (*CustomInvoker)(nil) 18 19 type CustomInvoker func(ctx *Context) (b []byte) 20 21 func (fn CustomInvoker) Invoke(i []interface{}) ([]reflect.Value, error) { 22 c := i[0].(*Context) 23 b := fn(c) 24 c.Byte(404, b) 25 return nil, nil 26 } 27 28 func TestInject(t *testing.T) { 29 tt := zlsgo.NewTest(t) 30 r := newServer() 31 32 r.Log.Discard() 33 34 w := newRequest(r, "GET", "/NoInject", "/NoInject", func(c *Context) { 35 c.String(200, "NoInject") 36 }) 37 tt.Equal(200, w.Code) 38 tt.Equal("NoInject", w.Body.String()) 39 40 w = newRequest(r, "GET", "/Inject", "/Inject", func() (int, string) { 41 return 403, "Inject" 42 }) 43 tt.Equal(403, w.Code) 44 tt.Equal("Inject", w.Body.String()) 45 46 rewriteError := "" 47 w = newRequest(r, "GET", "/InjectErr", "/InjectErr", func() (int, string, error) { 48 return 403, "test InjectErr", errors.New("test InjectErr") 49 }) 50 tt.Equal(500, w.Code) 51 tt.Equal("test InjectErr", w.Body.String()) 52 tt.Equal("", rewriteError) 53 w = newRequest(r, "GET", "/InjectErrRewrite", "/InjectErrRewrite", func() (int, string, error) { 54 return 403, "test InjectErr", errors.New("InjectErrRewrite") 55 }, RewriteErrorHandler(func(c *Context, err error) { 56 tt.Equal("InjectErrRewrite", err.Error()) 57 rewriteError = err.Error() 58 })) 59 tt.Equal(403, w.Code) 60 tt.Equal("test InjectErr", w.Body.String()) 61 tt.Equal("InjectErrRewrite", rewriteError) 62 63 w = newRequest(r, "GET", "/InjectCustom", "/InjectCustom", CustomInvoker(func(ctx *Context) (b []byte) { 64 return []byte("InjectCustom") 65 }), func() { 66 67 }) 68 tt.Equal(404, w.Code) 69 tt.Equal("InjectCustom", w.Body.String()) 70 71 w = newRequest(r, "GET", "/InjectAny", "/InjectAny", func(ctx *Context) (c uint, api ApiData, err error) { 72 return 302, ApiData{Code: 301, Msg: "InjectAny"}, nil 73 }) 74 tt.Equal(302, w.Code) 75 tt.Equal("application/json; charset=utf-8", w.Header().Get("Content-Type")) 76 tt.Equal("InjectAny", zjson.Get(w.Body.String(), "msg").String()) 77 78 w = httptest.NewRecorder() 79 req, _ := http.NewRequest("GET", "__404__", nil) 80 r.ServeHTTP(w, req) 81 t.Log(w) 82 } 83 84 func TestInjectMiddleware(t *testing.T) { 85 tt := zlsgo.NewTest(t) 86 r := newServer() 87 88 now := time.Now() 89 r.Injector().Map(now) 90 91 w := newRequest(r, "GET", "/TestInjectMiddleware", "/TestInjectMiddleware", func() (int, string) { 92 t.Log("run") 93 return 403, "Inject" 94 }, Recovery(func(c *Context, err error) { 95 zlog.Error("Recovery", err) 96 }), func(c *Context) { 97 c.Next() 98 }, func(n time.Time) error { 99 tt.Equal(now, n) 100 return errors.New("return exit") 101 }) 102 tt.Equal(500, w.Code) 103 tt.Equal("return exit", w.Body.String()) 104 105 pc := make([]int, 0) 106 w = newRequest(r, "GET", "/TestInjectMiddleware2", "/TestInjectMiddleware2", func() (int, string) { 107 t.Log("run") 108 return 403, "Inject" 109 }, func(c *Context) { 110 pc = append(pc, 1) 111 c.Next() 112 pc = append(pc, 9) 113 }, func(c *Context) string { 114 pc = append(pc, 2) 115 c.Next() 116 pc = append(pc, 8) 117 return "middleware" 118 }, func(c *Context) error { 119 pc = append(pc, 3) 120 c.Next() 121 pc = append(pc, 7) 122 var s string 123 err := c.Injector().Resolve(&s) 124 tt.NoError(err) 125 tt.Equal("test", s) 126 return nil 127 }, func() { 128 pc = append(pc, 4) 129 }, func(c *Context) { 130 c.Next() 131 pc = append(pc, 6) 132 }, func(c *Context) { 133 pc = append(pc, 5) 134 c.Injector().Map("test") 135 c.Next() 136 }) 137 tt.Equal(403, w.Code) 138 tt.Equal([]int{1, 2, 3, 4, 5, 6, 7, 8, 9}, pc) 139 tt.Equal("middleware", w.Body.String()) 140 } 141 142 func BenchmarkInjectNo(b *testing.B) { 143 r := newServer() 144 path := "/BenchmarkInjectNo" 145 r.SetMode(QuietMode) 146 r.GET(path, func(c *Context) { 147 c.String(200, path) 148 }) 149 b.ResetTimer() 150 b.ReportAllocs() 151 for i := 0; i < b.N; i++ { 152 w := httptest.NewRecorder() 153 req, _ := http.NewRequest("GET", path, nil) 154 r.ServeHTTP(w, req) 155 if w.Code != 200 || w.Body.String() != path { 156 b.Fail() 157 } 158 } 159 } 160 161 func BenchmarkInjectFast(b *testing.B) { 162 r := newServer() 163 r.SetMode(QuietMode) 164 path := "/BenchmarkInjectFast" 165 r.GET(path, func() (int, string) { 166 return 200, path 167 }) 168 b.ResetTimer() 169 b.ReportAllocs() 170 for i := 0; i < b.N; i++ { 171 w := httptest.NewRecorder() 172 req, _ := http.NewRequest("GET", path, nil) 173 r.ServeHTTP(w, req) 174 if w.Code != 200 || w.Body.String() != path { 175 b.Fail() 176 } 177 } 178 } 179 180 func BenchmarkInjectBasis(b *testing.B) { 181 r := newServer() 182 r.SetMode(QuietMode) 183 path := "/BenchmarkInjectBasis" 184 r.GET(path, func() (int, []byte) { 185 return 200, []byte(path) 186 }) 187 b.ResetTimer() 188 b.ReportAllocs() 189 for i := 0; i < b.N; i++ { 190 w := httptest.NewRecorder() 191 req, _ := http.NewRequest("GET", path, nil) 192 r.ServeHTTP(w, req) 193 if w.Code != 200 || w.Body.String() != path { 194 b.Fail() 195 } 196 } 197 }