github.com/weaviate/weaviate@v1.24.6/usecases/objects/authorization_test.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 package objects 13 14 import ( 15 "context" 16 "errors" 17 "fmt" 18 "reflect" 19 "testing" 20 21 "github.com/go-openapi/strfmt" 22 "github.com/sirupsen/logrus/hooks/test" 23 "github.com/stretchr/testify/assert" 24 "github.com/stretchr/testify/require" 25 "github.com/weaviate/weaviate/entities/additional" 26 "github.com/weaviate/weaviate/entities/models" 27 "github.com/weaviate/weaviate/usecases/config" 28 ) 29 30 // A component-test like test suite that makes sure that every available UC is 31 // potentially protected with the Authorization plugin 32 33 func Test_Kinds_Authorization(t *testing.T) { 34 type testCase struct { 35 methodName string 36 additionalArgs []interface{} 37 expectedVerb string 38 expectedResource string 39 } 40 41 tests := []testCase{ 42 // single kind 43 { 44 methodName: "AddObject", 45 additionalArgs: []interface{}{(*models.Object)(nil)}, 46 expectedVerb: "create", 47 expectedResource: "objects", 48 }, 49 { 50 methodName: "ValidateObject", 51 additionalArgs: []interface{}{(*models.Object)(nil)}, 52 expectedVerb: "validate", 53 expectedResource: "objects", 54 }, 55 { 56 methodName: "GetObject", 57 additionalArgs: []interface{}{"", strfmt.UUID("foo"), additional.Properties{}}, 58 expectedVerb: "get", 59 expectedResource: "objects/foo", 60 }, 61 { 62 methodName: "DeleteObject", 63 additionalArgs: []interface{}{"class", strfmt.UUID("foo")}, 64 expectedVerb: "delete", 65 expectedResource: "objects/class/foo", 66 }, 67 { // deprecated by the one above 68 methodName: "DeleteObject", 69 additionalArgs: []interface{}{"", strfmt.UUID("foo")}, 70 expectedVerb: "delete", 71 expectedResource: "objects/foo", 72 }, 73 { 74 methodName: "UpdateObject", 75 additionalArgs: []interface{}{"class", strfmt.UUID("foo"), (*models.Object)(nil)}, 76 expectedVerb: "update", 77 expectedResource: "objects/class/foo", 78 }, 79 { // deprecated by the one above 80 methodName: "UpdateObject", 81 additionalArgs: []interface{}{"", strfmt.UUID("foo"), (*models.Object)(nil)}, 82 expectedVerb: "update", 83 expectedResource: "objects/foo", 84 }, 85 { 86 methodName: "MergeObject", 87 additionalArgs: []interface{}{ 88 &models.Object{Class: "class", ID: "foo"}, 89 (*additional.ReplicationProperties)(nil), 90 }, 91 expectedVerb: "update", 92 expectedResource: "objects/class/foo", 93 }, 94 { 95 methodName: "GetObjectsClass", 96 additionalArgs: []interface{}{strfmt.UUID("foo")}, 97 expectedVerb: "get", 98 expectedResource: "objects/foo", 99 }, 100 { 101 methodName: "GetObjectClassFromName", 102 additionalArgs: []interface{}{strfmt.UUID("foo")}, 103 expectedVerb: "get", 104 expectedResource: "objects/foo", 105 }, 106 { 107 methodName: "HeadObject", 108 additionalArgs: []interface{}{"class", strfmt.UUID("foo")}, 109 expectedVerb: "head", 110 expectedResource: "objects/class/foo", 111 }, 112 { // deprecated by the one above 113 methodName: "HeadObject", 114 additionalArgs: []interface{}{"", strfmt.UUID("foo")}, 115 expectedVerb: "head", 116 expectedResource: "objects/foo", 117 }, 118 119 // query objects 120 { 121 methodName: "Query", 122 additionalArgs: []interface{}{new(QueryParams)}, 123 expectedVerb: "list", 124 expectedResource: "objects", 125 }, 126 127 { // list objects is deprecated by query 128 methodName: "GetObjects", 129 additionalArgs: []interface{}{(*int64)(nil), (*int64)(nil), (*string)(nil), (*string)(nil), additional.Properties{}}, 130 expectedVerb: "list", 131 expectedResource: "objects", 132 }, 133 134 // reference on objects 135 { 136 methodName: "AddObjectReference", 137 additionalArgs: []interface{}{AddReferenceInput{Class: "class", ID: strfmt.UUID("foo"), Property: "some prop"}, (*models.SingleRef)(nil)}, 138 expectedVerb: "update", 139 expectedResource: "objects/class/foo", 140 }, 141 { 142 methodName: "DeleteObjectReference", 143 additionalArgs: []interface{}{strfmt.UUID("foo"), "some prop", (*models.SingleRef)(nil)}, 144 expectedVerb: "update", 145 expectedResource: "objects/foo", 146 }, 147 { 148 methodName: "UpdateObjectReferences", 149 additionalArgs: []interface{}{&PutReferenceInput{Class: "class", ID: strfmt.UUID("foo"), Property: "some prop"}}, 150 expectedVerb: "update", 151 expectedResource: "objects/class/foo", 152 }, 153 } 154 155 t.Run("verify that a test for every public method exists", func(t *testing.T) { 156 testedMethods := make([]string, len(tests)) 157 for i, test := range tests { 158 testedMethods[i] = test.methodName 159 } 160 161 for _, method := range allExportedMethods(&Manager{}) { 162 assert.Contains(t, testedMethods, method) 163 } 164 }) 165 166 t.Run("verify the tested methods require correct permissions from the authorizer", func(t *testing.T) { 167 principal := &models.Principal{} 168 logger, _ := test.NewNullLogger() 169 for _, test := range tests { 170 if test.methodName != "MergeObject" { 171 continue 172 } 173 t.Run(test.methodName, func(t *testing.T) { 174 schemaManager := &fakeSchemaManager{} 175 locks := &fakeLocks{} 176 cfg := &config.WeaviateConfig{} 177 authorizer := &authDenier{} 178 vectorRepo := &fakeVectorRepo{} 179 manager := NewManager(locks, schemaManager, 180 cfg, logger, authorizer, 181 vectorRepo, getFakeModulesProvider(), nil) 182 183 args := append([]interface{}{context.Background(), principal}, test.additionalArgs...) 184 out, _ := callFuncByName(manager, test.methodName, args...) 185 186 require.Len(t, authorizer.calls, 1, "authorizer must be called") 187 aerr := out[len(out)-1].Interface().(error) 188 if err, ok := aerr.(*Error); !ok || !err.Forbidden() { 189 assert.Equal(t, errors.New("just a test fake"), aerr, 190 "execution must abort with authorizer error") 191 } 192 193 assert.Equal(t, authorizeCall{principal, test.expectedVerb, test.expectedResource}, 194 authorizer.calls[0], "correct parameters must have been used on authorizer") 195 }) 196 } 197 }) 198 } 199 200 func Test_BatchKinds_Authorization(t *testing.T) { 201 type testCase struct { 202 methodName string 203 additionalArgs []interface{} 204 expectedVerb string 205 expectedResource string 206 } 207 208 tests := []testCase{ 209 { 210 methodName: "AddObjects", 211 additionalArgs: []interface{}{ 212 []*models.Object{}, 213 []*string{}, 214 &additional.ReplicationProperties{}, 215 }, 216 expectedVerb: "create", 217 expectedResource: "batch/objects", 218 }, 219 220 { 221 methodName: "AddReferences", 222 additionalArgs: []interface{}{ 223 []*models.BatchReference{}, 224 &additional.ReplicationProperties{}, 225 }, 226 expectedVerb: "update", 227 expectedResource: "batch/*", 228 }, 229 230 { 231 methodName: "DeleteObjects", 232 additionalArgs: []interface{}{ 233 &models.BatchDeleteMatch{}, 234 (*bool)(nil), 235 (*string)(nil), 236 &additional.ReplicationProperties{}, 237 "", 238 }, 239 expectedVerb: "delete", 240 expectedResource: "batch/objects", 241 }, 242 { 243 methodName: "DeleteObjectsFromGRPC", 244 additionalArgs: []interface{}{ 245 BatchDeleteParams{}, 246 &additional.ReplicationProperties{}, 247 "", 248 }, 249 expectedVerb: "delete", 250 expectedResource: "batch/objects", 251 }, 252 } 253 254 t.Run("verify that a test for every public method exists", func(t *testing.T) { 255 testedMethods := make([]string, len(tests)) 256 for i, test := range tests { 257 testedMethods[i] = test.methodName 258 } 259 260 for _, method := range allExportedMethods(&BatchManager{}) { 261 assert.Contains(t, testedMethods, method) 262 } 263 }) 264 265 t.Run("verify the tested methods require correct permissions from the authorizer", func(t *testing.T) { 266 principal := &models.Principal{} 267 logger, _ := test.NewNullLogger() 268 for _, test := range tests { 269 schemaManager := &fakeSchemaManager{} 270 locks := &fakeLocks{} 271 cfg := &config.WeaviateConfig{} 272 authorizer := &authDenier{} 273 vectorRepo := &fakeVectorRepo{} 274 modulesProvider := getFakeModulesProvider() 275 manager := NewBatchManager(vectorRepo, modulesProvider, locks, schemaManager, cfg, logger, authorizer, nil) 276 277 args := append([]interface{}{context.Background(), principal}, test.additionalArgs...) 278 out, _ := callFuncByName(manager, test.methodName, args...) 279 280 require.Len(t, authorizer.calls, 1, "authorizer must be called") 281 assert.Equal(t, errors.New("just a test fake"), out[len(out)-1].Interface(), 282 "execution must abort with authorizer error") 283 assert.Equal(t, authorizeCall{principal, test.expectedVerb, test.expectedResource}, 284 authorizer.calls[0], "correct parameters must have been used on authorizer") 285 } 286 }) 287 } 288 289 type authorizeCall struct { 290 principal *models.Principal 291 verb string 292 resource string 293 } 294 295 type authDenier struct { 296 calls []authorizeCall 297 } 298 299 func (a *authDenier) Authorize(principal *models.Principal, verb, resource string) error { 300 a.calls = append(a.calls, authorizeCall{principal, verb, resource}) 301 return errors.New("just a test fake") 302 } 303 304 // inspired by https://stackoverflow.com/a/33008200 305 func callFuncByName(manager interface{}, funcName string, params ...interface{}) (out []reflect.Value, err error) { 306 managerValue := reflect.ValueOf(manager) 307 m := managerValue.MethodByName(funcName) 308 if !m.IsValid() { 309 return make([]reflect.Value, 0), fmt.Errorf("Method not found \"%s\"", funcName) 310 } 311 in := make([]reflect.Value, len(params)) 312 for i, param := range params { 313 in[i] = reflect.ValueOf(param) 314 } 315 out = m.Call(in) 316 return 317 } 318 319 func allExportedMethods(subject interface{}) []string { 320 var methods []string 321 subjectType := reflect.TypeOf(subject) 322 for i := 0; i < subjectType.NumMethod(); i++ { 323 name := subjectType.Method(i).Name 324 if name[0] >= 'A' && name[0] <= 'Z' { 325 methods = append(methods, name) 326 } 327 } 328 329 return methods 330 }