github.com/hxx258456/ccgo@v0.0.5-0.20230213014102-48b35f46f66f/go-grpc-middleware/auth/auth_test.go (about) 1 // Copyright 2016 Michal Witkowski. All Rights Reserved. 2 // See LICENSE for licensing terms. 3 4 package grpc_auth_test 5 6 import ( 7 "testing" 8 9 "github.com/hxx258456/ccgo/grpc" 10 "github.com/stretchr/testify/suite" 11 12 "fmt" 13 14 "time" 15 16 grpc_auth "github.com/hxx258456/ccgo/go-grpc-middleware/auth" 17 grpc_testing "github.com/hxx258456/ccgo/go-grpc-middleware/testing" 18 pb_testproto "github.com/hxx258456/ccgo/go-grpc-middleware/testing/testproto" 19 "github.com/hxx258456/ccgo/go-grpc-middleware/util/metautils" 20 "github.com/hxx258456/ccgo/grpc/codes" 21 "github.com/hxx258456/ccgo/grpc/credentials/oauth" 22 "github.com/hxx258456/ccgo/grpc/metadata" 23 "github.com/hxx258456/ccgo/net/context" 24 "github.com/stretchr/testify/assert" 25 "github.com/stretchr/testify/require" 26 "golang.org/x/oauth2" 27 ) 28 29 var ( 30 commonAuthToken = "some_good_token" 31 overrideAuthToken = "override_token" 32 33 authedMarker = "some_context_marker" 34 goodPing = &pb_testproto.PingRequest{Value: "something", SleepTimeMs: 9999} 35 ) 36 37 // TODO(mwitkow): Add auth from metadata client dialer, which requires TLS. 38 39 func buildDummyAuthFunction(expectedScheme string, expectedToken string) func(ctx context.Context) (context.Context, error) { 40 return func(ctx context.Context) (context.Context, error) { 41 token, err := grpc_auth.AuthFromMD(ctx, expectedScheme) 42 if err != nil { 43 return nil, err 44 } 45 if token != expectedToken { 46 return nil, grpc.Errorf(codes.PermissionDenied, "buildDummyAuthFunction bad token") 47 } 48 return context.WithValue(ctx, authedMarker, "marker_exists"), nil 49 } 50 } 51 52 func assertAuthMarkerExists(t *testing.T, ctx context.Context) { 53 assert.Equal(t, "marker_exists", ctx.Value(authedMarker).(string), "auth marker from buildDummyAuthFunction must be passed around") 54 } 55 56 type assertingPingService struct { 57 pb_testproto.TestServiceServer 58 T *testing.T 59 } 60 61 func (s *assertingPingService) PingError(ctx context.Context, ping *pb_testproto.PingRequest) (*pb_testproto.Empty, error) { 62 assertAuthMarkerExists(s.T, ctx) 63 return s.TestServiceServer.PingError(ctx, ping) 64 } 65 66 func (s *assertingPingService) PingList(ping *pb_testproto.PingRequest, stream pb_testproto.TestService_PingListServer) error { 67 assertAuthMarkerExists(s.T, stream.Context()) 68 return s.TestServiceServer.PingList(ping, stream) 69 } 70 71 func ctxWithToken(ctx context.Context, scheme string, token string) context.Context { 72 md := metadata.Pairs("authorization", fmt.Sprintf("%s %v", scheme, token)) 73 nCtx := metautils.NiceMD(md).ToOutgoing(ctx) 74 return nCtx 75 } 76 77 func TestAuthTestSuite(t *testing.T) { 78 authFunc := buildDummyAuthFunction("bearer", commonAuthToken) 79 s := &AuthTestSuite{ 80 InterceptorTestSuite: &grpc_testing.InterceptorTestSuite{ 81 TestService: &assertingPingService{&grpc_testing.TestPingService{T: t}, t}, 82 ServerOpts: []grpc.ServerOption{ 83 grpc.StreamInterceptor(grpc_auth.StreamServerInterceptor(authFunc)), 84 grpc.UnaryInterceptor(grpc_auth.UnaryServerInterceptor(authFunc)), 85 }, 86 }, 87 } 88 suite.Run(t, s) 89 } 90 91 type AuthTestSuite struct { 92 *grpc_testing.InterceptorTestSuite 93 } 94 95 func (s *AuthTestSuite) TestUnary_NoAuth() { 96 _, err := s.Client.Ping(s.SimpleCtx(), goodPing) 97 assert.Error(s.T(), err, "there must be an error") 98 assert.Equal(s.T(), codes.Unauthenticated, grpc.Code(err), "must error with unauthenticated") 99 } 100 101 func (s *AuthTestSuite) TestUnary_BadAuth() { 102 _, err := s.Client.Ping(ctxWithToken(s.SimpleCtx(), "bearer", "bad_token"), goodPing) 103 assert.Error(s.T(), err, "there must be an error") 104 assert.Equal(s.T(), codes.PermissionDenied, grpc.Code(err), "must error with permission denied") 105 } 106 107 func (s *AuthTestSuite) TestUnary_PassesAuth() { 108 _, err := s.Client.Ping(ctxWithToken(s.SimpleCtx(), "bearer", commonAuthToken), goodPing) 109 require.NoError(s.T(), err, "no error must occur") 110 } 111 112 func (s *AuthTestSuite) TestUnary_PassesWithPerRpcCredentials() { 113 grpcCreds := oauth.TokenSource{TokenSource: &fakeOAuth2TokenSource{accessToken: commonAuthToken}} 114 client := s.NewClient(grpc.WithPerRPCCredentials(grpcCreds)) 115 _, err := client.Ping(s.SimpleCtx(), goodPing) 116 require.NoError(s.T(), err, "no error must occur") 117 } 118 119 func (s *AuthTestSuite) TestStream_NoAuth() { 120 stream, err := s.Client.PingList(s.SimpleCtx(), goodPing) 121 require.NoError(s.T(), err, "should not fail on establishing the stream") 122 _, err = stream.Recv() 123 assert.Error(s.T(), err, "there must be an error") 124 assert.Equal(s.T(), codes.Unauthenticated, grpc.Code(err), "must error with unauthenticated") 125 } 126 127 func (s *AuthTestSuite) TestStream_BadAuth() { 128 stream, err := s.Client.PingList(ctxWithToken(s.SimpleCtx(), "bearer", "bad_token"), goodPing) 129 require.NoError(s.T(), err, "should not fail on establishing the stream") 130 _, err = stream.Recv() 131 assert.Error(s.T(), err, "there must be an error") 132 assert.Equal(s.T(), codes.PermissionDenied, grpc.Code(err), "must error with permission denied") 133 } 134 135 func (s *AuthTestSuite) TestStream_PassesAuth() { 136 stream, err := s.Client.PingList(ctxWithToken(s.SimpleCtx(), "Bearer", commonAuthToken), goodPing) 137 require.NoError(s.T(), err, "should not fail on establishing the stream") 138 pong, err := stream.Recv() 139 require.NoError(s.T(), err, "no error must occur") 140 require.NotNil(s.T(), pong, "pong must not be nil") 141 } 142 143 func (s *AuthTestSuite) TestStream_PassesWithPerRpcCredentials() { 144 grpcCreds := oauth.TokenSource{TokenSource: &fakeOAuth2TokenSource{accessToken: commonAuthToken}} 145 client := s.NewClient(grpc.WithPerRPCCredentials(grpcCreds)) 146 stream, err := client.PingList(s.SimpleCtx(), goodPing) 147 require.NoError(s.T(), err, "should not fail on establishing the stream") 148 pong, err := stream.Recv() 149 require.NoError(s.T(), err, "no error must occur") 150 require.NotNil(s.T(), pong, "pong must not be nil") 151 } 152 153 type authOverrideTestService struct { 154 pb_testproto.TestServiceServer 155 T *testing.T 156 } 157 158 func (s *authOverrideTestService) AuthFuncOverride(ctx context.Context, fullMethodName string) (context.Context, error) { 159 assert.NotEmpty(s.T, fullMethodName, "method name of caller is passed around") 160 return buildDummyAuthFunction("bearer", overrideAuthToken)(ctx) 161 } 162 163 func TestAuthOverrideTestSuite(t *testing.T) { 164 authFunc := buildDummyAuthFunction("bearer", commonAuthToken) 165 s := &AuthOverrideTestSuite{ 166 InterceptorTestSuite: &grpc_testing.InterceptorTestSuite{ 167 TestService: &authOverrideTestService{&assertingPingService{&grpc_testing.TestPingService{T: t}, t}, t}, 168 ServerOpts: []grpc.ServerOption{ 169 grpc.StreamInterceptor(grpc_auth.StreamServerInterceptor(authFunc)), 170 grpc.UnaryInterceptor(grpc_auth.UnaryServerInterceptor(authFunc)), 171 }, 172 }, 173 } 174 suite.Run(t, s) 175 } 176 177 type AuthOverrideTestSuite struct { 178 *grpc_testing.InterceptorTestSuite 179 } 180 181 func (s *AuthOverrideTestSuite) TestUnary_PassesAuth() { 182 _, err := s.Client.Ping(ctxWithToken(s.SimpleCtx(), "bearer", overrideAuthToken), goodPing) 183 require.NoError(s.T(), err, "no error must occur") 184 } 185 186 func (s *AuthOverrideTestSuite) TestStream_PassesAuth() { 187 stream, err := s.Client.PingList(ctxWithToken(s.SimpleCtx(), "Bearer", overrideAuthToken), goodPing) 188 require.NoError(s.T(), err, "should not fail on establishing the stream") 189 pong, err := stream.Recv() 190 require.NoError(s.T(), err, "no error must occur") 191 require.NotNil(s.T(), pong, "pong must not be nil") 192 } 193 194 // fakeOAuth2TokenSource implements a fake oauth2.TokenSource for the purpose of credentials test. 195 type fakeOAuth2TokenSource struct { 196 accessToken string 197 } 198 199 func (ts *fakeOAuth2TokenSource) Token() (*oauth2.Token, error) { 200 t := &oauth2.Token{ 201 AccessToken: ts.accessToken, 202 Expiry: time.Now().Add(1 * time.Minute), 203 TokenType: "bearer", 204 } 205 return t, nil 206 }