github.com/cloudwego/kitex@v0.9.0/pkg/generic/thrift/parse_test.go (about)

     1  /*
     2   * Copyright 2021 CloudWeGo Authors
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   *     http://www.apache.org/licenses/LICENSE-2.0
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   */
    16  
    17  package thrift
    18  
    19  import (
    20  	"reflect"
    21  	"strings"
    22  	"testing"
    23  
    24  	"github.com/cloudwego/thriftgo/parser"
    25  
    26  	"github.com/cloudwego/kitex/internal/test"
    27  	"github.com/cloudwego/kitex/pkg/generic/descriptor"
    28  )
    29  
    30  var httpIDL = `
    31  namespace go http
    32  
    33  struct Item{
    34      1: optional i64 id(go.tag = "json:\"id\"")
    35      2: optional string text
    36  }
    37  
    38  struct JsonDict{
    39  }
    40  
    41  struct ReqItem{
    42      1: optional i64 id(go.tag = "json:\"id\"")
    43      2: optional string text
    44  }
    45  
    46  
    47  struct BizRequest {
    48      1: optional i64 v_int64(api.query = 'v_int64', api.vd = "$>0&&$<200")
    49      2: optional string text(api.body = 'text')
    50      3: optional i32 token(api.header = 'token')
    51      4: optional JsonDict json_header(api.header = 'json_header')
    52      5: optional Item some(api.body = 'some')
    53      6: optional list<ReqItem> req_items(api.query = 'req_items')
    54      7: optional i32 api_version(api.path = 'action')
    55      8: optional i64 uid(api.path = 'biz')
    56      9: optional list<i64> cids(api.query = 'cids')
    57      10: optional list<string> vids(api.query = 'vids')
    58      
    59      11: required string required_filed
    60  }
    61  
    62  struct RspItem{
    63      1: optional i64 item_id 
    64      2: optional string text
    65  }
    66  struct BizResponse {
    67      1: optional string T                             (api.header= 'T') 
    68      2: optional map<i64, RspItem> rsp_items           (api.body='rsp_items')
    69      3: optional i32 v_enum                       (api.none = '')
    70      4: optional list<RspItem> rsp_item_list            (api.body = 'rsp_item_list')
    71      5: optional i32 http_code                         (api.http_code = '') 
    72      6: optional list<i64> item_count (api.header = 'item_count')
    73  }
    74  
    75  exception Exception{
    76      1: i32 code
    77      2: string msg
    78  }
    79  
    80  service BizService{
    81  BizResponse BizMethod1(1: BizRequest req) throws (1: Exception err) (api.get = '/life/client/:action/:biz', api.baseurl = 'ib.snssdk.com', api.param = 'true') 
    82  BizResponse BizMethod2(1: BizRequest req)(api.post = '/life/client/:action/:biz', api.baseurl = 'ib.snssdk.com', api.param = 'true', api.serializer = 'form')
    83  BizResponse BizMethod3(1: BizRequest req)(api.post = '/life/client/:action/:biz/more', api.baseurl = 'ib.snssdk.com', api.param = 'true', api.serializer = 'json')
    84  }`
    85  
    86  func TestParseHttpIDL(t *testing.T) {
    87  	re, err := parser.ParseString("http.thrift", httpIDL)
    88  	test.Assert(t, err == nil, err)
    89  	svc, err := Parse(re, LastServiceOnly)
    90  	test.Assert(t, err == nil)
    91  	test.Assert(t, len(svc.Functions) == 3)
    92  	bizMethod1 := svc.Functions["BizMethod1"]
    93  	test.Assert(t, bizMethod1.Request.Type == descriptor.STRUCT)
    94  	test.Assert(t, len(bizMethod1.Request.Struct.FieldsByName["req"].Type.Struct.RequiredFields) == 1)
    95  	test.Assert(t, bizMethod1.Request.Struct.FieldsByName["req"].Type.Struct.FieldsByName["required_filed"].Required)
    96  	test.Assert(t, len(bizMethod1.Request.Struct.FieldsByID) == 1)
    97  	test.Assert(t, bizMethod1.Response.Type == descriptor.STRUCT)
    98  	test.Assert(t, len(bizMethod1.Response.Struct.FieldsByID) == 2)
    99  }
   100  
   101  var httpConflictPathIDL = `
   102  namespace go http
   103  
   104  struct BizRequest {
   105      1: optional i32 api_version(api.path = 'action')
   106      2: optional i64 uid(api.path = 'biz')
   107  }
   108  
   109  struct BizResponse {
   110      1: optional string T(api.header= 'T')
   111  }
   112  
   113  service BizService{
   114  BizResponse BizMethod1(1: BizRequest req)(api.post = '/life/client/:action/:biz/*one', api.baseurl = 'ib.snssdk.com', api.param = 'true', api.serializer = 'form')
   115  BizResponse BizMethod2(1: BizRequest req)(api.post = '/life/client/:action/:biz/*two', api.baseurl = 'ib.snssdk.com', api.param = 'true', api.serializer = 'json')
   116  BizResponse BizMethod3(1: BizRequest req)(api.post = '/life/client/:action/:biz/*three', api.baseurl = 'ib.snssdk.com', api.param = 'true', api.serializer = 'json')
   117  }`
   118  
   119  func TestPanicRecover(t *testing.T) {
   120  	re, err := parser.ParseString("http.thrift", httpConflictPathIDL)
   121  	test.Assert(t, err == nil, err)
   122  	_, err = Parse(re, LastServiceOnly)
   123  	test.Assert(t, err != nil)
   124  	test.DeepEqual(t, err.Error(), "router handle failed, err=handlers are already registered for path '/life/client/:action/:biz/*two'")
   125  }
   126  
   127  var selfReferenceIDL = `
   128  namespace go http
   129  
   130  struct A {
   131      1: A self
   132  }
   133  
   134  service B{
   135      A Foo(1: A req) 
   136      string Ping(1: string msg)
   137  }
   138  `
   139  
   140  func TestSelfReferenceParse(t *testing.T) {
   141  	re, err := parser.ParseString("a.thrift", selfReferenceIDL)
   142  	test.Assert(t, err == nil, err)
   143  	svc, err := Parse(re, LastServiceOnly)
   144  	test.Assert(t, err == nil, err)
   145  	for _, fn := range svc.Functions {
   146  		for _, i := range fn.Request.Struct.FieldsByID {
   147  			_ = i
   148  		}
   149  	}
   150  	_ = svc
   151  }
   152  
   153  var notSupportAnnotationIDL = `
   154  namespace go http
   155  
   156  struct A {
   157      1: A self (whoami = '')
   158  }
   159  
   160  service B{
   161      A Foo(1: A req)
   162      string Ping(1: string msg)
   163  }
   164  `
   165  
   166  type notSupportAnnotation struct{}
   167  
   168  func (notSupportAnnotation) Equal(key, value string) bool {
   169  	return key == "whoami"
   170  }
   171  
   172  func (notSupportAnnotation) Handle() interface{} {
   173  	return 1
   174  }
   175  
   176  func TestNotSupportAnnotation(t *testing.T) {
   177  	descriptor.RegisterAnnotation(notSupportAnnotation{})
   178  	re, err := parser.ParseString("a.thrift", notSupportAnnotationIDL)
   179  	test.Assert(t, err == nil)
   180  	_, err = Parse(re, LastServiceOnly)
   181  	test.Assert(t, err != nil)
   182  }
   183  
   184  var multiServicesIDL = `
   185  namespace go test
   186  
   187  service A {
   188  	string method1(1: string req)
   189  }
   190  
   191  service B {
   192  	string method2(1: string req)
   193  }
   194  `
   195  
   196  func TestCombineService(t *testing.T) {
   197  	content := multiServicesIDL
   198  	tree, err := parser.ParseString("a.thrift", content)
   199  	test.Assert(t, err == nil)
   200  
   201  	dp, err := Parse(tree, LastServiceOnly)
   202  	test.Assert(t, err == nil)
   203  	test.Assert(t, dp.Name == "B")
   204  	test.Assert(t, len(dp.Functions) == 1 && dp.Functions["method2"] != nil)
   205  
   206  	dp, err = Parse(tree, CombineServices)
   207  	test.Assert(t, err == nil)
   208  	test.Assert(t, dp.Name == "CombinedServices")
   209  	test.Assert(t, len(dp.Functions) == 2 && dp.Functions["method1"] != nil && dp.Functions["method2"] != nil)
   210  
   211  	content = strings.ReplaceAll(content, "method2", "method1")
   212  	tree, err = parser.ParseString("a.thrift", content)
   213  	test.Assert(t, err == nil)
   214  
   215  	dp, err = Parse(tree, CombineServices)
   216  	test.Assert(t, err != nil && dp == nil)
   217  	test.Assert(t, err.Error() == "duplicate method name: method1")
   218  }
   219  
   220  const baseIDL = `
   221  namespace * base
   222  
   223  struct BaseStruct {
   224      1: string def
   225      2: required string req
   226      3: optional string opt
   227  }
   228  
   229  service BaseService {
   230      BaseStruct simple(1: BaseStruct req)
   231  }
   232  `
   233  
   234  const demoIDL = `
   235  namespace * demo
   236  
   237  include "base.thrift"
   238  
   239  struct Request {
   240      1: string str
   241      2: required string str2
   242      3: optional string str3
   243  }
   244  
   245  struct Response {
   246      1: string str
   247      2: required string str2
   248      3: optional string str3
   249  }
   250  
   251  exception Exception {
   252      1: string error
   253      2: required string error2
   254      3: optional string error3
   255  }
   256  
   257  service DemoBaseService extends base.BaseService {
   258      binary int2bin(1: i32 arg);
   259  }
   260  
   261  service DemoService extends DemoBaseService {
   262      Response req2res(1: Request req);
   263  }
   264  
   265  `
   266  
   267  func TestExtends(t *testing.T) {
   268  	demo, err := parser.ParseString("demo.thrift", demoIDL)
   269  	test.Assert(t, err == nil)
   270  
   271  	base, err := parser.ParseString("base.thrift", baseIDL)
   272  	test.Assert(t, err == nil)
   273  
   274  	demo.Includes[0].Reference = base
   275  
   276  	dp, err := Parse(demo, LastServiceOnly)
   277  	test.Assert(t, err == nil)
   278  	test.Assert(t, dp.Name == "DemoService")
   279  	test.Assert(t, len(dp.Functions) == 3)
   280  	for _, fn := range []string{"simple", "int2bin", "req2res"} {
   281  		test.Assert(t, dp.Functions[fn] != nil)
   282  	}
   283  
   284  	dp, err = Parse(demo, FirstServiceOnly)
   285  	test.Assert(t, err == nil)
   286  	test.Assert(t, dp.Name == "DemoBaseService")
   287  	test.Assert(t, len(dp.Functions) == 2)
   288  	for _, fn := range []string{"simple", "int2bin"} {
   289  		test.Assert(t, dp.Functions[fn] != nil)
   290  	}
   291  
   292  	dp, err = Parse(demo, CombineServices)
   293  	test.Assert(t, err == nil)
   294  	test.Assert(t, dp.Name == "CombinedServices")
   295  	test.Assert(t, len(dp.Functions) == 3)
   296  	for _, fn := range []string{"simple", "int2bin", "req2res"} {
   297  		test.Assert(t, dp.Functions[fn] != nil)
   298  	}
   299  }
   300  
   301  const defaultValueBaseIDL = `
   302  namespace * base
   303  
   304  enum TestEnum {
   305      FIRST = 1,
   306      SECOND = 2,
   307      THIRD = 3,
   308      FOURTH = 4,
   309  }
   310  
   311  struct BaseElem {
   312  	1: optional i32 a,
   313  }
   314  
   315  const BaseElem defaultBaseElem = {
   316  	"a": 56
   317  }
   318  
   319  struct BaseStruct {
   320      1: optional bool a,
   321      2: optional byte b,
   322      3: optional i16 c,
   323  	4: optional i32 d,
   324  	5: optional i64 e,
   325  	6: optional double f,
   326  	7: optional string g,
   327  	8: optional binary h,
   328  	9: optional TestEnum i,
   329  	10: optional set<string> j,
   330  	11: optional list<BaseElem> k,
   331  	12: optional map<string, i32> l,
   332  	13: optional map<i32, BaseElem> m,
   333  	14: optional BaseElem n,
   334  }
   335  
   336  const set<string> defaultJ = ["123","456"]
   337  
   338  const map<string, i32> defaultL = {
   339  	"123": 12,
   340  	"456": 45,
   341  }
   342  
   343  const BaseStruct defaultBase = {
   344  	"a": true,
   345  	"b": 1,
   346  	"c": 2,
   347  	"d": 3,
   348  	"e": 4,
   349  	"f": 5.1,
   350  	"g": "123",
   351  	"h": "456",
   352  	"i": TestEnum.THIRD,
   353  	"j": defaultJ,
   354  	"k": [{"a": 34}, defaultBaseElem],
   355  	"l": defaultL,
   356  	"m": {
   357  		12: {
   358  			"a": 34,
   359  		},
   360  	},
   361  	"n": {
   362  		"a": 56,
   363  	},
   364  }
   365  `
   366  
   367  const defaultValueDemoIDL = `
   368  namespace * demo
   369  
   370  include "base.thrift"
   371  
   372  const byte DefaultB = 1
   373  const i16 DefaultC = 2
   374  const i32 DefaultD = 3
   375  const i64 DefaultE = 4
   376  
   377  struct Request {
   378  	1: optional bool a = true,
   379  	2: optional byte b = DefaultB,
   380  	3: optional i16 c = DefaultC,
   381  	4: optional i32 d = DefaultD,
   382  	5: optional i64 e = DefaultE,
   383  	6: optional double f = 5.1,
   384  	7: optional string g = "123",
   385  	8: optional binary h = "456",
   386  	9: optional base.TestEnum i = base.TestEnum.THIRD,
   387  	10: optional set<string> j = base.defaultJ,
   388  	11: optional list<base.BaseElem> k = [{"a": 34}, base.defaultBaseElem],
   389  	12: optional map<string, i32> l = base.defaultL,
   390  	13: optional map<i32, base.BaseElem> m = {
   391  		12: {
   392  			"a": 34,
   393  		},
   394  	},
   395  	14: optional base.BaseElem n = {
   396  		"a": 56,
   397  	},
   398  	15: optional base.BaseStruct base = base.defaultBase,
   399  	16: optional list<base.BaseStruct> bases = [base.defaultBase],
   400  }
   401  
   402  struct Response {
   403  }
   404  
   405  service DemoService {
   406      Response req2res(1: Request req);
   407  }
   408  `
   409  
   410  func TestDefaultValue(t *testing.T) {
   411  	demo, err := parser.ParseString("demo.thrift", defaultValueDemoIDL)
   412  	test.Assert(t, err == nil)
   413  
   414  	base, err := parser.ParseString("base.thrift", defaultValueBaseIDL)
   415  	test.Assert(t, err == nil)
   416  
   417  	demo.Includes[0].Reference = base
   418  
   419  	dp, err := Parse(demo, DefaultParseMode())
   420  	test.Assert(t, err == nil, err)
   421  
   422  	fun, err := dp.LookupFunctionByMethod("req2res")
   423  	test.Assert(t, err == nil)
   424  
   425  	defaultValueDeepEqual(t, func(name string) interface{} {
   426  		return fun.Request.Struct.FieldsByName["req"].Type.Struct.FieldsByName[name].DefaultValue
   427  	})
   428  
   429  	defaultBase := fun.Request.Struct.FieldsByName["req"].Type.Struct.FieldsByName["base"].DefaultValue.(map[string]interface{})
   430  	defaultValueDeepEqual(t, func(name string) interface{} {
   431  		return defaultBase[name]
   432  	})
   433  
   434  	defaultBases := fun.Request.Struct.FieldsByName["req"].Type.Struct.FieldsByName["bases"].DefaultValue.([]interface{})
   435  	for i := range defaultBases {
   436  		defaultBase := defaultBases[i].(map[string]interface{})
   437  		defaultValueDeepEqual(t, func(name string) interface{} {
   438  			return defaultBase[name]
   439  		})
   440  	}
   441  }
   442  
   443  func defaultValueDeepEqual(t *testing.T, defaultValue func(name string) interface{}) {
   444  	test.Assert(t, defaultValue("a") == true)
   445  	test.Assert(t, defaultValue("b") == byte(1))
   446  	test.Assert(t, defaultValue("c") == int16(2))
   447  	test.Assert(t, defaultValue("d") == int32(3))
   448  	test.Assert(t, defaultValue("e") == int64(4))
   449  	test.Assert(t, defaultValue("f") == 5.1)
   450  	test.Assert(t, defaultValue("g") == "123")
   451  	test.Assert(t, reflect.DeepEqual(defaultValue("h"), []byte("456")))
   452  	test.Assert(t, defaultValue("i") == int64(3))
   453  	test.Assert(t, reflect.DeepEqual(defaultValue("j"), []interface{}{"123", "456"}))
   454  	test.Assert(t, reflect.DeepEqual(defaultValue("k"), []interface{}{
   455  		map[string]interface{}{"a": int32(34)},
   456  		map[string]interface{}{"a": int32(56)},
   457  	}))
   458  	test.Assert(t, reflect.DeepEqual(defaultValue("l"), map[interface{}]interface{}{
   459  		"123": int32(12),
   460  		"456": int32(45),
   461  	}))
   462  	test.Assert(t, reflect.DeepEqual(defaultValue("m"), map[interface{}]interface{}{
   463  		int32(12): map[string]interface{}{"a": int32(34)},
   464  	}))
   465  	test.Assert(t, reflect.DeepEqual(defaultValue("n"), map[string]interface{}{
   466  		"a": int32(56),
   467  	}))
   468  }