github.com/cloudwego/kitex@v0.9.0/server/service_test.go (about)

     1  /*
     2   * Copyright 2024 CloudWeGo Authors
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   *     http://www.apache.org/licenses/LICENSE-2.0
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   */
    16  
    17  package server
    18  
    19  import (
    20  	"fmt"
    21  	"testing"
    22  
    23  	"github.com/cloudwego/kitex/internal/mocks"
    24  	"github.com/cloudwego/kitex/internal/test"
    25  	"github.com/cloudwego/kitex/pkg/serviceinfo"
    26  )
    27  
    28  func TestAddService(t *testing.T) {
    29  	svcs := newServices()
    30  	err := svcs.addService(mocks.ServiceInfo(), mocks.MyServiceHandler(), &RegisterOptions{})
    31  	test.Assert(t, err == nil)
    32  	test.Assert(t, len(svcs.svcMap) == 1)
    33  	fmt.Println(svcs.svcSearchMap)
    34  	test.Assert(t, len(svcs.svcSearchMap) == 10)
    35  	test.Assert(t, len(svcs.conflictingMethodHasFallbackSvcMap) == 0)
    36  	test.Assert(t, svcs.fallbackSvc == nil)
    37  
    38  	err = svcs.addService(mocks.Service3Info(), mocks.MyServiceHandler(), &RegisterOptions{IsFallbackService: true})
    39  	test.Assert(t, err == nil)
    40  	test.Assert(t, len(svcs.svcMap) == 2)
    41  	test.Assert(t, len(svcs.svcSearchMap) == 11)
    42  	test.Assert(t, len(svcs.conflictingMethodHasFallbackSvcMap) == 1)
    43  	test.Assert(t, svcs.conflictingMethodHasFallbackSvcMap["mock"])
    44  
    45  	err = svcs.addService(mocks.Service2Info(), mocks.MyServiceHandler(), &RegisterOptions{IsFallbackService: true})
    46  	test.Assert(t, err != nil)
    47  	test.Assert(t, err.Error() == "multiple fallback services cannot be registered. [MockService3] is already registered as a fallback service")
    48  }
    49  
    50  func TestCheckCombineServiceWithOtherService(t *testing.T) {
    51  	svcs := newServices()
    52  	combineSvcInfo := &serviceinfo.ServiceInfo{ServiceName: "CombineService"}
    53  	svcs.svcMap[combineSvcInfo.ServiceName] = newService(combineSvcInfo, nil)
    54  	err := svcs.checkCombineServiceWithOtherService(mocks.ServiceInfo())
    55  	test.Assert(t, err != nil)
    56  	test.Assert(t, err.Error() == "only one service can be registered when registering combine service")
    57  
    58  	svcs = newServices()
    59  	svcs.svcMap[mocks.MockServiceName] = newService(mocks.ServiceInfo(), mocks.MyServiceHandler())
    60  	err = svcs.checkCombineServiceWithOtherService(combineSvcInfo)
    61  	test.Assert(t, err != nil)
    62  	test.Assert(t, err.Error() == "only one service can be registered when registering combine service")
    63  }
    64  
    65  func TestCheckMultipleFallbackService(t *testing.T) {
    66  	svcs := newServices()
    67  	svc := newService(mocks.ServiceInfo(), mocks.MyServiceHandler())
    68  	registerOpts := &RegisterOptions{IsFallbackService: true}
    69  	err := svcs.checkMultipleFallbackService(registerOpts, svc)
    70  	test.Assert(t, err == nil)
    71  	test.Assert(t, svcs.fallbackSvc == svc)
    72  
    73  	err = svcs.checkMultipleFallbackService(registerOpts, newService(mocks.Service2Info(), nil))
    74  	test.Assert(t, err != nil)
    75  	test.Assert(t, err.Error() == "multiple fallback services cannot be registered. [MockService] is already registered as a fallback service", err)
    76  }
    77  
    78  func TestRegisterConflictingMethodHasFallbackSvcMap(t *testing.T) {
    79  	svcs := newServices()
    80  	svcFromMap := newService(mocks.ServiceInfo(), mocks.MyServiceHandler())
    81  	svcs.registerConflictingMethodHasFallbackSvcMap(svcFromMap, mocks.MockMethod)
    82  	test.Assert(t, !svcs.conflictingMethodHasFallbackSvcMap[mocks.MockMethod])
    83  
    84  	svcs = newServices()
    85  	svcs.fallbackSvc = svcFromMap
    86  	svcs.registerConflictingMethodHasFallbackSvcMap(svcFromMap, mocks.MockMethod)
    87  	test.Assert(t, svcs.conflictingMethodHasFallbackSvcMap[mocks.MockMethod])
    88  }