github.com/weaviate/weaviate@v1.24.6/usecases/backup/auth_test.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 package backup 13 14 import ( 15 "context" 16 "errors" 17 "fmt" 18 "reflect" 19 "testing" 20 21 "github.com/sirupsen/logrus/hooks/test" 22 "github.com/stretchr/testify/assert" 23 "github.com/stretchr/testify/require" 24 "github.com/weaviate/weaviate/entities/models" 25 ) 26 27 // A component-test like test suite that makes sure that every available UC is 28 // potentially protected with the Authorization plugin 29 30 func Test_Authorization(t *testing.T) { 31 req := &BackupRequest{ID: "123", Backend: "s3"} 32 type testCase struct { 33 methodName string 34 additionalArgs []interface{} 35 expectedVerb string 36 expectedResource string 37 } 38 39 tests := []testCase{ 40 { 41 methodName: "Backup", 42 additionalArgs: []interface{}{req}, 43 expectedVerb: "add", 44 expectedResource: "backups/s3/123", 45 }, 46 { 47 methodName: "BackupStatus", 48 additionalArgs: []interface{}{"s3", "123"}, 49 expectedVerb: "get", 50 expectedResource: "backups/s3/123", 51 }, 52 { 53 methodName: "Restore", 54 additionalArgs: []interface{}{req}, 55 expectedVerb: "restore", 56 expectedResource: "backups/s3/123/restore", 57 }, 58 { 59 methodName: "RestorationStatus", 60 additionalArgs: []interface{}{"s3", "123"}, 61 expectedVerb: "get", 62 expectedResource: "backups/s3/123/restore", 63 }, 64 } 65 66 t.Run("verify that a test for every public method exists", func(t *testing.T) { 67 testedMethods := make([]string, len(tests)) 68 for i, test := range tests { 69 testedMethods[i] = test.methodName 70 } 71 72 for _, method := range allExportedMethods(&Scheduler{}) { 73 switch method { 74 case "OnCommit", "OnAbort", "OnCanCommit", "OnStatus": 75 continue 76 } 77 assert.Contains(t, testedMethods, method) 78 } 79 }) 80 81 t.Run("verify the tested methods require correct permissions from the authorizer", func(t *testing.T) { 82 principal := &models.Principal{} 83 logger, _ := test.NewNullLogger() 84 for _, test := range tests { 85 t.Run(test.methodName, func(t *testing.T) { 86 authorizer := &authDenier{} 87 s := NewScheduler(authorizer, nil, nil, nil, nil, logger) 88 require.NotNil(t, s) 89 90 args := append([]interface{}{context.Background(), principal}, test.additionalArgs...) 91 out, _ := callFuncByName(s, test.methodName, args...) 92 93 require.Len(t, authorizer.calls, 1, "authorizer must be called") 94 assert.Equal(t, errors.New("just a test fake"), out[len(out)-1].Interface(), 95 "execution must abort with authorizer error") 96 assert.Equal(t, authorizeCall{principal, test.expectedVerb, test.expectedResource}, 97 authorizer.calls[0], "correct parameters must have been used on authorizer") 98 }) 99 } 100 }) 101 } 102 103 type authorizeCall struct { 104 principal *models.Principal 105 verb string 106 resource string 107 } 108 109 type authDenier struct { 110 calls []authorizeCall 111 } 112 113 func (a *authDenier) Authorize(principal *models.Principal, verb, resource string) error { 114 a.calls = append(a.calls, authorizeCall{principal, verb, resource}) 115 return errors.New("just a test fake") 116 } 117 118 // inspired by https://stackoverflow.com/a/33008200 119 func callFuncByName(manager interface{}, funcName string, params ...interface{}) (out []reflect.Value, err error) { 120 managerValue := reflect.ValueOf(manager) 121 m := managerValue.MethodByName(funcName) 122 if !m.IsValid() { 123 return make([]reflect.Value, 0), fmt.Errorf("Method not found \"%s\"", funcName) 124 } 125 in := make([]reflect.Value, len(params)) 126 for i, param := range params { 127 in[i] = reflect.ValueOf(param) 128 } 129 out = m.Call(in) 130 return 131 } 132 133 func allExportedMethods(subject interface{}) []string { 134 var methods []string 135 subjectType := reflect.TypeOf(subject) 136 for i := 0; i < subjectType.NumMethod(); i++ { 137 name := subjectType.Method(i).Name 138 if name[0] >= 'A' && name[0] <= 'Z' { 139 methods = append(methods, name) 140 } 141 } 142 143 return methods 144 }