github.com/lingyao2333/mo-zero@v1.4.1/zrpc/internal/serverinterceptors/authinterceptor_test.go (about) 1 package serverinterceptors 2 3 import ( 4 "context" 5 "testing" 6 7 "github.com/lingyao2333/mo-zero/core/stores/redis/redistest" 8 "github.com/lingyao2333/mo-zero/zrpc/internal/auth" 9 "github.com/stretchr/testify/assert" 10 "google.golang.org/grpc" 11 "google.golang.org/grpc/metadata" 12 ) 13 14 func TestStreamAuthorizeInterceptor(t *testing.T) { 15 tests := []struct { 16 name string 17 app string 18 token string 19 strict bool 20 hasError bool 21 }{ 22 { 23 name: "strict=false", 24 strict: false, 25 hasError: false, 26 }, 27 { 28 name: "strict=true", 29 strict: true, 30 hasError: true, 31 }, 32 { 33 name: "strict=true,with token", 34 app: "foo", 35 token: "bar", 36 strict: true, 37 hasError: false, 38 }, 39 { 40 name: "strict=true,with error token", 41 app: "foo", 42 token: "error", 43 strict: true, 44 hasError: true, 45 }, 46 } 47 48 store, clean, err := redistest.CreateRedis() 49 assert.Nil(t, err) 50 defer clean() 51 52 for _, test := range tests { 53 t.Run(test.name, func(t *testing.T) { 54 if len(test.app) > 0 { 55 assert.Nil(t, store.Hset("apps", test.app, test.token)) 56 defer store.Hdel("apps", test.app) 57 } 58 59 authenticator, err := auth.NewAuthenticator(store, "apps", test.strict) 60 assert.Nil(t, err) 61 interceptor := StreamAuthorizeInterceptor(authenticator) 62 md := metadata.New(map[string]string{ 63 "app": "foo", 64 "token": "bar", 65 }) 66 ctx := metadata.NewIncomingContext(context.Background(), md) 67 stream := mockedStream{ctx: ctx} 68 err = interceptor(nil, stream, nil, func(_ interface{}, _ grpc.ServerStream) error { 69 return nil 70 }) 71 if test.hasError { 72 assert.NotNil(t, err) 73 } else { 74 assert.Nil(t, err) 75 } 76 }) 77 } 78 } 79 80 func TestUnaryAuthorizeInterceptor(t *testing.T) { 81 tests := []struct { 82 name string 83 app string 84 token string 85 strict bool 86 hasError bool 87 }{ 88 { 89 name: "strict=false", 90 strict: false, 91 hasError: false, 92 }, 93 { 94 name: "strict=true", 95 strict: true, 96 hasError: true, 97 }, 98 { 99 name: "strict=true,with token", 100 app: "foo", 101 token: "bar", 102 strict: true, 103 hasError: false, 104 }, 105 { 106 name: "strict=true,with error token", 107 app: "foo", 108 token: "error", 109 strict: true, 110 hasError: true, 111 }, 112 } 113 114 store, clean, err := redistest.CreateRedis() 115 assert.Nil(t, err) 116 defer clean() 117 118 for _, test := range tests { 119 t.Run(test.name, func(t *testing.T) { 120 if len(test.app) > 0 { 121 assert.Nil(t, store.Hset("apps", test.app, test.token)) 122 defer store.Hdel("apps", test.app) 123 } 124 125 authenticator, err := auth.NewAuthenticator(store, "apps", test.strict) 126 assert.Nil(t, err) 127 interceptor := UnaryAuthorizeInterceptor(authenticator) 128 md := metadata.New(map[string]string{ 129 "app": "foo", 130 "token": "bar", 131 }) 132 ctx := metadata.NewIncomingContext(context.Background(), md) 133 _, err = interceptor(ctx, nil, nil, 134 func(ctx context.Context, req interface{}) (interface{}, error) { 135 return nil, nil 136 }) 137 if test.hasError { 138 assert.NotNil(t, err) 139 } else { 140 assert.Nil(t, err) 141 } 142 if test.strict { 143 _, err = interceptor(context.Background(), nil, nil, 144 func(ctx context.Context, req interface{}) (interface{}, error) { 145 return nil, nil 146 }) 147 assert.NotNil(t, err) 148 149 var md metadata.MD 150 ctx := metadata.NewIncomingContext(context.Background(), md) 151 _, err = interceptor(ctx, nil, nil, 152 func(ctx context.Context, req interface{}) (interface{}, error) { 153 return nil, nil 154 }) 155 assert.NotNil(t, err) 156 157 md = metadata.New(map[string]string{ 158 "app": "", 159 "token": "", 160 }) 161 ctx = metadata.NewIncomingContext(context.Background(), md) 162 _, err = interceptor(ctx, nil, nil, 163 func(ctx context.Context, req interface{}) (interface{}, error) { 164 return nil, nil 165 }) 166 assert.NotNil(t, err) 167 } 168 }) 169 } 170 } 171 172 type mockedStream struct { 173 ctx context.Context 174 } 175 176 func (m mockedStream) SetHeader(md metadata.MD) error { 177 return nil 178 } 179 180 func (m mockedStream) SendHeader(md metadata.MD) error { 181 return nil 182 } 183 184 func (m mockedStream) SetTrailer(md metadata.MD) { 185 } 186 187 func (m mockedStream) Context() context.Context { 188 return m.ctx 189 } 190 191 func (m mockedStream) SendMsg(v interface{}) error { 192 return nil 193 } 194 195 func (m mockedStream) RecvMsg(v interface{}) error { 196 return nil 197 }