gitee.com/ks-custle/core-gm@v0.0.0-20230922171213-b83bdd97b62c/go-grpc-middleware/auth/auth_test.go (about)

     1  // Copyright 2016 Michal Witkowski. All Rights Reserved.
     2  // See LICENSE for licensing terms.
     3  
     4  package grpc_auth_test
     5  
     6  import (
     7  	"gitee.com/ks-custle/core-gm/grpc/status"
     8  	"testing"
     9  
    10  	"gitee.com/ks-custle/core-gm/grpc"
    11  	"github.com/stretchr/testify/suite"
    12  
    13  	"fmt"
    14  
    15  	"time"
    16  
    17  	grpcauth "gitee.com/ks-custle/core-gm/go-grpc-middleware/auth"
    18  	grpctesting "gitee.com/ks-custle/core-gm/go-grpc-middleware/testing"
    19  	pbtestproto "gitee.com/ks-custle/core-gm/go-grpc-middleware/testing/testproto"
    20  	"gitee.com/ks-custle/core-gm/go-grpc-middleware/util/metautils"
    21  	"gitee.com/ks-custle/core-gm/grpc/codes"
    22  	"gitee.com/ks-custle/core-gm/grpc/credentials/oauth"
    23  	"gitee.com/ks-custle/core-gm/grpc/metadata"
    24  	"gitee.com/ks-custle/core-gm/net/context"
    25  	"github.com/stretchr/testify/assert"
    26  	"github.com/stretchr/testify/require"
    27  	"golang.org/x/oauth2"
    28  )
    29  
    30  var (
    31  	commonAuthToken   = "some_good_token"
    32  	overrideAuthToken = "override_token"
    33  
    34  	authedMarker = "some_context_marker"
    35  	goodPing     = &pbtestproto.PingRequest{Value: "something", SleepTimeMs: 9999}
    36  )
    37  
    38  // TODO(mwitkow): Add auth from metadata client dialer, which requires TLS.
    39  
    40  func buildDummyAuthFunction(expectedScheme string, expectedToken string) func(ctx context.Context) (context.Context, error) {
    41  	return func(ctx context.Context) (context.Context, error) {
    42  		token, err := grpcauth.AuthFromMD(ctx, expectedScheme)
    43  		if err != nil {
    44  			return nil, err
    45  		}
    46  		if token != expectedToken {
    47  			// `grpc.Errorf` is deprecated. use status.Errorf instead.
    48  			//return nil, grpc.Errorf(codes.PermissionDenied, "buildDummyAuthFunction bad token")
    49  			return nil, status.Errorf(codes.PermissionDenied, "buildDummyAuthFunction bad token")
    50  		}
    51  		return context.WithValue(ctx, authedMarker, "marker_exists"), nil
    52  	}
    53  }
    54  
    55  func assertAuthMarkerExists(t *testing.T, ctx context.Context) {
    56  	assert.Equal(t, "marker_exists", ctx.Value(authedMarker).(string), "auth marker from buildDummyAuthFunction must be passed around")
    57  }
    58  
    59  type assertingPingService struct {
    60  	pbtestproto.TestServiceServer
    61  	T *testing.T
    62  }
    63  
    64  func (s *assertingPingService) PingError(ctx context.Context, ping *pbtestproto.PingRequest) (*pbtestproto.Empty, error) {
    65  	assertAuthMarkerExists(s.T, ctx)
    66  	return s.TestServiceServer.PingError(ctx, ping)
    67  }
    68  
    69  func (s *assertingPingService) PingList(ping *pbtestproto.PingRequest, stream pbtestproto.TestService_PingListServer) error {
    70  	assertAuthMarkerExists(s.T, stream.Context())
    71  	return s.TestServiceServer.PingList(ping, stream)
    72  }
    73  
    74  func ctxWithToken(ctx context.Context, scheme string, token string) context.Context {
    75  	md := metadata.Pairs("authorization", fmt.Sprintf("%s %v", scheme, token))
    76  	nCtx := metautils.NiceMD(md).ToOutgoing(ctx)
    77  	return nCtx
    78  }
    79  
    80  func TestAuthTestSuite(t *testing.T) {
    81  	authFunc := buildDummyAuthFunction("bearer", commonAuthToken)
    82  	s := &AuthTestSuite{
    83  		InterceptorTestSuite: &grpctesting.InterceptorTestSuite{
    84  			TestService: &assertingPingService{&grpctesting.TestPingService{T: t}, t},
    85  			ServerOpts: []grpc.ServerOption{
    86  				grpc.StreamInterceptor(grpcauth.StreamServerInterceptor(authFunc)),
    87  				grpc.UnaryInterceptor(grpcauth.UnaryServerInterceptor(authFunc)),
    88  			},
    89  		},
    90  	}
    91  	suite.Run(t, s)
    92  }
    93  
    94  type AuthTestSuite struct {
    95  	*grpctesting.InterceptorTestSuite
    96  }
    97  
    98  func (s *AuthTestSuite) TestUnary_NoAuth() {
    99  	_, err := s.Client.Ping(s.SimpleCtx(), goodPing)
   100  	assert.Error(s.T(), err, "there must be an error")
   101  	// `grpc.Code` is deprecated. Use `status.Code` instead.
   102  	//assert.Equal(s.T(), codes.Unauthenticated, grpc.Code(err), "must error with unauthenticated")
   103  	assert.Equal(s.T(), codes.Unauthenticated, status.Code(err), "must error with unauthenticated")
   104  }
   105  
   106  func (s *AuthTestSuite) TestUnary_BadAuth() {
   107  	_, err := s.Client.Ping(ctxWithToken(s.SimpleCtx(), "bearer", "bad_token"), goodPing)
   108  	assert.Error(s.T(), err, "there must be an error")
   109  	// `grpc.Code` is deprecated. Use `status.Code` instead.
   110  	//assert.Equal(s.T(), codes.PermissionDenied, grpc.Code(err), "must error with permission denied")
   111  	assert.Equal(s.T(), codes.PermissionDenied, status.Code(err), "must error with permission denied")
   112  }
   113  
   114  func (s *AuthTestSuite) TestUnary_PassesAuth() {
   115  	_, err := s.Client.Ping(ctxWithToken(s.SimpleCtx(), "bearer", commonAuthToken), goodPing)
   116  	require.NoError(s.T(), err, "no error must occur")
   117  }
   118  
   119  func (s *AuthTestSuite) TestUnary_PassesWithPerRpcCredentials() {
   120  	grpcCreds := oauth.TokenSource{TokenSource: &fakeOAuth2TokenSource{accessToken: commonAuthToken}}
   121  	client := s.NewClient(grpc.WithPerRPCCredentials(grpcCreds))
   122  	_, err := client.Ping(s.SimpleCtx(), goodPing)
   123  	require.NoError(s.T(), err, "no error must occur")
   124  }
   125  
   126  func (s *AuthTestSuite) TestStream_NoAuth() {
   127  	stream, err := s.Client.PingList(s.SimpleCtx(), goodPing)
   128  	require.NoError(s.T(), err, "should not fail on establishing the stream")
   129  	_, err = stream.Recv()
   130  	assert.Error(s.T(), err, "there must be an error")
   131  	// `grpc.Code` is deprecated. Use `status.Code` instead.
   132  	//assert.Equal(s.T(), codes.Unauthenticated, grpc.Code(err), "must error with unauthenticated")
   133  	assert.Equal(s.T(), codes.Unauthenticated, status.Code(err), "must error with unauthenticated")
   134  }
   135  
   136  func (s *AuthTestSuite) TestStream_BadAuth() {
   137  	stream, err := s.Client.PingList(ctxWithToken(s.SimpleCtx(), "bearer", "bad_token"), goodPing)
   138  	require.NoError(s.T(), err, "should not fail on establishing the stream")
   139  	_, err = stream.Recv()
   140  	assert.Error(s.T(), err, "there must be an error")
   141  	// `grpc.Code` is deprecated. Use `status.Code` instead.
   142  	//assert.Equal(s.T(), codes.PermissionDenied, grpc.Code(err), "must error with permission denied")
   143  	assert.Equal(s.T(), codes.PermissionDenied, status.Code(err), "must error with permission denied")
   144  }
   145  
   146  func (s *AuthTestSuite) TestStream_PassesAuth() {
   147  	stream, err := s.Client.PingList(ctxWithToken(s.SimpleCtx(), "Bearer", commonAuthToken), goodPing)
   148  	require.NoError(s.T(), err, "should not fail on establishing the stream")
   149  	pong, err := stream.Recv()
   150  	require.NoError(s.T(), err, "no error must occur")
   151  	require.NotNil(s.T(), pong, "pong must not be nil")
   152  }
   153  
   154  func (s *AuthTestSuite) TestStream_PassesWithPerRpcCredentials() {
   155  	grpcCreds := oauth.TokenSource{TokenSource: &fakeOAuth2TokenSource{accessToken: commonAuthToken}}
   156  	client := s.NewClient(grpc.WithPerRPCCredentials(grpcCreds))
   157  	stream, err := client.PingList(s.SimpleCtx(), goodPing)
   158  	require.NoError(s.T(), err, "should not fail on establishing the stream")
   159  	pong, err := stream.Recv()
   160  	require.NoError(s.T(), err, "no error must occur")
   161  	require.NotNil(s.T(), pong, "pong must not be nil")
   162  }
   163  
   164  type authOverrideTestService struct {
   165  	pbtestproto.TestServiceServer
   166  	T *testing.T
   167  }
   168  
   169  func (s *authOverrideTestService) AuthFuncOverride(ctx context.Context, fullMethodName string) (context.Context, error) {
   170  	assert.NotEmpty(s.T, fullMethodName, "method name of caller is passed around")
   171  	return buildDummyAuthFunction("bearer", overrideAuthToken)(ctx)
   172  }
   173  
   174  func TestAuthOverrideTestSuite(t *testing.T) {
   175  	authFunc := buildDummyAuthFunction("bearer", commonAuthToken)
   176  	s := &AuthOverrideTestSuite{
   177  		InterceptorTestSuite: &grpctesting.InterceptorTestSuite{
   178  			TestService: &authOverrideTestService{&assertingPingService{&grpctesting.TestPingService{T: t}, t}, t},
   179  			ServerOpts: []grpc.ServerOption{
   180  				grpc.StreamInterceptor(grpcauth.StreamServerInterceptor(authFunc)),
   181  				grpc.UnaryInterceptor(grpcauth.UnaryServerInterceptor(authFunc)),
   182  			},
   183  		},
   184  	}
   185  	suite.Run(t, s)
   186  }
   187  
   188  type AuthOverrideTestSuite struct {
   189  	*grpctesting.InterceptorTestSuite
   190  }
   191  
   192  func (s *AuthOverrideTestSuite) TestUnary_PassesAuth() {
   193  	_, err := s.Client.Ping(ctxWithToken(s.SimpleCtx(), "bearer", overrideAuthToken), goodPing)
   194  	require.NoError(s.T(), err, "no error must occur")
   195  }
   196  
   197  func (s *AuthOverrideTestSuite) TestStream_PassesAuth() {
   198  	stream, err := s.Client.PingList(ctxWithToken(s.SimpleCtx(), "Bearer", overrideAuthToken), goodPing)
   199  	require.NoError(s.T(), err, "should not fail on establishing the stream")
   200  	pong, err := stream.Recv()
   201  	require.NoError(s.T(), err, "no error must occur")
   202  	require.NotNil(s.T(), pong, "pong must not be nil")
   203  }
   204  
   205  // fakeOAuth2TokenSource implements a fake oauth2.TokenSource for the purpose of credentials test.
   206  type fakeOAuth2TokenSource struct {
   207  	accessToken string
   208  }
   209  
   210  func (ts *fakeOAuth2TokenSource) Token() (*oauth2.Token, error) {
   211  	t := &oauth2.Token{
   212  		AccessToken: ts.accessToken,
   213  		Expiry:      time.Now().Add(1 * time.Minute),
   214  		TokenType:   "bearer",
   215  	}
   216  	return t, nil
   217  }