github.com/cloudwego/kitex@v0.9.0/server/middlewares_test.go (about)

     1  /*
     2   * Copyright 2024 CloudWeGo Authors
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   *     http://www.apache.org/licenses/LICENSE-2.0
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   */
    16  
    17  package server
    18  
    19  import (
    20  	"context"
    21  	"errors"
    22  	"net"
    23  	"testing"
    24  	"time"
    25  
    26  	"github.com/cloudwego/kitex/internal/test"
    27  	"github.com/cloudwego/kitex/pkg/rpcinfo"
    28  )
    29  
    30  var _ context.Context = (*mockCtx)(nil)
    31  
    32  type mockCtx struct {
    33  	err    error
    34  	ddl    time.Time
    35  	hasDDL bool
    36  	done   chan struct{}
    37  	data   map[interface{}]interface{}
    38  }
    39  
    40  func (m *mockCtx) Deadline() (deadline time.Time, ok bool) {
    41  	return m.ddl, m.hasDDL
    42  }
    43  
    44  func (m *mockCtx) Done() <-chan struct{} {
    45  	return m.done
    46  }
    47  
    48  func (m *mockCtx) Err() error {
    49  	return m.err
    50  }
    51  
    52  func (m *mockCtx) Value(key interface{}) interface{} {
    53  	return m.data[key]
    54  }
    55  
    56  func Test_serverTimeoutMW(t *testing.T) {
    57  	addr, _ := net.ResolveTCPAddr("tcp", "127.0.0.1:8080")
    58  	from := rpcinfo.NewEndpointInfo("from_service", "from_method", addr, nil)
    59  	to := rpcinfo.NewEndpointInfo("to_service", "to_method", nil, nil)
    60  	newCtxWithRPCInfo := func(timeout time.Duration) context.Context {
    61  		cfg := rpcinfo.NewRPCConfig()
    62  		_ = rpcinfo.AsMutableRPCConfig(cfg).SetRPCTimeout(timeout)
    63  		ri := rpcinfo.NewRPCInfo(from, to, nil, cfg, nil)
    64  		return rpcinfo.NewCtxWithRPCInfo(context.Background(), ri)
    65  	}
    66  	timeoutMW := serverTimeoutMW(context.Background())
    67  
    68  	t.Run("no_timeout(fastPath)", func(t *testing.T) {
    69  		// prepare
    70  		ctx := newCtxWithRPCInfo(0)
    71  
    72  		// test
    73  		err := timeoutMW(func(ctx context.Context, req, resp interface{}) (err error) {
    74  			ddl, ok := ctx.Deadline()
    75  			test.Assert(t, !ok)
    76  			test.Assert(t, ddl.IsZero())
    77  			return nil
    78  		})(ctx, nil, nil)
    79  
    80  		// assert
    81  		test.Assert(t, err == nil, err)
    82  	})
    83  
    84  	t.Run("finish_before_timeout_without_error", func(t *testing.T) {
    85  		// prepare
    86  		ctx := newCtxWithRPCInfo(time.Millisecond * 50)
    87  		waitFinish := make(chan struct{})
    88  
    89  		// test
    90  		err := timeoutMW(func(ctx context.Context, req, resp interface{}) (err error) {
    91  			go func() {
    92  				timer := time.NewTimer(time.Millisecond * 20)
    93  				select {
    94  				case <-ctx.Done():
    95  					t.Errorf("ctx done, error: %v", ctx.Err())
    96  				case <-timer.C:
    97  					t.Logf("(expected) ctx not done")
    98  				}
    99  				waitFinish <- struct{}{}
   100  			}()
   101  			return nil
   102  		})(ctx, nil, nil)
   103  
   104  		// assert
   105  		test.Assert(t, err == nil, err)
   106  		<-waitFinish
   107  	})
   108  
   109  	t.Run("finish_before_timeout_with_error", func(t *testing.T) {
   110  		// prepare
   111  		ctx := newCtxWithRPCInfo(time.Millisecond * 50)
   112  		waitFinish := make(chan struct{})
   113  
   114  		// test
   115  		err := timeoutMW(func(ctx context.Context, req, resp interface{}) (err error) {
   116  			go func() {
   117  				timer := time.NewTimer(time.Millisecond * 20)
   118  				select {
   119  				case <-ctx.Done():
   120  					if errors.Is(ctx.Err(), context.Canceled) {
   121  						t.Logf("(expected) cancel called")
   122  					} else {
   123  						t.Errorf("cancel not called, error: %v", ctx.Err())
   124  					}
   125  				case <-timer.C:
   126  					t.Error("ctx not done")
   127  				}
   128  				waitFinish <- struct{}{}
   129  			}()
   130  			return errors.New("error")
   131  		})(ctx, nil, nil)
   132  
   133  		// assert
   134  		test.Assert(t, err.Error() == "error", err)
   135  		<-waitFinish
   136  	})
   137  
   138  	t.Run("finish_after_timeout_without_error", func(t *testing.T) {
   139  		// prepare
   140  		ctx := newCtxWithRPCInfo(time.Millisecond * 20)
   141  		waitFinish := make(chan struct{})
   142  
   143  		// test
   144  		err := timeoutMW(func(ctx context.Context, req, resp interface{}) (err error) {
   145  			go func() {
   146  				timer := time.NewTimer(time.Millisecond * 60)
   147  				select {
   148  				case <-ctx.Done():
   149  					if errors.Is(ctx.Err(), context.DeadlineExceeded) {
   150  						t.Logf("(expected) deadline exceeded")
   151  					} else {
   152  						t.Error("deadline not exceeded, error: ", ctx.Err())
   153  					}
   154  				case <-timer.C:
   155  					t.Error("ctx not done")
   156  				}
   157  				waitFinish <- struct{}{}
   158  			}()
   159  			time.Sleep(time.Millisecond * 40)
   160  			return nil
   161  		})(ctx, nil, nil)
   162  
   163  		// assert
   164  		test.Assert(t, err == nil, err)
   165  		<-waitFinish
   166  	})
   167  
   168  	t.Run("finish_after_timeout_with_error", func(t *testing.T) {
   169  		// prepare
   170  		ctx := newCtxWithRPCInfo(time.Millisecond * 20)
   171  		waitFinish := make(chan struct{})
   172  
   173  		// test
   174  		err := timeoutMW(func(ctx context.Context, req, resp interface{}) (err error) {
   175  			go func() {
   176  				timer := time.NewTimer(time.Millisecond * 60)
   177  				select {
   178  				case <-ctx.Done():
   179  					if errors.Is(ctx.Err(), context.DeadlineExceeded) {
   180  						t.Logf("(expected) deadline exceeded")
   181  					} else {
   182  						t.Error("deadline not exceeded, error: ", ctx.Err())
   183  					}
   184  				case <-timer.C:
   185  					t.Error("ctx not done")
   186  				}
   187  				waitFinish <- struct{}{}
   188  			}()
   189  			time.Sleep(time.Millisecond * 40)
   190  			return errors.New("error")
   191  		})(ctx, nil, nil)
   192  
   193  		// assert
   194  		test.Assert(t, err.Error() == "error", err)
   195  		<-waitFinish
   196  	})
   197  }