github.com/weaviate/weaviate@v1.24.6/usecases/schema/authorization_test.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package schema
    13  
    14  import (
    15  	"context"
    16  	"errors"
    17  	"fmt"
    18  	"reflect"
    19  	"testing"
    20  
    21  	"github.com/sirupsen/logrus/hooks/test"
    22  	"github.com/stretchr/testify/assert"
    23  	"github.com/stretchr/testify/require"
    24  	"github.com/weaviate/weaviate/entities/models"
    25  	"github.com/weaviate/weaviate/usecases/config"
    26  )
    27  
    28  // A component-test like test suite that makes sure that every available UC is
    29  // potentially protected with the Authorization plugin
    30  
    31  func Test_Schema_Authorization(t *testing.T) {
    32  	type testCase struct {
    33  		methodName       string
    34  		additionalArgs   []interface{}
    35  		expectedVerb     string
    36  		expectedResource string
    37  	}
    38  
    39  	tests := []testCase{
    40  		{
    41  			methodName:       "GetSchema",
    42  			expectedVerb:     "list",
    43  			expectedResource: "schema/*",
    44  		},
    45  		{
    46  			methodName:       "GetClass",
    47  			additionalArgs:   []interface{}{"classname"},
    48  			expectedVerb:     "list",
    49  			expectedResource: "schema/*",
    50  		},
    51  		{
    52  			methodName:       "GetShardsStatus",
    53  			additionalArgs:   []interface{}{"className", "tenant"},
    54  			expectedVerb:     "list",
    55  			expectedResource: "schema/className/shards",
    56  		},
    57  		{
    58  			methodName:       "AddClass",
    59  			additionalArgs:   []interface{}{&models.Class{}},
    60  			expectedVerb:     "create",
    61  			expectedResource: "schema/objects",
    62  		},
    63  		{
    64  			methodName:       "UpdateClass",
    65  			additionalArgs:   []interface{}{"somename", &models.Class{}},
    66  			expectedVerb:     "update",
    67  			expectedResource: "schema/objects",
    68  		},
    69  		{
    70  			methodName:       "DeleteClass",
    71  			additionalArgs:   []interface{}{"somename"},
    72  			expectedVerb:     "delete",
    73  			expectedResource: "schema/objects",
    74  		},
    75  		{
    76  			methodName:       "AddClassProperty",
    77  			additionalArgs:   []interface{}{"somename", &models.Property{}},
    78  			expectedVerb:     "update",
    79  			expectedResource: "schema/objects",
    80  		},
    81  		{
    82  			methodName:       "MergeClassObjectProperty",
    83  			additionalArgs:   []interface{}{"somename", &models.Property{}},
    84  			expectedVerb:     "update",
    85  			expectedResource: "schema/objects",
    86  		},
    87  		{
    88  			methodName:       "DeleteClassProperty",
    89  			additionalArgs:   []interface{}{"somename", "someprop"},
    90  			expectedVerb:     "update",
    91  			expectedResource: "schema/objects",
    92  		},
    93  		{
    94  			methodName:       "UpdateShardStatus",
    95  			additionalArgs:   []interface{}{"className", "shardName", "targetStatus"},
    96  			expectedVerb:     "update",
    97  			expectedResource: "schema/className/shards/shardName",
    98  		},
    99  		{
   100  			methodName:       "AddTenants",
   101  			additionalArgs:   []interface{}{"className", []*models.Tenant{{Name: "P1"}}},
   102  			expectedVerb:     "update",
   103  			expectedResource: tenantsPath,
   104  		},
   105  		{
   106  			methodName: "UpdateTenants",
   107  			additionalArgs: []interface{}{"className", []*models.Tenant{
   108  				{Name: "P1", ActivityStatus: models.TenantActivityStatusHOT},
   109  			}},
   110  			expectedVerb:     "update",
   111  			expectedResource: tenantsPath,
   112  		},
   113  		{
   114  			methodName:       "DeleteTenants",
   115  			additionalArgs:   []interface{}{"className", []string{"P1"}},
   116  			expectedVerb:     "delete",
   117  			expectedResource: tenantsPath,
   118  		},
   119  		{
   120  			methodName:       "GetTenants",
   121  			additionalArgs:   []interface{}{"className"},
   122  			expectedVerb:     "get",
   123  			expectedResource: tenantsPath,
   124  		},
   125  	}
   126  
   127  	t.Run("verify that a test for every public method exists", func(t *testing.T) {
   128  		testedMethods := make([]string, len(tests))
   129  		for i, test := range tests {
   130  			testedMethods[i] = test.methodName
   131  		}
   132  
   133  		for _, method := range allExportedMethods(&Manager{}) {
   134  			switch method {
   135  			case "RegisterSchemaUpdateCallback",
   136  				"UpdateMeta", "GetSchemaSkipAuth", "IndexedInverted", "RLock", "RUnlock", "Lock", "Unlock",
   137  				"TryLock", "RLocker", "TryRLock", // introduced by sync.Mutex in go 1.18
   138  				"Nodes", "NodeName", "ClusterHealthScore", "ClusterStatus", "ResolveParentNodes",
   139  				"CopyShardingState", "TxManager", "RestoreClass",
   140  				"ShardOwner", "TenantShard", "ShardFromUUID", "LockGuard", "RLockGuard", "ShardReplicas",
   141  				"StartServing", "Shutdown": // internal methods to indicate readiness state
   142  				// don't require auth on methods which are exported because other
   143  				// packages need to call them for maintenance and other regular jobs,
   144  				// but aren't user facing
   145  				continue
   146  			}
   147  			assert.Contains(t, testedMethods, method)
   148  		}
   149  	})
   150  
   151  	t.Run("verify the tested methods require correct permissions from the Authorizer", func(t *testing.T) {
   152  		principal := &models.Principal{}
   153  		logger, _ := test.NewNullLogger()
   154  		for _, test := range tests {
   155  			t.Run(test.methodName, func(t *testing.T) {
   156  				authorizer := &authDenier{}
   157  				manager, err := NewManager(&NilMigrator{}, newFakeRepo(),
   158  					logger, authorizer, config.Config{},
   159  					dummyParseVectorConfig, &fakeVectorizerValidator{},
   160  					dummyValidateInvertedConfig, &fakeModuleConfig{},
   161  					&fakeClusterState{hosts: []string{"node1"}}, &fakeTxClient{},
   162  					&fakeTxPersistence{}, &fakeScaleOutManager{})
   163  				require.Nil(t, err)
   164  
   165  				var args []interface{}
   166  				if test.methodName == "GetSchema" {
   167  					// no context on this method
   168  					args = append([]interface{}{principal}, test.additionalArgs...)
   169  				} else {
   170  					args = append([]interface{}{context.Background(), principal}, test.additionalArgs...)
   171  				}
   172  				out, _ := callFuncByName(manager, test.methodName, args...)
   173  
   174  				require.Len(t, authorizer.calls, 1, "Authorizer must be called")
   175  				assert.Equal(t, errors.New("just a test fake"), out[len(out)-1].Interface(),
   176  					"execution must abort with Authorizer error")
   177  				assert.Equal(t, authorizeCall{principal, test.expectedVerb, test.expectedResource},
   178  					authorizer.calls[0], "correct parameters must have been used on Authorizer")
   179  			})
   180  		}
   181  	})
   182  }
   183  
   184  type authorizeCall struct {
   185  	principal *models.Principal
   186  	verb      string
   187  	resource  string
   188  }
   189  
   190  type authDenier struct {
   191  	calls []authorizeCall
   192  }
   193  
   194  func (a *authDenier) Authorize(principal *models.Principal, verb, resource string) error {
   195  	a.calls = append(a.calls, authorizeCall{principal, verb, resource})
   196  	return errors.New("just a test fake")
   197  }
   198  
   199  // inspired by https://stackoverflow.com/a/33008200
   200  func callFuncByName(manager interface{}, funcName string, params ...interface{}) (out []reflect.Value, err error) {
   201  	managerValue := reflect.ValueOf(manager)
   202  	m := managerValue.MethodByName(funcName)
   203  	if !m.IsValid() {
   204  		return make([]reflect.Value, 0), fmt.Errorf("Method not found \"%s\"", funcName)
   205  	}
   206  	in := make([]reflect.Value, len(params))
   207  	for i, param := range params {
   208  		in[i] = reflect.ValueOf(param)
   209  	}
   210  	out = m.Call(in)
   211  	return
   212  }
   213  
   214  func allExportedMethods(subject interface{}) []string {
   215  	var methods []string
   216  	subjectType := reflect.TypeOf(subject)
   217  	for i := 0; i < subjectType.NumMethod(); i++ {
   218  		name := subjectType.Method(i).Name
   219  		if name[0] >= 'A' && name[0] <= 'Z' {
   220  			methods = append(methods, name)
   221  		}
   222  	}
   223  
   224  	return methods
   225  }