github.com/weaviate/weaviate@v1.24.6/usecases/traverser/authorization_test.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 package traverser 13 14 import ( 15 "context" 16 "errors" 17 "fmt" 18 "reflect" 19 "testing" 20 21 "github.com/sirupsen/logrus/hooks/test" 22 "github.com/stretchr/testify/assert" 23 "github.com/stretchr/testify/require" 24 "github.com/weaviate/weaviate/entities/aggregation" 25 "github.com/weaviate/weaviate/entities/dto" 26 "github.com/weaviate/weaviate/entities/models" 27 "github.com/weaviate/weaviate/usecases/config" 28 ) 29 30 // A component-test like test suite that makes sure that every available UC is 31 // potentially protected with the Authorization plugin 32 33 func Test_Traverser_Authorization(t *testing.T) { 34 type testCase struct { 35 methodName string 36 additionalArgs []interface{} 37 expectedVerb string 38 expectedResource string 39 } 40 41 tests := []testCase{ 42 { 43 methodName: "GetClass", 44 additionalArgs: []interface{}{dto.GetParams{}}, 45 expectedVerb: "get", 46 expectedResource: "traversal/*", 47 }, 48 49 { 50 methodName: "Aggregate", 51 additionalArgs: []interface{}{&aggregation.Params{}}, 52 expectedVerb: "get", 53 expectedResource: "traversal/*", 54 }, 55 56 { 57 methodName: "Explore", 58 additionalArgs: []interface{}{ExploreParams{}}, 59 expectedVerb: "get", 60 expectedResource: "traversal/*", 61 }, 62 } 63 64 t.Run("verify that a test for every public method exists", func(t *testing.T) { 65 testedMethods := make([]string, len(tests)) 66 for i, test := range tests { 67 testedMethods[i] = test.methodName 68 } 69 70 for _, method := range allExportedMethods(&Traverser{}) { 71 assert.Contains(t, testedMethods, method) 72 } 73 }) 74 75 t.Run("verify the tested methods require correct permissions from the authorizer", func(t *testing.T) { 76 principal := &models.Principal{} 77 logger, _ := test.NewNullLogger() 78 for _, test := range tests { 79 locks := &fakeLocks{} 80 authorizer := &authDenier{} 81 vectorRepo := &fakeVectorRepo{} 82 explorer := &fakeExplorer{} 83 schemaGetter := &fakeSchemaGetter{} 84 85 manager := NewTraverser(&config.WeaviateConfig{}, locks, logger, authorizer, 86 vectorRepo, explorer, schemaGetter, nil, nil, -1) 87 88 args := append([]interface{}{context.Background(), principal}, test.additionalArgs...) 89 out, _ := callFuncByName(manager, test.methodName, args...) 90 91 require.Len(t, authorizer.calls, 1, "authorizer must be called") 92 assert.Equal(t, errors.New("just a test fake"), out[len(out)-1].Interface(), 93 "execution must abort with authorizer error") 94 assert.Equal(t, authorizeCall{principal, test.expectedVerb, test.expectedResource}, 95 authorizer.calls[0], "correct parameters must have been used on authorizer") 96 } 97 }) 98 } 99 100 type authorizeCall struct { 101 principal *models.Principal 102 verb string 103 resource string 104 } 105 106 type authDenier struct { 107 calls []authorizeCall 108 } 109 110 func (a *authDenier) Authorize(principal *models.Principal, verb, resource string) error { 111 a.calls = append(a.calls, authorizeCall{principal, verb, resource}) 112 return errors.New("just a test fake") 113 } 114 115 // inspired by https://stackoverflow.com/a/33008200 116 func callFuncByName(manager interface{}, funcName string, params ...interface{}) (out []reflect.Value, err error) { 117 managerValue := reflect.ValueOf(manager) 118 m := managerValue.MethodByName(funcName) 119 if !m.IsValid() { 120 return make([]reflect.Value, 0), fmt.Errorf("Method not found \"%s\"", funcName) 121 } 122 in := make([]reflect.Value, len(params)) 123 for i, param := range params { 124 in[i] = reflect.ValueOf(param) 125 } 126 out = m.Call(in) 127 return 128 } 129 130 func allExportedMethods(subject interface{}) []string { 131 var methods []string 132 subjectType := reflect.TypeOf(subject) 133 for i := 0; i < subjectType.NumMethod(); i++ { 134 name := subjectType.Method(i).Name 135 if name[0] >= 'A' && name[0] <= 'Z' { 136 methods = append(methods, name) 137 } 138 } 139 140 return methods 141 }