github.com/hxx258456/ccgo@v0.0.5-0.20230213014102-48b35f46f66f/go-grpc-middleware/auth/auth_test.go (about)

     1  // Copyright 2016 Michal Witkowski. All Rights Reserved.
     2  // See LICENSE for licensing terms.
     3  
     4  package grpc_auth_test
     5  
     6  import (
     7  	"testing"
     8  
     9  	"github.com/hxx258456/ccgo/grpc"
    10  	"github.com/stretchr/testify/suite"
    11  
    12  	"fmt"
    13  
    14  	"time"
    15  
    16  	grpc_auth "github.com/hxx258456/ccgo/go-grpc-middleware/auth"
    17  	grpc_testing "github.com/hxx258456/ccgo/go-grpc-middleware/testing"
    18  	pb_testproto "github.com/hxx258456/ccgo/go-grpc-middleware/testing/testproto"
    19  	"github.com/hxx258456/ccgo/go-grpc-middleware/util/metautils"
    20  	"github.com/hxx258456/ccgo/grpc/codes"
    21  	"github.com/hxx258456/ccgo/grpc/credentials/oauth"
    22  	"github.com/hxx258456/ccgo/grpc/metadata"
    23  	"github.com/hxx258456/ccgo/net/context"
    24  	"github.com/stretchr/testify/assert"
    25  	"github.com/stretchr/testify/require"
    26  	"golang.org/x/oauth2"
    27  )
    28  
    29  var (
    30  	commonAuthToken   = "some_good_token"
    31  	overrideAuthToken = "override_token"
    32  
    33  	authedMarker = "some_context_marker"
    34  	goodPing     = &pb_testproto.PingRequest{Value: "something", SleepTimeMs: 9999}
    35  )
    36  
    37  // TODO(mwitkow): Add auth from metadata client dialer, which requires TLS.
    38  
    39  func buildDummyAuthFunction(expectedScheme string, expectedToken string) func(ctx context.Context) (context.Context, error) {
    40  	return func(ctx context.Context) (context.Context, error) {
    41  		token, err := grpc_auth.AuthFromMD(ctx, expectedScheme)
    42  		if err != nil {
    43  			return nil, err
    44  		}
    45  		if token != expectedToken {
    46  			return nil, grpc.Errorf(codes.PermissionDenied, "buildDummyAuthFunction bad token")
    47  		}
    48  		return context.WithValue(ctx, authedMarker, "marker_exists"), nil
    49  	}
    50  }
    51  
    52  func assertAuthMarkerExists(t *testing.T, ctx context.Context) {
    53  	assert.Equal(t, "marker_exists", ctx.Value(authedMarker).(string), "auth marker from buildDummyAuthFunction must be passed around")
    54  }
    55  
    56  type assertingPingService struct {
    57  	pb_testproto.TestServiceServer
    58  	T *testing.T
    59  }
    60  
    61  func (s *assertingPingService) PingError(ctx context.Context, ping *pb_testproto.PingRequest) (*pb_testproto.Empty, error) {
    62  	assertAuthMarkerExists(s.T, ctx)
    63  	return s.TestServiceServer.PingError(ctx, ping)
    64  }
    65  
    66  func (s *assertingPingService) PingList(ping *pb_testproto.PingRequest, stream pb_testproto.TestService_PingListServer) error {
    67  	assertAuthMarkerExists(s.T, stream.Context())
    68  	return s.TestServiceServer.PingList(ping, stream)
    69  }
    70  
    71  func ctxWithToken(ctx context.Context, scheme string, token string) context.Context {
    72  	md := metadata.Pairs("authorization", fmt.Sprintf("%s %v", scheme, token))
    73  	nCtx := metautils.NiceMD(md).ToOutgoing(ctx)
    74  	return nCtx
    75  }
    76  
    77  func TestAuthTestSuite(t *testing.T) {
    78  	authFunc := buildDummyAuthFunction("bearer", commonAuthToken)
    79  	s := &AuthTestSuite{
    80  		InterceptorTestSuite: &grpc_testing.InterceptorTestSuite{
    81  			TestService: &assertingPingService{&grpc_testing.TestPingService{T: t}, t},
    82  			ServerOpts: []grpc.ServerOption{
    83  				grpc.StreamInterceptor(grpc_auth.StreamServerInterceptor(authFunc)),
    84  				grpc.UnaryInterceptor(grpc_auth.UnaryServerInterceptor(authFunc)),
    85  			},
    86  		},
    87  	}
    88  	suite.Run(t, s)
    89  }
    90  
    91  type AuthTestSuite struct {
    92  	*grpc_testing.InterceptorTestSuite
    93  }
    94  
    95  func (s *AuthTestSuite) TestUnary_NoAuth() {
    96  	_, err := s.Client.Ping(s.SimpleCtx(), goodPing)
    97  	assert.Error(s.T(), err, "there must be an error")
    98  	assert.Equal(s.T(), codes.Unauthenticated, grpc.Code(err), "must error with unauthenticated")
    99  }
   100  
   101  func (s *AuthTestSuite) TestUnary_BadAuth() {
   102  	_, err := s.Client.Ping(ctxWithToken(s.SimpleCtx(), "bearer", "bad_token"), goodPing)
   103  	assert.Error(s.T(), err, "there must be an error")
   104  	assert.Equal(s.T(), codes.PermissionDenied, grpc.Code(err), "must error with permission denied")
   105  }
   106  
   107  func (s *AuthTestSuite) TestUnary_PassesAuth() {
   108  	_, err := s.Client.Ping(ctxWithToken(s.SimpleCtx(), "bearer", commonAuthToken), goodPing)
   109  	require.NoError(s.T(), err, "no error must occur")
   110  }
   111  
   112  func (s *AuthTestSuite) TestUnary_PassesWithPerRpcCredentials() {
   113  	grpcCreds := oauth.TokenSource{TokenSource: &fakeOAuth2TokenSource{accessToken: commonAuthToken}}
   114  	client := s.NewClient(grpc.WithPerRPCCredentials(grpcCreds))
   115  	_, err := client.Ping(s.SimpleCtx(), goodPing)
   116  	require.NoError(s.T(), err, "no error must occur")
   117  }
   118  
   119  func (s *AuthTestSuite) TestStream_NoAuth() {
   120  	stream, err := s.Client.PingList(s.SimpleCtx(), goodPing)
   121  	require.NoError(s.T(), err, "should not fail on establishing the stream")
   122  	_, err = stream.Recv()
   123  	assert.Error(s.T(), err, "there must be an error")
   124  	assert.Equal(s.T(), codes.Unauthenticated, grpc.Code(err), "must error with unauthenticated")
   125  }
   126  
   127  func (s *AuthTestSuite) TestStream_BadAuth() {
   128  	stream, err := s.Client.PingList(ctxWithToken(s.SimpleCtx(), "bearer", "bad_token"), goodPing)
   129  	require.NoError(s.T(), err, "should not fail on establishing the stream")
   130  	_, err = stream.Recv()
   131  	assert.Error(s.T(), err, "there must be an error")
   132  	assert.Equal(s.T(), codes.PermissionDenied, grpc.Code(err), "must error with permission denied")
   133  }
   134  
   135  func (s *AuthTestSuite) TestStream_PassesAuth() {
   136  	stream, err := s.Client.PingList(ctxWithToken(s.SimpleCtx(), "Bearer", commonAuthToken), goodPing)
   137  	require.NoError(s.T(), err, "should not fail on establishing the stream")
   138  	pong, err := stream.Recv()
   139  	require.NoError(s.T(), err, "no error must occur")
   140  	require.NotNil(s.T(), pong, "pong must not be nil")
   141  }
   142  
   143  func (s *AuthTestSuite) TestStream_PassesWithPerRpcCredentials() {
   144  	grpcCreds := oauth.TokenSource{TokenSource: &fakeOAuth2TokenSource{accessToken: commonAuthToken}}
   145  	client := s.NewClient(grpc.WithPerRPCCredentials(grpcCreds))
   146  	stream, err := client.PingList(s.SimpleCtx(), goodPing)
   147  	require.NoError(s.T(), err, "should not fail on establishing the stream")
   148  	pong, err := stream.Recv()
   149  	require.NoError(s.T(), err, "no error must occur")
   150  	require.NotNil(s.T(), pong, "pong must not be nil")
   151  }
   152  
   153  type authOverrideTestService struct {
   154  	pb_testproto.TestServiceServer
   155  	T *testing.T
   156  }
   157  
   158  func (s *authOverrideTestService) AuthFuncOverride(ctx context.Context, fullMethodName string) (context.Context, error) {
   159  	assert.NotEmpty(s.T, fullMethodName, "method name of caller is passed around")
   160  	return buildDummyAuthFunction("bearer", overrideAuthToken)(ctx)
   161  }
   162  
   163  func TestAuthOverrideTestSuite(t *testing.T) {
   164  	authFunc := buildDummyAuthFunction("bearer", commonAuthToken)
   165  	s := &AuthOverrideTestSuite{
   166  		InterceptorTestSuite: &grpc_testing.InterceptorTestSuite{
   167  			TestService: &authOverrideTestService{&assertingPingService{&grpc_testing.TestPingService{T: t}, t}, t},
   168  			ServerOpts: []grpc.ServerOption{
   169  				grpc.StreamInterceptor(grpc_auth.StreamServerInterceptor(authFunc)),
   170  				grpc.UnaryInterceptor(grpc_auth.UnaryServerInterceptor(authFunc)),
   171  			},
   172  		},
   173  	}
   174  	suite.Run(t, s)
   175  }
   176  
   177  type AuthOverrideTestSuite struct {
   178  	*grpc_testing.InterceptorTestSuite
   179  }
   180  
   181  func (s *AuthOverrideTestSuite) TestUnary_PassesAuth() {
   182  	_, err := s.Client.Ping(ctxWithToken(s.SimpleCtx(), "bearer", overrideAuthToken), goodPing)
   183  	require.NoError(s.T(), err, "no error must occur")
   184  }
   185  
   186  func (s *AuthOverrideTestSuite) TestStream_PassesAuth() {
   187  	stream, err := s.Client.PingList(ctxWithToken(s.SimpleCtx(), "Bearer", overrideAuthToken), goodPing)
   188  	require.NoError(s.T(), err, "should not fail on establishing the stream")
   189  	pong, err := stream.Recv()
   190  	require.NoError(s.T(), err, "no error must occur")
   191  	require.NotNil(s.T(), pong, "pong must not be nil")
   192  }
   193  
   194  // fakeOAuth2TokenSource implements a fake oauth2.TokenSource for the purpose of credentials test.
   195  type fakeOAuth2TokenSource struct {
   196  	accessToken string
   197  }
   198  
   199  func (ts *fakeOAuth2TokenSource) Token() (*oauth2.Token, error) {
   200  	t := &oauth2.Token{
   201  		AccessToken: ts.accessToken,
   202  		Expiry:      time.Now().Add(1 * time.Minute),
   203  		TokenType:   "bearer",
   204  	}
   205  	return t, nil
   206  }