gitee.com/ks-custle/core-gm@v0.0.0-20230922171213-b83bdd97b62c/go-grpc-middleware/auth/auth_test.go (about) 1 // Copyright 2016 Michal Witkowski. All Rights Reserved. 2 // See LICENSE for licensing terms. 3 4 package grpc_auth_test 5 6 import ( 7 "gitee.com/ks-custle/core-gm/grpc/status" 8 "testing" 9 10 "gitee.com/ks-custle/core-gm/grpc" 11 "github.com/stretchr/testify/suite" 12 13 "fmt" 14 15 "time" 16 17 grpcauth "gitee.com/ks-custle/core-gm/go-grpc-middleware/auth" 18 grpctesting "gitee.com/ks-custle/core-gm/go-grpc-middleware/testing" 19 pbtestproto "gitee.com/ks-custle/core-gm/go-grpc-middleware/testing/testproto" 20 "gitee.com/ks-custle/core-gm/go-grpc-middleware/util/metautils" 21 "gitee.com/ks-custle/core-gm/grpc/codes" 22 "gitee.com/ks-custle/core-gm/grpc/credentials/oauth" 23 "gitee.com/ks-custle/core-gm/grpc/metadata" 24 "gitee.com/ks-custle/core-gm/net/context" 25 "github.com/stretchr/testify/assert" 26 "github.com/stretchr/testify/require" 27 "golang.org/x/oauth2" 28 ) 29 30 var ( 31 commonAuthToken = "some_good_token" 32 overrideAuthToken = "override_token" 33 34 authedMarker = "some_context_marker" 35 goodPing = &pbtestproto.PingRequest{Value: "something", SleepTimeMs: 9999} 36 ) 37 38 // TODO(mwitkow): Add auth from metadata client dialer, which requires TLS. 39 40 func buildDummyAuthFunction(expectedScheme string, expectedToken string) func(ctx context.Context) (context.Context, error) { 41 return func(ctx context.Context) (context.Context, error) { 42 token, err := grpcauth.AuthFromMD(ctx, expectedScheme) 43 if err != nil { 44 return nil, err 45 } 46 if token != expectedToken { 47 // `grpc.Errorf` is deprecated. use status.Errorf instead. 48 //return nil, grpc.Errorf(codes.PermissionDenied, "buildDummyAuthFunction bad token") 49 return nil, status.Errorf(codes.PermissionDenied, "buildDummyAuthFunction bad token") 50 } 51 return context.WithValue(ctx, authedMarker, "marker_exists"), nil 52 } 53 } 54 55 func assertAuthMarkerExists(t *testing.T, ctx context.Context) { 56 assert.Equal(t, "marker_exists", ctx.Value(authedMarker).(string), "auth marker from buildDummyAuthFunction must be passed around") 57 } 58 59 type assertingPingService struct { 60 pbtestproto.TestServiceServer 61 T *testing.T 62 } 63 64 func (s *assertingPingService) PingError(ctx context.Context, ping *pbtestproto.PingRequest) (*pbtestproto.Empty, error) { 65 assertAuthMarkerExists(s.T, ctx) 66 return s.TestServiceServer.PingError(ctx, ping) 67 } 68 69 func (s *assertingPingService) PingList(ping *pbtestproto.PingRequest, stream pbtestproto.TestService_PingListServer) error { 70 assertAuthMarkerExists(s.T, stream.Context()) 71 return s.TestServiceServer.PingList(ping, stream) 72 } 73 74 func ctxWithToken(ctx context.Context, scheme string, token string) context.Context { 75 md := metadata.Pairs("authorization", fmt.Sprintf("%s %v", scheme, token)) 76 nCtx := metautils.NiceMD(md).ToOutgoing(ctx) 77 return nCtx 78 } 79 80 func TestAuthTestSuite(t *testing.T) { 81 authFunc := buildDummyAuthFunction("bearer", commonAuthToken) 82 s := &AuthTestSuite{ 83 InterceptorTestSuite: &grpctesting.InterceptorTestSuite{ 84 TestService: &assertingPingService{&grpctesting.TestPingService{T: t}, t}, 85 ServerOpts: []grpc.ServerOption{ 86 grpc.StreamInterceptor(grpcauth.StreamServerInterceptor(authFunc)), 87 grpc.UnaryInterceptor(grpcauth.UnaryServerInterceptor(authFunc)), 88 }, 89 }, 90 } 91 suite.Run(t, s) 92 } 93 94 type AuthTestSuite struct { 95 *grpctesting.InterceptorTestSuite 96 } 97 98 func (s *AuthTestSuite) TestUnary_NoAuth() { 99 _, err := s.Client.Ping(s.SimpleCtx(), goodPing) 100 assert.Error(s.T(), err, "there must be an error") 101 // `grpc.Code` is deprecated. Use `status.Code` instead. 102 //assert.Equal(s.T(), codes.Unauthenticated, grpc.Code(err), "must error with unauthenticated") 103 assert.Equal(s.T(), codes.Unauthenticated, status.Code(err), "must error with unauthenticated") 104 } 105 106 func (s *AuthTestSuite) TestUnary_BadAuth() { 107 _, err := s.Client.Ping(ctxWithToken(s.SimpleCtx(), "bearer", "bad_token"), goodPing) 108 assert.Error(s.T(), err, "there must be an error") 109 // `grpc.Code` is deprecated. Use `status.Code` instead. 110 //assert.Equal(s.T(), codes.PermissionDenied, grpc.Code(err), "must error with permission denied") 111 assert.Equal(s.T(), codes.PermissionDenied, status.Code(err), "must error with permission denied") 112 } 113 114 func (s *AuthTestSuite) TestUnary_PassesAuth() { 115 _, err := s.Client.Ping(ctxWithToken(s.SimpleCtx(), "bearer", commonAuthToken), goodPing) 116 require.NoError(s.T(), err, "no error must occur") 117 } 118 119 func (s *AuthTestSuite) TestUnary_PassesWithPerRpcCredentials() { 120 grpcCreds := oauth.TokenSource{TokenSource: &fakeOAuth2TokenSource{accessToken: commonAuthToken}} 121 client := s.NewClient(grpc.WithPerRPCCredentials(grpcCreds)) 122 _, err := client.Ping(s.SimpleCtx(), goodPing) 123 require.NoError(s.T(), err, "no error must occur") 124 } 125 126 func (s *AuthTestSuite) TestStream_NoAuth() { 127 stream, err := s.Client.PingList(s.SimpleCtx(), goodPing) 128 require.NoError(s.T(), err, "should not fail on establishing the stream") 129 _, err = stream.Recv() 130 assert.Error(s.T(), err, "there must be an error") 131 // `grpc.Code` is deprecated. Use `status.Code` instead. 132 //assert.Equal(s.T(), codes.Unauthenticated, grpc.Code(err), "must error with unauthenticated") 133 assert.Equal(s.T(), codes.Unauthenticated, status.Code(err), "must error with unauthenticated") 134 } 135 136 func (s *AuthTestSuite) TestStream_BadAuth() { 137 stream, err := s.Client.PingList(ctxWithToken(s.SimpleCtx(), "bearer", "bad_token"), goodPing) 138 require.NoError(s.T(), err, "should not fail on establishing the stream") 139 _, err = stream.Recv() 140 assert.Error(s.T(), err, "there must be an error") 141 // `grpc.Code` is deprecated. Use `status.Code` instead. 142 //assert.Equal(s.T(), codes.PermissionDenied, grpc.Code(err), "must error with permission denied") 143 assert.Equal(s.T(), codes.PermissionDenied, status.Code(err), "must error with permission denied") 144 } 145 146 func (s *AuthTestSuite) TestStream_PassesAuth() { 147 stream, err := s.Client.PingList(ctxWithToken(s.SimpleCtx(), "Bearer", commonAuthToken), goodPing) 148 require.NoError(s.T(), err, "should not fail on establishing the stream") 149 pong, err := stream.Recv() 150 require.NoError(s.T(), err, "no error must occur") 151 require.NotNil(s.T(), pong, "pong must not be nil") 152 } 153 154 func (s *AuthTestSuite) TestStream_PassesWithPerRpcCredentials() { 155 grpcCreds := oauth.TokenSource{TokenSource: &fakeOAuth2TokenSource{accessToken: commonAuthToken}} 156 client := s.NewClient(grpc.WithPerRPCCredentials(grpcCreds)) 157 stream, err := client.PingList(s.SimpleCtx(), goodPing) 158 require.NoError(s.T(), err, "should not fail on establishing the stream") 159 pong, err := stream.Recv() 160 require.NoError(s.T(), err, "no error must occur") 161 require.NotNil(s.T(), pong, "pong must not be nil") 162 } 163 164 type authOverrideTestService struct { 165 pbtestproto.TestServiceServer 166 T *testing.T 167 } 168 169 func (s *authOverrideTestService) AuthFuncOverride(ctx context.Context, fullMethodName string) (context.Context, error) { 170 assert.NotEmpty(s.T, fullMethodName, "method name of caller is passed around") 171 return buildDummyAuthFunction("bearer", overrideAuthToken)(ctx) 172 } 173 174 func TestAuthOverrideTestSuite(t *testing.T) { 175 authFunc := buildDummyAuthFunction("bearer", commonAuthToken) 176 s := &AuthOverrideTestSuite{ 177 InterceptorTestSuite: &grpctesting.InterceptorTestSuite{ 178 TestService: &authOverrideTestService{&assertingPingService{&grpctesting.TestPingService{T: t}, t}, t}, 179 ServerOpts: []grpc.ServerOption{ 180 grpc.StreamInterceptor(grpcauth.StreamServerInterceptor(authFunc)), 181 grpc.UnaryInterceptor(grpcauth.UnaryServerInterceptor(authFunc)), 182 }, 183 }, 184 } 185 suite.Run(t, s) 186 } 187 188 type AuthOverrideTestSuite struct { 189 *grpctesting.InterceptorTestSuite 190 } 191 192 func (s *AuthOverrideTestSuite) TestUnary_PassesAuth() { 193 _, err := s.Client.Ping(ctxWithToken(s.SimpleCtx(), "bearer", overrideAuthToken), goodPing) 194 require.NoError(s.T(), err, "no error must occur") 195 } 196 197 func (s *AuthOverrideTestSuite) TestStream_PassesAuth() { 198 stream, err := s.Client.PingList(ctxWithToken(s.SimpleCtx(), "Bearer", overrideAuthToken), goodPing) 199 require.NoError(s.T(), err, "should not fail on establishing the stream") 200 pong, err := stream.Recv() 201 require.NoError(s.T(), err, "no error must occur") 202 require.NotNil(s.T(), pong, "pong must not be nil") 203 } 204 205 // fakeOAuth2TokenSource implements a fake oauth2.TokenSource for the purpose of credentials test. 206 type fakeOAuth2TokenSource struct { 207 accessToken string 208 } 209 210 func (ts *fakeOAuth2TokenSource) Token() (*oauth2.Token, error) { 211 t := &oauth2.Token{ 212 AccessToken: ts.accessToken, 213 Expiry: time.Now().Add(1 * time.Minute), 214 TokenType: "bearer", 215 } 216 return t, nil 217 }