github.com/cloudwego/kitex@v0.9.0/pkg/utils/kitexutil/kitexutil_test.go (about)

     1  /*
     2   * Copyright 2023 CloudWeGo Authors
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   *     http://www.apache.org/licenses/LICENSE-2.0
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   */
    16  
    17  package kitexutil
    18  
    19  import (
    20  	"context"
    21  	"net"
    22  	"os"
    23  	"reflect"
    24  	"testing"
    25  
    26  	mocks "github.com/cloudwego/kitex/internal/mocks/thrift/fast"
    27  	"github.com/cloudwego/kitex/internal/test"
    28  	"github.com/cloudwego/kitex/pkg/rpcinfo"
    29  	"github.com/cloudwego/kitex/pkg/utils"
    30  	"github.com/cloudwego/kitex/transport"
    31  )
    32  
    33  var (
    34  	testRi         rpcinfo.RPCInfo
    35  	testCtx        context.Context
    36  	panicCtx       context.Context
    37  	caller         = "kitexutil.from.service"
    38  	callee         = "kitexutil.to.service"
    39  	idlServiceName = "MockService"
    40  	fromAddr       = utils.NewNetAddr("test", "127.0.0.1:12345")
    41  	fromMethod     = "from_method"
    42  	method         = "method"
    43  	tp             = transport.TTHeader
    44  )
    45  
    46  func TestMain(m *testing.M) {
    47  	testRi = buildRPCInfo()
    48  
    49  	testCtx = context.Background()
    50  	testCtx = rpcinfo.NewCtxWithRPCInfo(testCtx, testRi)
    51  	panicCtx = rpcinfo.NewCtxWithRPCInfo(context.Background(), &panicRPCInfo{})
    52  
    53  	os.Exit(m.Run())
    54  }
    55  
    56  func TestGetCaller(t *testing.T) {
    57  	type args struct {
    58  		ctx context.Context
    59  	}
    60  	tests := []struct {
    61  		name  string
    62  		args  args
    63  		want  string
    64  		want1 bool
    65  	}{
    66  		{name: "Success", args: args{testCtx}, want: caller, want1: true},
    67  		{name: "Failure", args: args{context.Background()}, want: "", want1: false},
    68  		{name: "Panic recovered", args: args{panicCtx}, want: "", want1: false},
    69  	}
    70  	for _, tt := range tests {
    71  		t.Run(tt.name, func(t *testing.T) {
    72  			got, got1 := GetCaller(tt.args.ctx)
    73  			if got != tt.want {
    74  				t.Errorf("GetCaller() got = %v, want %v", got, tt.want)
    75  			}
    76  			if got1 != tt.want1 {
    77  				t.Errorf("GetCaller() got1 = %v, want %v", got1, tt.want1)
    78  			}
    79  		})
    80  	}
    81  }
    82  
    83  func TestGetCallerAddr(t *testing.T) {
    84  	type args struct {
    85  		ctx context.Context
    86  	}
    87  	tests := []struct {
    88  		name  string
    89  		args  args
    90  		want  net.Addr
    91  		want1 bool
    92  	}{
    93  		{name: "Success", args: args{testCtx}, want: fromAddr, want1: true},
    94  		{name: "Failure", args: args{context.Background()}, want: nil, want1: false},
    95  		{name: "Panic recovered", args: args{panicCtx}, want: nil, want1: false},
    96  	}
    97  	for _, tt := range tests {
    98  		t.Run(tt.name, func(t *testing.T) {
    99  			got, got1 := GetCallerAddr(tt.args.ctx)
   100  			if !reflect.DeepEqual(got, tt.want) {
   101  				t.Errorf("GetCallerAddr() got = %v, want %v", got, tt.want)
   102  			}
   103  			if got1 != tt.want1 {
   104  				t.Errorf("GetCallerAddr() got1 = %v, want %v", got1, tt.want1)
   105  			}
   106  		})
   107  	}
   108  }
   109  
   110  func TestGetCallerIP(t *testing.T) {
   111  	ip, ok := GetCallerIP(testCtx)
   112  	test.Assert(t, ok)
   113  	test.Assert(t, ip == "127.0.0.1", ip)
   114  
   115  	ri := buildRPCInfo()
   116  	rpcinfo.AsMutableEndpointInfo(ri.From()).SetAddress(utils.NewNetAddr("test", "127.0.0.1"))
   117  	ip, ok = GetCallerIP(rpcinfo.NewCtxWithRPCInfo(context.Background(), ri))
   118  	test.Assert(t, ok)
   119  	test.Assert(t, ip == "127.0.0.1", ip)
   120  
   121  	ip, ok = GetCallerIP(context.Background())
   122  	test.Assert(t, !ok)
   123  	test.Assert(t, ip == "", ip)
   124  
   125  	ip, ok = GetCallerIP(panicCtx)
   126  	test.Assert(t, !ok)
   127  	test.Assert(t, ip == "", ip)
   128  
   129  	rpcinfo.AsMutableEndpointInfo(ri.From()).SetAddress(utils.NewNetAddr("test", ""))
   130  	ip, ok = GetCallerIP(rpcinfo.NewCtxWithRPCInfo(context.Background(), ri))
   131  	test.Assert(t, !ok)
   132  	test.Assert(t, ip == "", ip)
   133  }
   134  
   135  func TestGetMethod(t *testing.T) {
   136  	type args struct {
   137  		ctx context.Context
   138  	}
   139  	tests := []struct {
   140  		name  string
   141  		args  args
   142  		want  string
   143  		want1 bool
   144  	}{
   145  		{name: "Success", args: args{testCtx}, want: method, want1: true},
   146  		{name: "Failure", args: args{context.Background()}, want: "", want1: false},
   147  	}
   148  	for _, tt := range tests {
   149  		t.Run(tt.name, func(t *testing.T) {
   150  			got, got1 := GetMethod(tt.args.ctx)
   151  			if got != tt.want {
   152  				t.Errorf("GetMethod() got = %v, want %v", got, tt.want)
   153  			}
   154  			if got1 != tt.want1 {
   155  				t.Errorf("GetMethod() got1 = %v, want %v", got1, tt.want1)
   156  			}
   157  		})
   158  	}
   159  }
   160  
   161  func TestGetCallerHandlerMethod(t *testing.T) {
   162  	type args struct {
   163  		ctx context.Context
   164  	}
   165  	tests := []struct {
   166  		name  string
   167  		args  args
   168  		want  string
   169  		want1 bool
   170  	}{
   171  		{name: "Success", args: args{testCtx}, want: fromMethod, want1: true},
   172  		{name: "Failure", args: args{context.Background()}, want: "", want1: false},
   173  		{name: "Panic recovered", args: args{panicCtx}, want: "", want1: false},
   174  	}
   175  	for _, tt := range tests {
   176  		t.Run(tt.name, func(t *testing.T) {
   177  			got, got1 := GetCallerHandlerMethod(tt.args.ctx)
   178  			if !reflect.DeepEqual(got, tt.want) {
   179  				t.Errorf("GetCallerHandlerMethod() got = %v, want %v", got, tt.want)
   180  			}
   181  			if got1 != tt.want1 {
   182  				t.Errorf("GetCallerHandlerMethod() got1 = %v, want %v", got1, tt.want1)
   183  			}
   184  		})
   185  	}
   186  }
   187  
   188  func TestGetIDLServiceName(t *testing.T) {
   189  	type args struct {
   190  		ctx context.Context
   191  	}
   192  	tests := []struct {
   193  		name  string
   194  		args  args
   195  		want  string
   196  		want1 bool
   197  	}{
   198  		{name: "Success", args: args{testCtx}, want: idlServiceName, want1: true},
   199  		{name: "Failure", args: args{context.Background()}, want: "", want1: false},
   200  		{name: "Panic recovered", args: args{panicCtx}, want: "", want1: false},
   201  	}
   202  	for _, tt := range tests {
   203  		t.Run(tt.name, func(t *testing.T) {
   204  			got, got1 := GetIDLServiceName(tt.args.ctx)
   205  			if !reflect.DeepEqual(got, tt.want) {
   206  				t.Errorf("GetCallerHandlerMethod() got = %v, want %v", got, tt.want)
   207  			}
   208  			if got1 != tt.want1 {
   209  				t.Errorf("GetCallerHandlerMethod() got1 = %v, want %v", got1, tt.want1)
   210  			}
   211  		})
   212  	}
   213  }
   214  
   215  func TestGetRPCInfo(t *testing.T) {
   216  	type args struct {
   217  		ctx context.Context
   218  	}
   219  	tests := []struct {
   220  		name  string
   221  		args  args
   222  		want  rpcinfo.RPCInfo
   223  		want1 bool
   224  	}{
   225  		{name: "Success", args: args{testCtx}, want: testRi, want1: true},
   226  		{name: "Failure", args: args{context.Background()}, want: nil, want1: false},
   227  	}
   228  	for _, tt := range tests {
   229  		t.Run(tt.name, func(t *testing.T) {
   230  			got, got1 := GetRPCInfo(tt.args.ctx)
   231  			if !reflect.DeepEqual(got, tt.want) {
   232  				t.Errorf("GetRPCInfo() got = %v, want %v", got, tt.want)
   233  			}
   234  			if got1 != tt.want1 {
   235  				t.Errorf("GetRPCInfo() got1 = %v, want %v", got1, tt.want1)
   236  			}
   237  		})
   238  	}
   239  }
   240  
   241  func TestGetCtxTransportProtocol(t *testing.T) {
   242  	type args struct {
   243  		ctx context.Context
   244  	}
   245  	tests := []struct {
   246  		name  string
   247  		args  args
   248  		want  string
   249  		want1 bool
   250  	}{
   251  		{name: "Success", args: args{testCtx}, want: tp.String(), want1: true},
   252  		{name: "Failure", args: args{context.Background()}, want: "", want1: false},
   253  		{name: "Panic recovered", args: args{panicCtx}, want: "", want1: false},
   254  	}
   255  	for _, tt := range tests {
   256  		t.Run(tt.name, func(t *testing.T) {
   257  			got, got1 := GetTransportProtocol(tt.args.ctx)
   258  			if got != tt.want {
   259  				t.Errorf("GetTransportProtocol() got = %v, want %v", got, tt.want)
   260  			}
   261  			if got1 != tt.want1 {
   262  				t.Errorf("GetTransportProtocol() got1 = %v, want %v", got1, tt.want1)
   263  			}
   264  		})
   265  	}
   266  }
   267  
   268  func TestGetRealRequest(t *testing.T) {
   269  	req := &mocks.MockReq{}
   270  	arg := &mocks.MockTestArgs{Req: req}
   271  	type args struct {
   272  		req interface{}
   273  	}
   274  	tests := []struct {
   275  		name string
   276  		args args
   277  		want interface{}
   278  	}{
   279  		{name: "success", args: args{arg}, want: req},
   280  		{name: "nil input", args: args{nil}, want: nil},
   281  		{name: "wrong interface", args: args{req}, want: nil},
   282  	}
   283  	for _, tt := range tests {
   284  		t.Run(tt.name, func(t *testing.T) {
   285  			if got := GetRealReqFromKitexArgs(tt.args.req); !reflect.DeepEqual(got, tt.want) {
   286  				t.Errorf("GetRealReqFromKitexArgs() = %v, want %v", got, tt.want)
   287  			}
   288  		})
   289  	}
   290  }
   291  
   292  func TestGetRealResponse(t *testing.T) {
   293  	success := "success"
   294  	result := &mocks.MockTestResult{Success: &success}
   295  	type args struct {
   296  		resp interface{}
   297  	}
   298  	tests := []struct {
   299  		name string
   300  		args args
   301  		want interface{}
   302  	}{
   303  		{name: "success", args: args{result}, want: &success},
   304  		{name: "nil input", args: args{nil}, want: nil},
   305  		{name: "wrong interface", args: args{success}, want: nil},
   306  	}
   307  	for _, tt := range tests {
   308  		t.Run(tt.name, func(t *testing.T) {
   309  			if got := GetRealRespFromKitexResult(tt.args.resp); !reflect.DeepEqual(got, tt.want) {
   310  				t.Errorf("GetRealRespFromKitexResult() = %v, want %v", got, tt.want)
   311  			}
   312  		})
   313  	}
   314  }
   315  
   316  func buildRPCInfo() rpcinfo.RPCInfo {
   317  	from := rpcinfo.NewEndpointInfo(caller, fromMethod, fromAddr, nil)
   318  	to := rpcinfo.NewEndpointInfo(callee, method, nil, nil)
   319  	ink := rpcinfo.NewInvocation(idlServiceName, method)
   320  	config := rpcinfo.NewRPCConfig()
   321  	config.(rpcinfo.MutableRPCConfig).SetTransportProtocol(tp)
   322  
   323  	stats := rpcinfo.NewRPCStats()
   324  	ri := rpcinfo.NewRPCInfo(from, to, ink, config, stats)
   325  	return ri
   326  }
   327  
   328  type panicRPCInfo struct{}
   329  
   330  func (m *panicRPCInfo) From() rpcinfo.EndpointInfo     { panic("Panic when invoke From") }
   331  func (m *panicRPCInfo) To() rpcinfo.EndpointInfo       { panic("Panic when invoke To") }
   332  func (m *panicRPCInfo) Invocation() rpcinfo.Invocation { panic("Panic when invoke Invocation") }
   333  func (m *panicRPCInfo) Config() rpcinfo.RPCConfig      { panic("Panic when invoke Config") }
   334  func (m *panicRPCInfo) Stats() rpcinfo.RPCStats        { panic("Panic when invoke Stats") }