github.com/cloudwego/kitex@v0.9.0/server/middlewares_test.go (about) 1 /* 2 * Copyright 2024 CloudWeGo Authors 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package server 18 19 import ( 20 "context" 21 "errors" 22 "net" 23 "testing" 24 "time" 25 26 "github.com/cloudwego/kitex/internal/test" 27 "github.com/cloudwego/kitex/pkg/rpcinfo" 28 ) 29 30 var _ context.Context = (*mockCtx)(nil) 31 32 type mockCtx struct { 33 err error 34 ddl time.Time 35 hasDDL bool 36 done chan struct{} 37 data map[interface{}]interface{} 38 } 39 40 func (m *mockCtx) Deadline() (deadline time.Time, ok bool) { 41 return m.ddl, m.hasDDL 42 } 43 44 func (m *mockCtx) Done() <-chan struct{} { 45 return m.done 46 } 47 48 func (m *mockCtx) Err() error { 49 return m.err 50 } 51 52 func (m *mockCtx) Value(key interface{}) interface{} { 53 return m.data[key] 54 } 55 56 func Test_serverTimeoutMW(t *testing.T) { 57 addr, _ := net.ResolveTCPAddr("tcp", "127.0.0.1:8080") 58 from := rpcinfo.NewEndpointInfo("from_service", "from_method", addr, nil) 59 to := rpcinfo.NewEndpointInfo("to_service", "to_method", nil, nil) 60 newCtxWithRPCInfo := func(timeout time.Duration) context.Context { 61 cfg := rpcinfo.NewRPCConfig() 62 _ = rpcinfo.AsMutableRPCConfig(cfg).SetRPCTimeout(timeout) 63 ri := rpcinfo.NewRPCInfo(from, to, nil, cfg, nil) 64 return rpcinfo.NewCtxWithRPCInfo(context.Background(), ri) 65 } 66 timeoutMW := serverTimeoutMW(context.Background()) 67 68 t.Run("no_timeout(fastPath)", func(t *testing.T) { 69 // prepare 70 ctx := newCtxWithRPCInfo(0) 71 72 // test 73 err := timeoutMW(func(ctx context.Context, req, resp interface{}) (err error) { 74 ddl, ok := ctx.Deadline() 75 test.Assert(t, !ok) 76 test.Assert(t, ddl.IsZero()) 77 return nil 78 })(ctx, nil, nil) 79 80 // assert 81 test.Assert(t, err == nil, err) 82 }) 83 84 t.Run("finish_before_timeout_without_error", func(t *testing.T) { 85 // prepare 86 ctx := newCtxWithRPCInfo(time.Millisecond * 50) 87 waitFinish := make(chan struct{}) 88 89 // test 90 err := timeoutMW(func(ctx context.Context, req, resp interface{}) (err error) { 91 go func() { 92 timer := time.NewTimer(time.Millisecond * 20) 93 select { 94 case <-ctx.Done(): 95 t.Errorf("ctx done, error: %v", ctx.Err()) 96 case <-timer.C: 97 t.Logf("(expected) ctx not done") 98 } 99 waitFinish <- struct{}{} 100 }() 101 return nil 102 })(ctx, nil, nil) 103 104 // assert 105 test.Assert(t, err == nil, err) 106 <-waitFinish 107 }) 108 109 t.Run("finish_before_timeout_with_error", func(t *testing.T) { 110 // prepare 111 ctx := newCtxWithRPCInfo(time.Millisecond * 50) 112 waitFinish := make(chan struct{}) 113 114 // test 115 err := timeoutMW(func(ctx context.Context, req, resp interface{}) (err error) { 116 go func() { 117 timer := time.NewTimer(time.Millisecond * 20) 118 select { 119 case <-ctx.Done(): 120 if errors.Is(ctx.Err(), context.Canceled) { 121 t.Logf("(expected) cancel called") 122 } else { 123 t.Errorf("cancel not called, error: %v", ctx.Err()) 124 } 125 case <-timer.C: 126 t.Error("ctx not done") 127 } 128 waitFinish <- struct{}{} 129 }() 130 return errors.New("error") 131 })(ctx, nil, nil) 132 133 // assert 134 test.Assert(t, err.Error() == "error", err) 135 <-waitFinish 136 }) 137 138 t.Run("finish_after_timeout_without_error", func(t *testing.T) { 139 // prepare 140 ctx := newCtxWithRPCInfo(time.Millisecond * 20) 141 waitFinish := make(chan struct{}) 142 143 // test 144 err := timeoutMW(func(ctx context.Context, req, resp interface{}) (err error) { 145 go func() { 146 timer := time.NewTimer(time.Millisecond * 60) 147 select { 148 case <-ctx.Done(): 149 if errors.Is(ctx.Err(), context.DeadlineExceeded) { 150 t.Logf("(expected) deadline exceeded") 151 } else { 152 t.Error("deadline not exceeded, error: ", ctx.Err()) 153 } 154 case <-timer.C: 155 t.Error("ctx not done") 156 } 157 waitFinish <- struct{}{} 158 }() 159 time.Sleep(time.Millisecond * 40) 160 return nil 161 })(ctx, nil, nil) 162 163 // assert 164 test.Assert(t, err == nil, err) 165 <-waitFinish 166 }) 167 168 t.Run("finish_after_timeout_with_error", func(t *testing.T) { 169 // prepare 170 ctx := newCtxWithRPCInfo(time.Millisecond * 20) 171 waitFinish := make(chan struct{}) 172 173 // test 174 err := timeoutMW(func(ctx context.Context, req, resp interface{}) (err error) { 175 go func() { 176 timer := time.NewTimer(time.Millisecond * 60) 177 select { 178 case <-ctx.Done(): 179 if errors.Is(ctx.Err(), context.DeadlineExceeded) { 180 t.Logf("(expected) deadline exceeded") 181 } else { 182 t.Error("deadline not exceeded, error: ", ctx.Err()) 183 } 184 case <-timer.C: 185 t.Error("ctx not done") 186 } 187 waitFinish <- struct{}{} 188 }() 189 time.Sleep(time.Millisecond * 40) 190 return errors.New("error") 191 })(ctx, nil, nil) 192 193 // assert 194 test.Assert(t, err.Error() == "error", err) 195 <-waitFinish 196 }) 197 }