github.com/cloudwego/kitex@v0.9.0/pkg/generic/http_test/generic_init.go (about)

     1  /*
     2   * Copyright 2021 CloudWeGo Authors
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   *     http://www.apache.org/licenses/LICENSE-2.0
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   */
    16  
    17  // Package test ...
    18  package test
    19  
    20  import (
    21  	"bytes"
    22  	"context"
    23  	"encoding/base64"
    24  	"errors"
    25  	"fmt"
    26  	"math"
    27  	"net"
    28  	"reflect"
    29  	"strconv"
    30  	"strings"
    31  	"time"
    32  
    33  	"github.com/tidwall/gjson"
    34  
    35  	"github.com/cloudwego/kitex/client"
    36  	"github.com/cloudwego/kitex/client/genericclient"
    37  	kt "github.com/cloudwego/kitex/internal/mocks/thrift"
    38  	"github.com/cloudwego/kitex/internal/test"
    39  	"github.com/cloudwego/kitex/pkg/generic"
    40  	"github.com/cloudwego/kitex/pkg/serviceinfo"
    41  	"github.com/cloudwego/kitex/server"
    42  	"github.com/cloudwego/kitex/server/genericserver"
    43  	"github.com/cloudwego/kitex/transport"
    44  )
    45  
    46  type Simple struct {
    47  	ByteField   int8    `thrift:"ByteField,1" json:"ByteField"`
    48  	I64Field    int64   `thrift:"I64Field,2" json:"I64Field"`
    49  	DoubleField float64 `thrift:"DoubleField,3" json:"DoubleField"`
    50  	I32Field    int32   `thrift:"I32Field,4" json:"I32Field"`
    51  	StringField string  `thrift:"StringField,5" json:"StringField"`
    52  	BinaryField []byte  `thrift:"BinaryField,6" json:"BinaryField"`
    53  }
    54  
    55  type Nesting struct {
    56  	String_         string             `thrift:"String,1" json:"String"`
    57  	ListSimple      []*Simple          `thrift:"ListSimple,2" json:"ListSimple"`
    58  	Double          float64            `thrift:"Double,3" json:"Double"`
    59  	I32             int32              `thrift:"I32,4" json:"I32"`
    60  	ListI32         []int32            `thrift:"ListI32,5" json:"ListI32"`
    61  	I64             int64              `thrift:"I64,6" json:"I64"`
    62  	MapStringString map[string]string  `thrift:"MapStringString,7" json:"MapStringString"`
    63  	SimpleStruct    *Simple            `thrift:"SimpleStruct,8" json:"SimpleStruct"`
    64  	MapI32I64       map[int32]int64    `thrift:"MapI32I64,9" json:"MapI32I64"`
    65  	ListString      []string           `thrift:"ListString,10" json:"ListString"`
    66  	Binary          []byte             `thrift:"Binary,11" json:"Binary"`
    67  	MapI64String    map[int64]string   `thrift:"MapI64String,12" json:"MapI64String"`
    68  	ListI64         []int64            `thrift:"ListI64,13" json:"ListI64"`
    69  	Byte            int8               `thrift:"Byte,14" json:"Byte"`
    70  	MapStringSimple map[string]*Simple `thrift:"MapStringSimple,15" json:"MapStringSimple"`
    71  }
    72  
    73  func getString() string {
    74  	return strings.Repeat("你好,\b\n\r\t世界", 2)
    75  }
    76  
    77  func getBytes() []byte {
    78  	return bytes.Repeat([]byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, 2)
    79  }
    80  
    81  func getSimpleValue() *Simple {
    82  	return &Simple{
    83  		ByteField:   math.MaxInt8,
    84  		I64Field:    math.MaxInt64,
    85  		DoubleField: math.MaxFloat64,
    86  		I32Field:    math.MaxInt32,
    87  		StringField: getString(),
    88  		BinaryField: getBytes(),
    89  	}
    90  }
    91  
    92  func getNestingValue() *Nesting {
    93  	ret := &Nesting{
    94  		String_:         getString(),
    95  		ListSimple:      []*Simple{},
    96  		Double:          math.MaxFloat64,
    97  		I32:             math.MaxInt32,
    98  		ListI32:         []int32{},
    99  		I64:             math.MaxInt64,
   100  		MapStringString: map[string]string{},
   101  		SimpleStruct:    getSimpleValue(),
   102  		MapI32I64:       map[int32]int64{},
   103  		ListString:      []string{},
   104  		Binary:          getBytes(),
   105  		MapI64String:    map[int64]string{},
   106  		ListI64:         []int64{},
   107  		Byte:            math.MaxInt8,
   108  		MapStringSimple: map[string]*Simple{},
   109  	}
   110  
   111  	for i := 0; i < 16; i++ {
   112  		ret.ListSimple = append(ret.ListSimple, getSimpleValue())
   113  		ret.ListI32 = append(ret.ListI32, math.MinInt32)
   114  		ret.ListI64 = append(ret.ListI64, math.MinInt64)
   115  		ret.ListString = append(ret.ListString, getString())
   116  	}
   117  
   118  	for i := 0; i < 16; i++ {
   119  		ret.MapStringString[strconv.Itoa(i)] = getString()
   120  		ret.MapI32I64[int32(i)] = math.MinInt64
   121  		ret.MapI64String[int64(i)] = getString()
   122  		ret.MapStringSimple[strconv.Itoa(i)] = getSimpleValue()
   123  	}
   124  
   125  	return ret
   126  }
   127  
   128  func newGenericClient(tp transport.Protocol, destService string, g generic.Generic, targetIPPort string) genericclient.Client {
   129  	var opts []client.Option
   130  	opts = append(opts, client.WithHostPorts(targetIPPort), client.WithTransportProtocol(tp))
   131  	genericCli, _ := genericclient.NewClient(destService, g, opts...)
   132  	return genericCli
   133  }
   134  
   135  func newGenericServer(g generic.Generic, addr net.Addr, handler generic.Service) server.Server {
   136  	var opts []server.Option
   137  	opts = append(opts, server.WithServiceAddr(addr), server.WithExitWaitTime(time.Microsecond*10))
   138  	svr := genericserver.NewServer(handler, g, opts...)
   139  	go func() {
   140  		err := svr.Run()
   141  		if err != nil {
   142  			panic(err)
   143  		}
   144  	}()
   145  	test.WaitServerStart(addr.String())
   146  	return svr
   147  }
   148  
   149  // GenericServiceReadRequiredFiledImpl ...
   150  type GenericServiceBinaryEchoImpl struct{}
   151  
   152  const mockMyMsg = "my msg"
   153  
   154  // GenericCall ...
   155  func (g *GenericServiceBinaryEchoImpl) GenericCall(ctx context.Context, method string, request interface{}) (response interface{}, err error) {
   156  	req := request.(map[string]interface{})
   157  	gotBase64 := req["got_base64"].(bool)
   158  	fmt.Printf("Recv: (%T)%s\n", req["msg"], req["msg"])
   159  	if !gotBase64 && req["msg"].(string) != mockMyMsg {
   160  		return nil, errors.New("call failed, msg type mismatch")
   161  	}
   162  	if gotBase64 && req["msg"].(string) != base64.StdEncoding.EncodeToString([]byte(mockMyMsg)) {
   163  		return nil, errors.New("call failed, incorrect base64 data")
   164  	}
   165  	num := req["num"].(string)
   166  	if num != "0" {
   167  		return nil, errors.New("call failed, incorrect num")
   168  	}
   169  	return req, nil
   170  }
   171  
   172  // GenericServiceBenchmarkImpl ...
   173  type GenericServiceBenchmarkImpl struct{}
   174  
   175  // GenericCall ...
   176  func (g *GenericServiceBenchmarkImpl) GenericCall(ctx context.Context, method string, request interface{}) (response interface{}, err error) {
   177  	return request, nil
   178  }
   179  
   180  // GenericServiceAnnotationImpl ...
   181  type GenericServiceAnnotationImpl struct{}
   182  
   183  // GenericCall ...
   184  func (g *GenericServiceAnnotationImpl) GenericCall(ctx context.Context, method string, request interface{}) (response interface{}, err error) {
   185  	// check request
   186  	req := request.(map[string]interface{})
   187  	if req["v_int64"] != int64(1) {
   188  		return nil, assertErr("v_int64", int64(1), req["v_int64"])
   189  	}
   190  	if req["text"] != "text" {
   191  		return nil, assertErr("text", "text", req["text"])
   192  	}
   193  	if req["token"] != int32(1) {
   194  		return nil, assertErr("token", int32(1), req["token"])
   195  	}
   196  	if req["api_version"] != int32(1) {
   197  		return nil, assertErr("api_version", int32(1), req["api_version"])
   198  	}
   199  	if req["uid"] != int64(1) {
   200  		return nil, assertErr("uid", int64(1), req["uid"])
   201  	}
   202  	if req["cookie"] != "cookie_val" {
   203  		return nil, assertErr("cookie", "cookie_val", req["cookie"])
   204  	}
   205  	if req["req_items_map"].(map[interface{}]interface{})[int64(1)].(map[string]interface{})["MyID"] != "1" {
   206  		return nil, assertErr("req_items_map/1/MyID", "1", req["req_items_map"].(map[interface{}]interface{})[int64(1)].(map[string]interface{})["MyID"])
   207  	}
   208  	if req["req_items_map"].(map[interface{}]interface{})[int64(1)].(map[string]interface{})["text"] != "text" {
   209  		return nil, assertErr("req_items_map/1/text", "text", req["req_items_map"].(map[interface{}]interface{})[int64(1)].(map[string]interface{})["text"])
   210  	}
   211  	if req["some"].(map[string]interface{})["MyID"] != "1" {
   212  		return nil, assertErr("some/MyID", "1", req["some"].(map[string]interface{})["MyID"])
   213  	}
   214  	if req["some"].(map[string]interface{})["text"] != "text" {
   215  		return nil, assertErr("some/text", "text", req["some"].(map[string]interface{})["text"])
   216  	}
   217  	if !reflect.DeepEqual(req["req_items"], []interface{}{"item1", "item2", "item3"}) {
   218  		return nil, assertErr("req_items", []interface{}{"item1", "item2", "item3"}, req["req_items"])
   219  	}
   220  	if !reflect.DeepEqual(req["cids"], []interface{}{int64(1), int64(2), int64(3)}) {
   221  		return nil, assertErr("cids", []interface{}{int64(1), int64(2), int64(3)}, req["cids"])
   222  	}
   223  	if !reflect.DeepEqual(req["vids"], []interface{}{"1", "2", "3"}) {
   224  		return nil, assertErr("vids", []interface{}{"1", "2", "3"}, req["vids"])
   225  	}
   226  
   227  	resp := map[string]interface{}{
   228  		"rsp_items": map[interface{}]interface{}{
   229  			int64(1): map[string]interface{}{
   230  				"item_id": int64(1),
   231  				"text":    "1",
   232  			},
   233  		},
   234  		"v_enum": int32(1),
   235  		"rsp_item_list": []interface{}{
   236  			map[string]interface{}{
   237  				"item_id": int64(1),
   238  				"text":    "1",
   239  			},
   240  		},
   241  		"http_code": int32(1),
   242  		"b":         true,
   243  		"eight":     int8(8),
   244  		"sixteen":   int16(16),
   245  		"thirtytwo": int32(32),
   246  		"sixtyfour": int64(64),
   247  		"d":         float64(123.45),
   248  		"T":         "1",
   249  		"item_count": []interface{}{
   250  			int64(1), int64(2), int64(3),
   251  		},
   252  		"header_map": map[interface{}]interface{}{
   253  			"map1": int64(1),
   254  			"map2": int64(2),
   255  		},
   256  		"header_struct": map[string]interface{}{
   257  			"item_id": int64(1),
   258  			"text":    "1",
   259  		},
   260  		"string_set": []interface{}{"a", "b", "c"},
   261  	}
   262  	return resp, nil
   263  }
   264  
   265  func assertErr(field string, expected, actual interface{}) error {
   266  	return fmt.Errorf("field name: %s, expected: %#v, but get: %#v", field, expected, actual)
   267  }
   268  
   269  var (
   270  	mockReq  = `{"Msg":"hello","strMap":{"mk1":"mv1","mk2":"mv2"},"strList":["lv1","lv2"]} `
   271  	mockResp = "this is response"
   272  )
   273  
   274  // normal server
   275  func newMockServer(handler kt.Mock, addr net.Addr, opts ...server.Option) server.Server {
   276  	var options []server.Option
   277  	opts = append(opts, server.WithServiceAddr(addr), server.WithExitWaitTime(time.Millisecond*10))
   278  	options = append(options, opts...)
   279  
   280  	svr := server.NewServer(options...)
   281  	if err := svr.RegisterService(serviceInfo(), handler); err != nil {
   282  		panic(err)
   283  	}
   284  	go func() {
   285  		err := svr.Run()
   286  		if err != nil {
   287  			panic(err)
   288  		}
   289  	}()
   290  	test.WaitServerStart(addr.String())
   291  	return svr
   292  }
   293  
   294  func serviceInfo() *serviceinfo.ServiceInfo {
   295  	destService := "Mock"
   296  	handlerType := (*kt.Mock)(nil)
   297  	methods := map[string]serviceinfo.MethodInfo{
   298  		"Test":          serviceinfo.NewMethodInfo(testHandler, newMockTestArgs, newMockTestResult, false),
   299  		"ExceptionTest": serviceinfo.NewMethodInfo(exceptionHandler, newMockExceptionTestArgs, newMockExceptionTestResult, false),
   300  	}
   301  	svcInfo := &serviceinfo.ServiceInfo{
   302  		ServiceName: destService,
   303  		HandlerType: handlerType,
   304  		Methods:     methods,
   305  		Extra:       make(map[string]interface{}),
   306  	}
   307  	return svcInfo
   308  }
   309  
   310  func newMockTestArgs() interface{} {
   311  	return kt.NewMockTestArgs()
   312  }
   313  
   314  func newMockTestResult() interface{} {
   315  	return kt.NewMockTestResult()
   316  }
   317  
   318  func testHandler(ctx context.Context, handler, arg, result interface{}) error {
   319  	realArg := arg.(*kt.MockTestArgs)
   320  	realResult := result.(*kt.MockTestResult)
   321  	success, err := handler.(kt.Mock).Test(ctx, realArg.Req)
   322  	if err != nil {
   323  		return err
   324  	}
   325  	realResult.Success = &success
   326  	return nil
   327  }
   328  
   329  func newMockExceptionTestArgs() interface{} {
   330  	return kt.NewMockExceptionTestArgs()
   331  }
   332  
   333  func newMockExceptionTestResult() interface{} {
   334  	return &kt.MockExceptionTestResult{}
   335  }
   336  
   337  func exceptionHandler(ctx context.Context, handler, args, result interface{}) error {
   338  	a := args.(*kt.MockExceptionTestArgs)
   339  	r := result.(*kt.MockExceptionTestResult)
   340  	reply, err := handler.(kt.Mock).ExceptionTest(ctx, a.Req)
   341  	if err != nil {
   342  		switch v := err.(type) {
   343  		case *kt.Exception:
   344  			r.Err = v
   345  		default:
   346  			return err
   347  		}
   348  	} else {
   349  		r.Success = &reply
   350  	}
   351  	return nil
   352  }
   353  
   354  type mockImpl struct{}
   355  
   356  // Test ...
   357  func (m *mockImpl) Test(ctx context.Context, req *kt.MockReq) (r string, err error) {
   358  	msg := gjson.Get(mockReq, "Msg")
   359  	if req.Msg != msg.String() {
   360  		return "", fmt.Errorf("msg is not %s", mockReq)
   361  	}
   362  	strMap := gjson.Get(mockReq, "strMap")
   363  	if len(strMap.Map()) == 0 {
   364  		return "", fmt.Errorf("strmsg is not map[interface{}]interface{}")
   365  	}
   366  	for k, v := range strMap.Map() {
   367  		if req.StrMap[k] != v.String() {
   368  			return "", fmt.Errorf("strMsg is not %s", req.StrMap)
   369  		}
   370  	}
   371  
   372  	strList := gjson.Get(mockReq, "strList")
   373  	array := strList.Array()
   374  	if len(array) == 0 {
   375  		return "", fmt.Errorf("strlist is not %v", strList)
   376  	}
   377  	for idx := range array {
   378  		if array[idx].Value() != req.StrList[idx] {
   379  			return "", fmt.Errorf("strlist is not %s", mockReq)
   380  		}
   381  	}
   382  	return mockResp, nil
   383  }
   384  
   385  // ExceptionTest ...
   386  func (m *mockImpl) ExceptionTest(ctx context.Context, req *kt.MockReq) (r string, err error) {
   387  	msg := gjson.Get(mockReq, "Msg")
   388  	if req.Msg != msg.String() {
   389  		return "", fmt.Errorf("msg is not %s", mockReq)
   390  	}
   391  	strMap := gjson.Get(mockReq, "strMap")
   392  	if len(strMap.Map()) == 0 {
   393  		return "", fmt.Errorf("strmsg is not map[interface{}]interface{}")
   394  	}
   395  	for k, v := range strMap.Map() {
   396  		if req.StrMap[k] != v.String() {
   397  			return "", fmt.Errorf("strMsg is not %s", req.StrMap)
   398  		}
   399  	}
   400  
   401  	strList := gjson.Get(mockReq, "strList")
   402  	array := strList.Array()
   403  	if len(array) == 0 {
   404  		return "", fmt.Errorf("strlist is not %v", strList)
   405  	}
   406  	for idx := range array {
   407  		if array[idx].Value() != req.StrList[idx] {
   408  			return "", fmt.Errorf("strlist is not %s", mockReq)
   409  		}
   410  	}
   411  	return "", &kt.Exception{Code: 400, Msg: "this is an exception"}
   412  }