github.com/weaviate/weaviate@v1.24.6/usecases/classification/authorization_test.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 package classification 13 14 // import ( 15 // "context" 16 // "errors" 17 // "fmt" 18 // "reflect" 19 // "testing" 20 21 // "github.com/go-openapi/strfmt" 22 // "github.com/weaviate/weaviate/entities/models" 23 // "github.com/stretchr/testify/assert" 24 // "github.com/stretchr/testify/require" 25 // ) 26 27 // // A component-test like test suite that makes sure that every available UC is 28 // // potentially protected with the Authorization plugin 29 30 // func Test_Classifier_Authorization(t *testing.T) { 31 32 // type testCase struct { 33 // methodName string 34 // additionalArgs []interface{} 35 // expectedVerb string 36 // expectedResource string 37 // } 38 39 // tests := []testCase{ 40 // testCase{ 41 // methodName: "Get", 42 // additionalArgs: []interface{}{strfmt.UUID("")}, 43 // expectedVerb: "get", 44 // expectedResource: "classifications/*", 45 // }, 46 // testCase{ 47 // methodName: "Schedule", 48 // additionalArgs: []interface{}{models.Classification{}}, 49 // expectedVerb: "create", 50 // expectedResource: "classifications/*", 51 // }, 52 // } 53 54 // t.Run("verify that a test for every public method exists", func(t *testing.T) { 55 // testedMethods := make([]string, len(tests), len(tests)) 56 // for i, test := range tests { 57 // testedMethods[i] = test.methodName 58 // } 59 60 // for _, method := range allExportedMethods(&Classifier{}) { 61 // assert.Contains(t, testedMethods, method) 62 // } 63 // }) 64 65 // t.Run("verify the tested methods require correct permissions from the authorizer", func(t *testing.T) { 66 // principal := &models.Principal{} 67 // // logger, _ := test.NewNullLogger() 68 // for _, test := range tests { 69 // authorizer := &authDenier{} 70 // repo := &fakeClassificationRepo{} 71 // vectorRepo := &fakeVectorRepoKNN{} 72 // schemaGetter := &fakeSchemaGetter{} 73 74 // classifier := New(schemaGetter, repo, vectorRepo, authorizer) 75 76 // args := append([]interface{}{context.Background(), principal}, test.additionalArgs...) 77 // out, _ := callFuncByName(classifier, test.methodName, args...) 78 79 // require.Len(t, authorizer.calls, 1, "authorizer must be called") 80 // assert.Equal(t, errors.New("just a test fake"), out[len(out)-1].Interface(), 81 // "execution must abort with authorizer error") 82 // assert.Equal(t, authorizeCall{principal, test.expectedVerb, test.expectedResource}, 83 // authorizer.calls[0], "correct parameters must have been used on authorizer") 84 // } 85 // }) 86 // } 87 88 // type authorizeCall struct { 89 // principal *models.Principal 90 // verb string 91 // resource string 92 // } 93 94 // type authDenier struct { 95 // calls []authorizeCall 96 // } 97 98 // func (a *authDenier) Authorize(principal *models.Principal, verb, resource string) error { 99 // a.calls = append(a.calls, authorizeCall{principal, verb, resource}) 100 // return errors.New("just a test fake") 101 // } 102 103 // // inspired by https://stackoverflow.com/a/33008200 104 // func callFuncByName(manager interface{}, funcName string, params ...interface{}) (out []reflect.Value, err error) { 105 // managerValue := reflect.ValueOf(manager) 106 // m := managerValue.MethodByName(funcName) 107 // if !m.IsValid() { 108 // return make([]reflect.Value, 0), fmt.Errorf("Method not found \"%s\"", funcName) 109 // } 110 // in := make([]reflect.Value, len(params)) 111 // for i, param := range params { 112 // in[i] = reflect.ValueOf(param) 113 // } 114 // out = m.Call(in) 115 // return 116 // } 117 118 // func allExportedMethods(subject interface{}) []string { 119 // var methods []string 120 // subjectType := reflect.TypeOf(subject) 121 // for i := 0; i < subjectType.NumMethod(); i++ { 122 // name := subjectType.Method(i).Name 123 // if name[0] >= 'A' && name[0] <= 'Z' { 124 // methods = append(methods, name) 125 // } 126 // } 127 128 // return methods 129 // }