github.com/lingyao2333/mo-zero@v1.4.1/zrpc/internal/serverinterceptors/authinterceptor_test.go (about)

     1  package serverinterceptors
     2  
     3  import (
     4  	"context"
     5  	"testing"
     6  
     7  	"github.com/lingyao2333/mo-zero/core/stores/redis/redistest"
     8  	"github.com/lingyao2333/mo-zero/zrpc/internal/auth"
     9  	"github.com/stretchr/testify/assert"
    10  	"google.golang.org/grpc"
    11  	"google.golang.org/grpc/metadata"
    12  )
    13  
    14  func TestStreamAuthorizeInterceptor(t *testing.T) {
    15  	tests := []struct {
    16  		name     string
    17  		app      string
    18  		token    string
    19  		strict   bool
    20  		hasError bool
    21  	}{
    22  		{
    23  			name:     "strict=false",
    24  			strict:   false,
    25  			hasError: false,
    26  		},
    27  		{
    28  			name:     "strict=true",
    29  			strict:   true,
    30  			hasError: true,
    31  		},
    32  		{
    33  			name:     "strict=true,with token",
    34  			app:      "foo",
    35  			token:    "bar",
    36  			strict:   true,
    37  			hasError: false,
    38  		},
    39  		{
    40  			name:     "strict=true,with error token",
    41  			app:      "foo",
    42  			token:    "error",
    43  			strict:   true,
    44  			hasError: true,
    45  		},
    46  	}
    47  
    48  	store, clean, err := redistest.CreateRedis()
    49  	assert.Nil(t, err)
    50  	defer clean()
    51  
    52  	for _, test := range tests {
    53  		t.Run(test.name, func(t *testing.T) {
    54  			if len(test.app) > 0 {
    55  				assert.Nil(t, store.Hset("apps", test.app, test.token))
    56  				defer store.Hdel("apps", test.app)
    57  			}
    58  
    59  			authenticator, err := auth.NewAuthenticator(store, "apps", test.strict)
    60  			assert.Nil(t, err)
    61  			interceptor := StreamAuthorizeInterceptor(authenticator)
    62  			md := metadata.New(map[string]string{
    63  				"app":   "foo",
    64  				"token": "bar",
    65  			})
    66  			ctx := metadata.NewIncomingContext(context.Background(), md)
    67  			stream := mockedStream{ctx: ctx}
    68  			err = interceptor(nil, stream, nil, func(_ interface{}, _ grpc.ServerStream) error {
    69  				return nil
    70  			})
    71  			if test.hasError {
    72  				assert.NotNil(t, err)
    73  			} else {
    74  				assert.Nil(t, err)
    75  			}
    76  		})
    77  	}
    78  }
    79  
    80  func TestUnaryAuthorizeInterceptor(t *testing.T) {
    81  	tests := []struct {
    82  		name     string
    83  		app      string
    84  		token    string
    85  		strict   bool
    86  		hasError bool
    87  	}{
    88  		{
    89  			name:     "strict=false",
    90  			strict:   false,
    91  			hasError: false,
    92  		},
    93  		{
    94  			name:     "strict=true",
    95  			strict:   true,
    96  			hasError: true,
    97  		},
    98  		{
    99  			name:     "strict=true,with token",
   100  			app:      "foo",
   101  			token:    "bar",
   102  			strict:   true,
   103  			hasError: false,
   104  		},
   105  		{
   106  			name:     "strict=true,with error token",
   107  			app:      "foo",
   108  			token:    "error",
   109  			strict:   true,
   110  			hasError: true,
   111  		},
   112  	}
   113  
   114  	store, clean, err := redistest.CreateRedis()
   115  	assert.Nil(t, err)
   116  	defer clean()
   117  
   118  	for _, test := range tests {
   119  		t.Run(test.name, func(t *testing.T) {
   120  			if len(test.app) > 0 {
   121  				assert.Nil(t, store.Hset("apps", test.app, test.token))
   122  				defer store.Hdel("apps", test.app)
   123  			}
   124  
   125  			authenticator, err := auth.NewAuthenticator(store, "apps", test.strict)
   126  			assert.Nil(t, err)
   127  			interceptor := UnaryAuthorizeInterceptor(authenticator)
   128  			md := metadata.New(map[string]string{
   129  				"app":   "foo",
   130  				"token": "bar",
   131  			})
   132  			ctx := metadata.NewIncomingContext(context.Background(), md)
   133  			_, err = interceptor(ctx, nil, nil,
   134  				func(ctx context.Context, req interface{}) (interface{}, error) {
   135  					return nil, nil
   136  				})
   137  			if test.hasError {
   138  				assert.NotNil(t, err)
   139  			} else {
   140  				assert.Nil(t, err)
   141  			}
   142  			if test.strict {
   143  				_, err = interceptor(context.Background(), nil, nil,
   144  					func(ctx context.Context, req interface{}) (interface{}, error) {
   145  						return nil, nil
   146  					})
   147  				assert.NotNil(t, err)
   148  
   149  				var md metadata.MD
   150  				ctx := metadata.NewIncomingContext(context.Background(), md)
   151  				_, err = interceptor(ctx, nil, nil,
   152  					func(ctx context.Context, req interface{}) (interface{}, error) {
   153  						return nil, nil
   154  					})
   155  				assert.NotNil(t, err)
   156  
   157  				md = metadata.New(map[string]string{
   158  					"app":   "",
   159  					"token": "",
   160  				})
   161  				ctx = metadata.NewIncomingContext(context.Background(), md)
   162  				_, err = interceptor(ctx, nil, nil,
   163  					func(ctx context.Context, req interface{}) (interface{}, error) {
   164  						return nil, nil
   165  					})
   166  				assert.NotNil(t, err)
   167  			}
   168  		})
   169  	}
   170  }
   171  
   172  type mockedStream struct {
   173  	ctx context.Context
   174  }
   175  
   176  func (m mockedStream) SetHeader(md metadata.MD) error {
   177  	return nil
   178  }
   179  
   180  func (m mockedStream) SendHeader(md metadata.MD) error {
   181  	return nil
   182  }
   183  
   184  func (m mockedStream) SetTrailer(md metadata.MD) {
   185  }
   186  
   187  func (m mockedStream) Context() context.Context {
   188  	return m.ctx
   189  }
   190  
   191  func (m mockedStream) SendMsg(v interface{}) error {
   192  	return nil
   193  }
   194  
   195  func (m mockedStream) RecvMsg(v interface{}) error {
   196  	return nil
   197  }