github.com/kisexp/xdchain@v0.0.0-20211206025815-490d6b732aa7/rpc/security_test.go (about)

     1  package rpc
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"net/http"
     7  	"os"
     8  	"strconv"
     9  	"testing"
    10  	"time"
    11  
    12  	"github.com/kisexp/xdchain/core/types"
    13  	"github.com/golang/protobuf/ptypes"
    14  	"github.com/jpmorganchase/quorum-security-plugin-sdk-go/proto"
    15  	testifyassert "github.com/stretchr/testify/assert"
    16  )
    17  
    18  func TestVerifyAccess_whenNotMatch(t *testing.T) {
    19  	assert := testifyassert.New(t)
    20  
    21  	assert.Error(verifyAccess("xyz", "abc", []*proto.GrantedAuthority{
    22  		{
    23  			Service: "bar",
    24  			Method:  "foo",
    25  		},
    26  	}))
    27  }
    28  
    29  func TestVerifyAccess_whenEmpty(t *testing.T) {
    30  	assert := testifyassert.New(t)
    31  
    32  	assert.Error(verifyAccess("xyz", "abc", nil))
    33  }
    34  
    35  func TestVerifyAccess_whenExactMatch(t *testing.T) {
    36  	assert := testifyassert.New(t)
    37  
    38  	assert.NoError(verifyAccess("bar", "foo", []*proto.GrantedAuthority{
    39  		{
    40  			Service: "xyz",
    41  			Method:  "abc",
    42  		},
    43  		{
    44  			Service: "bar",
    45  			Method:  "foo",
    46  		},
    47  	}))
    48  }
    49  
    50  func TestVerifyAccess_whenWildcardServiceMatch(t *testing.T) {
    51  	assert := testifyassert.New(t)
    52  
    53  	assert.NoError(verifyAccess("bar", "foo", []*proto.GrantedAuthority{
    54  		{
    55  			Service: "xyz",
    56  			Method:  "abc",
    57  		},
    58  		{
    59  			Service: "*",
    60  			Method:  "foo",
    61  		},
    62  	}))
    63  }
    64  
    65  func TestVerifyAccess_whenWildcardMethodMatch(t *testing.T) {
    66  	assert := testifyassert.New(t)
    67  
    68  	assert.NoError(verifyAccess("bar", "foo", []*proto.GrantedAuthority{
    69  		{
    70  			Service: "xyz",
    71  			Method:  "abc",
    72  		},
    73  		{
    74  			Service: "bar",
    75  			Method:  "*",
    76  		},
    77  	}))
    78  }
    79  
    80  func TestVerifyAccess_whenWildcardMatch(t *testing.T) {
    81  	assert := testifyassert.New(t)
    82  
    83  	assert.NoError(verifyAccess("bar", "foo", []*proto.GrantedAuthority{
    84  		{
    85  			Service: "*",
    86  			Method:  "*",
    87  		},
    88  	}))
    89  }
    90  
    91  func TestVerifyExpiration_whenTypical(t *testing.T) {
    92  	assert := testifyassert.New(t)
    93  	expiredAt, _ := ptypes.TimestampProto(time.Now().Add(1 * time.Minute))
    94  	assert.NoError(verifyExpiration(&proto.PreAuthenticatedAuthenticationToken{
    95  		ExpiredAt: expiredAt,
    96  	}))
    97  }
    98  
    99  func TestVerifyExpiration_whenExpired(t *testing.T) {
   100  	assert := testifyassert.New(t)
   101  	expiredAt, _ := ptypes.TimestampProto(time.Now().Add(-1 * time.Minute))
   102  	assert.Error(verifyExpiration(&proto.PreAuthenticatedAuthenticationToken{
   103  		ExpiredAt: expiredAt,
   104  	}))
   105  }
   106  
   107  func TestExtractToken_whenTypical(t *testing.T) {
   108  	assert := testifyassert.New(t)
   109  	req, _ := http.NewRequest("POST", "", nil)
   110  	arbitraryValue := "xyz"
   111  	req.Header.Set(HttpAuthorizationHeader, arbitraryValue)
   112  
   113  	token, ok := extractToken(req)
   114  
   115  	assert.True(ok)
   116  	assert.Equal(arbitraryValue, token)
   117  }
   118  
   119  func TestExtractToken_whenEmpty(t *testing.T) {
   120  	assert := testifyassert.New(t)
   121  	req, _ := http.NewRequest("POST", "", nil)
   122  
   123  	_, ok := extractToken(req)
   124  
   125  	assert.False(ok)
   126  }
   127  
   128  func TestSecureCall_whenThereIsAuthenticationError(t *testing.T) {
   129  	assert := testifyassert.New(t)
   130  	arbitraryError := errors.New("arbitrary error")
   131  	stubSecurityContextResolver := newStubSecurityContextResolver([]struct{ k, v interface{} }{
   132  		{ctxAuthenticationError, arbitraryError},
   133  	})
   134  
   135  	_, err := SecureCall(stubSecurityContextResolver, "")
   136  
   137  	assert.EqualError(err, arbitraryError.Error())
   138  }
   139  
   140  func TestSecureCall_whenTokenExpired(t *testing.T) {
   141  	assert := testifyassert.New(t)
   142  	expiredAt, _ := ptypes.TimestampProto(time.Now().Add(-1 * time.Hour))
   143  	stubSecurityContextResolver := newStubSecurityContextResolver([]struct{ k, v interface{} }{
   144  		{ctxPreauthenticatedToken, &proto.PreAuthenticatedAuthenticationToken{
   145  			ExpiredAt: expiredAt,
   146  		}},
   147  	})
   148  
   149  	_, err := SecureCall(stubSecurityContextResolver, "")
   150  
   151  	assert.EqualError(err, "token expired")
   152  }
   153  
   154  func TestSecureCall_whenTypical(t *testing.T) {
   155  	assert := testifyassert.New(t)
   156  	expiredAt, _ := ptypes.TimestampProto(time.Now().Add(1 * time.Hour))
   157  	stubSecurityContextResolver := newStubSecurityContextResolver([]struct{ k, v interface{} }{
   158  		{ctxPreauthenticatedToken, &proto.PreAuthenticatedAuthenticationToken{
   159  			ExpiredAt: expiredAt,
   160  			Authorities: []*proto.GrantedAuthority{
   161  				{
   162  					Service: "eth",
   163  					Method:  "blockNumber",
   164  				},
   165  			},
   166  		}},
   167  	})
   168  
   169  	_, err := SecureCall(stubSecurityContextResolver, "eth_blockNumber")
   170  
   171  	assert.NoError(err)
   172  }
   173  
   174  func TestSecureCall_whenAccessDenied(t *testing.T) {
   175  	assert := testifyassert.New(t)
   176  	expiredAt, _ := ptypes.TimestampProto(time.Now().Add(1 * time.Hour))
   177  	stubSecurityContextResolver := newStubSecurityContextResolver([]struct{ k, v interface{} }{
   178  		{ctxPreauthenticatedToken, &proto.PreAuthenticatedAuthenticationToken{
   179  			ExpiredAt: expiredAt,
   180  			Authorities: []*proto.GrantedAuthority{
   181  				{
   182  					Service: "eth",
   183  					Method:  "blockNumber",
   184  				},
   185  			},
   186  		}},
   187  	})
   188  
   189  	_, err := SecureCall(stubSecurityContextResolver, "eth_someMethod")
   190  
   191  	assert.EqualError(err, "eth_someMethod - access denied")
   192  }
   193  
   194  func TestSecureCall_whenMethodInJSONMessageIsNotSupported(t *testing.T) {
   195  	assert := testifyassert.New(t)
   196  	expiredAt, _ := ptypes.TimestampProto(time.Now().Add(1 * time.Hour))
   197  	stubSecurityContextResolver := newStubSecurityContextResolver([]struct{ k, v interface{} }{
   198  		{ctxPreauthenticatedToken, &proto.PreAuthenticatedAuthenticationToken{
   199  			ExpiredAt: expiredAt,
   200  		}},
   201  	})
   202  
   203  	_, err := SecureCall(stubSecurityContextResolver, "arbitrary method")
   204  
   205  	assert.NoError(err)
   206  }
   207  
   208  type stubSecurityContextResolver struct {
   209  	ctx SecurityContext
   210  }
   211  
   212  func newStubSecurityContextResolver(ctx []struct{ k, v interface{} }) *stubSecurityContextResolver {
   213  	sc := SecurityContext(context.Background())
   214  	for _, kv := range ctx {
   215  		sc = context.WithValue(sc, kv.k, kv.v)
   216  	}
   217  	return &stubSecurityContextResolver{sc}
   218  }
   219  
   220  func (sr *stubSecurityContextResolver) Resolve() SecurityContext {
   221  	return sr.ctx
   222  }
   223  
   224  func TestResolvePSIProvider_whenTypicalEndpoints(t *testing.T) {
   225  	testCases := []struct {
   226  		endpoint    string
   227  		expectedPSI types.PrivateStateIdentifier
   228  	}{
   229  		{
   230  			endpoint:    "http://aritraryhost?PSI=PS1",
   231  			expectedPSI: types.ToPrivateStateIdentifier("PS1"),
   232  		},
   233  		{
   234  			endpoint:    "https://aritraryhost?PSI=PS2",
   235  			expectedPSI: types.ToPrivateStateIdentifier("PS2"),
   236  		},
   237  		{
   238  			endpoint:    "ws://aritraryhost?PSI=PS3",
   239  			expectedPSI: types.ToPrivateStateIdentifier("PS3"),
   240  		},
   241  		{
   242  			endpoint:    "wss://aritraryhost?PSI=PS4",
   243  			expectedPSI: types.ToPrivateStateIdentifier("PS4"),
   244  		},
   245  	}
   246  	for _, tc := range testCases {
   247  		actualCtx := resolvePSIProvider(context.Background(), tc.endpoint)
   248  
   249  		f := PSIProviderFromContext(actualCtx)
   250  		testifyassert.NotNil(t, f)
   251  		actualPSI, err := f(context.Background())
   252  		testifyassert.NoError(t, err)
   253  		testifyassert.Equal(t, tc.expectedPSI, actualPSI)
   254  	}
   255  }
   256  
   257  func TestResolvePSIProvider_whenEnvVariableTakesPrecedence(t *testing.T) {
   258  	_ = os.Setenv(EnvVarPrivateStateIdentifier, "ENV_PS1")
   259  	defer func() { _ = os.Unsetenv(EnvVarPrivateStateIdentifier) }()
   260  
   261  	endpoint := "http://aritraryhost?PSI=PS1"
   262  	actualCtx := resolvePSIProvider(context.Background(), endpoint)
   263  
   264  	f := PSIProviderFromContext(actualCtx)
   265  	testifyassert.NotNil(t, f)
   266  	actualPSI, err := f(context.Background())
   267  	testifyassert.NoError(t, err)
   268  	testifyassert.Equal(t, types.ToPrivateStateIdentifier("ENV_PS1"), actualPSI)
   269  }
   270  
   271  func TestResolvePSIProvider_whenNoPSI(t *testing.T) {
   272  	endpoint := "data/geth.ipc"
   273  	actualCtx := resolvePSIProvider(context.Background(), endpoint)
   274  
   275  	testifyassert.Nil(t, PSIProviderFromContext(actualCtx))
   276  }
   277  
   278  func TestEncodePSI_whenTypical(t *testing.T) {
   279  	actual := encodePSI(strconv.AppendUint(nil, 32, 10), "ARBITRARY")
   280  
   281  	testifyassert.Equal(t, "\"ARBITRARY/32\"", string(actual))
   282  }
   283  
   284  func TestEncodePSI_whenNoPSI(t *testing.T) {
   285  	actual := encodePSI(strconv.AppendUint(nil, 32, 10), "")
   286  
   287  	testifyassert.Equal(t, "32", string(actual))
   288  }
   289  
   290  func TestDecodePSI_whenTypical(t *testing.T) {
   291  	input := "\"ARBITRARY/1\""
   292  
   293  	psi := decodePSI([]byte(input))
   294  
   295  	testifyassert.Equal(t, types.PrivateStateIdentifier("ARBITRARY"), psi)
   296  }
   297  
   298  func TestDecodePSI_whenNoPSI(t *testing.T) {
   299  	inputs := []string{
   300  		"1",
   301  		"\"1",
   302  		"1\"",
   303  		"\"xyz\"",
   304  	}
   305  	for _, input := range inputs {
   306  		psi := decodePSI([]byte(input))
   307  
   308  		testifyassert.Equal(t, types.DefaultPrivateStateIdentifier, psi, "input: %s", input)
   309  	}
   310  }