github.com/nyan233/littlerpc@v0.4.6-0.20230316182519-0c8d5c48abaf/core/server/server_test.go (about)

     1  package server
     2  
     3  import (
     4  	"context"
     5  	"encoding/json"
     6  	"errors"
     7  	"fmt"
     8  	"github.com/nyan233/littlerpc/core/common/errorhandler"
     9  	"github.com/nyan233/littlerpc/core/common/metadata"
    10  	msgparser2 "github.com/nyan233/littlerpc/core/common/msgparser"
    11  	"github.com/nyan233/littlerpc/core/common/msgwriter"
    12  	transport2 "github.com/nyan233/littlerpc/core/common/transport"
    13  	"github.com/nyan233/littlerpc/core/container"
    14  	message2 "github.com/nyan233/littlerpc/core/protocol/message"
    15  	"github.com/nyan233/littlerpc/core/utils/random"
    16  	"github.com/nyan233/littlerpc/internal/pool"
    17  	"math"
    18  	"reflect"
    19  	"testing"
    20  	"time"
    21  )
    22  
    23  type testObject struct {
    24  	userName string
    25  	userId   int
    26  }
    27  
    28  func (t *testObject) SetUserName(ctx context.Context, userName string) error {
    29  	t.userName = userName
    30  	return nil
    31  }
    32  
    33  func (t *testObject) SetUserId(ctx context.Context, userId int) error {
    34  	t.userId = userId
    35  	return nil
    36  }
    37  
    38  func (t *testObject) GetUserId(ctx context.Context) (int, error) {
    39  	return t.userId, nil
    40  }
    41  
    42  func (t *testObject) GetUserName(ctx context.Context) (string, error) {
    43  	return t.userName, nil
    44  }
    45  
    46  func newTestServer(nilConn transport2.ConnAdapter) (*Server, error) {
    47  	server := &Server{
    48  		services: container.NewRCUMap[string, *metadata.Process](128),
    49  		sources:  container.NewRCUMap[string, *metadata.Source](128),
    50  	}
    51  	sc := new(Config)
    52  	WithDefaultServer()(sc)
    53  	server.config.Store(sc)
    54  	err := server.RegisterClass(ReflectionSource, new(LittleRpcReflection), nil)
    55  	if err != nil {
    56  		return nil, err
    57  	}
    58  	nilConn.SetSource(&connSourceDesc{
    59  		Parser: msgparser2.NewLRPCTrait(msgparser2.NewDefaultSimpleAllocTor(), 4096),
    60  		Writer: msgwriter.NewLRPCTrait(),
    61  	})
    62  	server.eHandle = sc.ErrHandler
    63  	server.taskPool = pool.NewTaskPool[string](sc.PoolBufferSize, sc.PoolMinSize, sc.PoolMaxSize, nil)
    64  	server.logger = &testLogger{logger: sc.Logger}
    65  	server.pManager = &pluginManager{plugins: sc.Plugins}
    66  	server.config.Store(sc)
    67  	return server, nil
    68  }
    69  
    70  func TestOnMessage(t *testing.T) {
    71  	nc := &transport2.NilConn{}
    72  	server, err := newTestServer(nc)
    73  	if err != nil {
    74  		t.Fatal(err)
    75  	}
    76  	obj := reflect.ValueOf(new(testObject))
    77  	err = server.RegisterClass("littlerpc/test/testObject", new(testObject), nil)
    78  	if err != nil {
    79  		t.Fatal(err)
    80  	}
    81  	// open debug
    82  	server.config.Load().Debug = true
    83  	msg := message2.New()
    84  	for i := 0; i < obj.NumMethod(); i++ {
    85  		msg.SetMsgType(message2.Call)
    86  		method := obj.Method(i)
    87  		msg.SetServiceName(fmt.Sprintf("littlerpc/test/testObject.%s", obj.Type().Method(i).Name))
    88  		for j := 1; j < method.Type().NumIn(); j++ {
    89  			payloads, err := baseTypeGenToJson(method.Type().In(j))
    90  			if err != nil {
    91  				t.Fatal(err)
    92  			}
    93  			msg.AppendPayloads(payloads)
    94  		}
    95  		var bytes container.Slice[byte]
    96  		err = message2.Marshal(msg, &bytes)
    97  		if err != nil {
    98  			t.Fatal(err)
    99  		}
   100  		func() {
   101  			defer func() {
   102  				if err := recover(); err != nil {
   103  					t.Fatal(err)
   104  				}
   105  			}()
   106  			server.onMessage(nc, bytes)
   107  			time.Sleep(time.Millisecond * 100)
   108  			msg.Reset()
   109  		}()
   110  	}
   111  	t.Run("TestOnMessageRecover", func(t *testing.T) {
   112  		func() {
   113  			defer server.eventLoopTopRecover(nc, nc.Source().(*connSourceDesc))
   114  			a := make([]int, 10)
   115  			a[100] = 1
   116  		}()
   117  		func() {
   118  			defer server.eventLoopTopRecover(nc, nc.Source().(*connSourceDesc))
   119  			panic("Hello world")
   120  		}()
   121  		func() {
   122  			defer server.eventLoopTopRecover(nc, nc.Source().(*connSourceDesc))
   123  			panic(errorhandler.ContextNotFound)
   124  		}()
   125  	})
   126  }
   127  
   128  func baseTypeGenToJson(typ reflect.Type) ([]byte, error) {
   129  	switch typ.Kind() {
   130  	case reflect.String:
   131  		return json.Marshal(random.GenStringOnAscii(300))
   132  	case reflect.Int64, reflect.Int, reflect.Int32:
   133  		return json.Marshal(random.FastRandN(math.MaxUint32 / 2))
   134  	default:
   135  		return nil, errors.New("no match for base type")
   136  	}
   137  }