github.com/erda-project/erda-infra@v1.0.10-0.20240327085753-f3a249292aeb/pkg/transport/transport_test.go (about)

     1  // Copyright (c) 2021 Terminus, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package transport
    16  
    17  import (
    18  	"context"
    19  	"reflect"
    20  	"testing"
    21  
    22  	transgrpc "github.com/erda-project/erda-infra/pkg/transport/grpc"
    23  	transhttp "github.com/erda-project/erda-infra/pkg/transport/http"
    24  	"github.com/erda-project/erda-infra/pkg/transport/interceptor"
    25  )
    26  
    27  type testInterceptor struct {
    28  	key    string
    29  	append func(v string)
    30  }
    31  
    32  func (i testInterceptor) Wrap(h interceptor.Handler) interceptor.Handler {
    33  	return func(ctx context.Context, req interface{}) (interface{}, error) {
    34  		i.append(i.key + "->")
    35  		h(ctx, req)
    36  		i.append("<-" + i.key)
    37  		return nil, nil
    38  	}
    39  }
    40  
    41  func TestWithInterceptors(t *testing.T) {
    42  	tests := []struct {
    43  		name   string
    44  		handle string
    45  		inters [][]*testInterceptor
    46  		want   []string
    47  	}{
    48  		{
    49  			handle: "handle",
    50  			inters: [][]*testInterceptor{
    51  				{{key: "a"}},
    52  				{{key: "b"}},
    53  				{{key: "c"}},
    54  			},
    55  			want: []string{"a->", "b->", "c->", "handle", "<-c", "<-b", "<-a"},
    56  		},
    57  		{
    58  			handle: "handle",
    59  			inters: [][]*testInterceptor{
    60  				{{key: "a"}, {key: "b"}},
    61  				{{key: "c"}},
    62  			},
    63  			want: []string{"a->", "b->", "c->", "handle", "<-c", "<-b", "<-a"},
    64  		},
    65  		{
    66  			handle: "handle",
    67  			inters: [][]*testInterceptor{
    68  				{{key: "a"}, {key: "b"}},
    69  				{{key: "c"}},
    70  				{{key: "d"}, {key: "e"}},
    71  			},
    72  			want: []string{"a->", "b->", "c->", "d->", "e->", "handle", "<-e", "<-d", "<-c", "<-b", "<-a"},
    73  		},
    74  		{
    75  			handle: "handle",
    76  			want:   []string{"handle"},
    77  		},
    78  		{
    79  			handle: "handle",
    80  			inters: [][]*testInterceptor{
    81  				{{key: "a"}},
    82  			},
    83  			want: []string{"a->", "handle", "<-a"},
    84  		},
    85  		{
    86  			handle: "handle",
    87  			inters: [][]*testInterceptor{
    88  				{{key: "a"}},
    89  				{{key: "b"}},
    90  			},
    91  			want: []string{"a->", "b->", "handle", "<-b", "<-a"},
    92  		},
    93  	}
    94  	for _, tt := range tests {
    95  		t.Run(tt.name, func(t *testing.T) {
    96  			getOpts := func(add func(v string)) (*ServiceOptions, interceptor.Handler) {
    97  				handler := func(ctx context.Context, req interface{}) (interface{}, error) {
    98  					add(tt.handle)
    99  					return nil, nil
   100  				}
   101  				opts := DefaultServiceOptions()
   102  				for _, inters := range tt.inters {
   103  					var list []interceptor.Interceptor
   104  					for _, i := range inters {
   105  						i.append = add
   106  						list = append(list, i.Wrap)
   107  					}
   108  					WithInterceptors(list...)(opts)
   109  				}
   110  				return opts, handler
   111  			}
   112  			var grpcResults []string
   113  			opts, handler := getOpts(func(v string) {
   114  				grpcResults = append(grpcResults, v)
   115  			})
   116  			grpcOpts := transgrpc.DefaultHandleOptions()
   117  			for _, opt := range opts.GRPC {
   118  				opt(grpcOpts)
   119  			}
   120  			if grpcOpts.Interceptor != nil {
   121  				handler = grpcOpts.Interceptor(handler)
   122  			}
   123  			handler(nil, nil)
   124  			if !reflect.DeepEqual(grpcResults, tt.want) {
   125  				t.Errorf("wrapped grpc handler got %v, want %v", grpcResults, tt.want)
   126  			}
   127  
   128  			var httpResults []string
   129  			opts, handler = getOpts(func(v string) {
   130  				httpResults = append(httpResults, v)
   131  			})
   132  			httpOpts := transhttp.DefaultHandleOptions()
   133  			for _, opt := range opts.HTTP {
   134  				opt(httpOpts)
   135  			}
   136  			if httpOpts.Interceptor != nil {
   137  				handler = httpOpts.Interceptor(handler)
   138  			}
   139  			handler(nil, nil)
   140  			if !reflect.DeepEqual(httpResults, tt.want) {
   141  				t.Errorf("wrapped http handler got %v, want %v", httpResults, tt.want)
   142  			}
   143  		})
   144  	}
   145  }